diff --git a/drivers/iommu/amd_iommu.c b/drivers/iommu/amd_iommu.c index bc98de5fa867..23c1a7eebb06 100644 --- a/drivers/iommu/amd_iommu.c +++ b/drivers/iommu/amd_iommu.c @@ -2459,10 +2459,20 @@ static dma_addr_t map_page(struct device *dev, struct page *page, unsigned long attrs) { phys_addr_t paddr = page_to_phys(page) + offset; - struct protection_domain *domain = get_domain(dev); - struct dma_ops_domain *dma_dom = to_dma_ops_domain(domain); + struct protection_domain *domain; + struct dma_ops_domain *dma_dom; + u64 dma_mask; - return __map_single(dev, dma_dom, paddr, size, dir, *dev->dma_mask); + domain = get_domain(dev); + if (PTR_ERR(domain) == -EINVAL) + return (dma_addr_t)paddr; + else if (IS_ERR(domain)) + return DMA_MAPPING_ERROR; + + dma_mask = *dev->dma_mask; + dma_dom = to_dma_ops_domain(domain); + + return __map_single(dev, dma_dom, paddr, size, dir, dma_mask); } /* @@ -2471,8 +2481,14 @@ static dma_addr_t map_page(struct device *dev, struct page *page, static void unmap_page(struct device *dev, dma_addr_t dma_addr, size_t size, enum dma_data_direction dir, unsigned long attrs) { - struct protection_domain *domain = get_domain(dev); - struct dma_ops_domain *dma_dom = to_dma_ops_domain(domain); + struct protection_domain *domain; + struct dma_ops_domain *dma_dom; + + domain = get_domain(dev); + if (IS_ERR(domain)) + return; + + dma_dom = to_dma_ops_domain(domain); __unmap_single(dma_dom, dma_addr, size, dir); } @@ -2512,13 +2528,20 @@ static int map_sg(struct device *dev, struct scatterlist *sglist, unsigned long attrs) { int mapped_pages = 0, npages = 0, prot = 0, i; - struct protection_domain *domain = get_domain(dev); - struct dma_ops_domain *dma_dom = to_dma_ops_domain(domain); + struct protection_domain *domain; + struct dma_ops_domain *dma_dom; struct scatterlist *s; unsigned long address; - u64 dma_mask = *dev->dma_mask; + u64 dma_mask; int ret; + domain = get_domain(dev); + if (IS_ERR(domain)) + return 0; + + dma_dom = to_dma_ops_domain(domain); + dma_mask = *dev->dma_mask; + npages = sg_num_pages(dev, sglist, nelems); address = dma_ops_alloc_iova(dev, dma_dom, npages, dma_mask); @@ -2592,11 +2615,20 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, int nelems, enum dma_data_direction dir, unsigned long attrs) { - struct protection_domain *domain = get_domain(dev); - struct dma_ops_domain *dma_dom = to_dma_ops_domain(domain); + struct protection_domain *domain; + struct dma_ops_domain *dma_dom; + unsigned long startaddr; + int npages = 2; - __unmap_single(dma_dom, sg_dma_address(sglist) & PAGE_MASK, - sg_num_pages(dev, sglist, nelems) << PAGE_SHIFT, dir); + domain = get_domain(dev); + if (IS_ERR(domain)) + return; + + startaddr = sg_dma_address(sglist) & PAGE_MASK; + dma_dom = to_dma_ops_domain(domain); + npages = sg_num_pages(dev, sglist, nelems); + + __unmap_single(dma_dom, startaddr, npages << PAGE_SHIFT, dir); } /* @@ -2607,11 +2639,16 @@ static void *alloc_coherent(struct device *dev, size_t size, unsigned long attrs) { u64 dma_mask = dev->coherent_dma_mask; - struct protection_domain *domain = get_domain(dev); + struct protection_domain *domain; struct dma_ops_domain *dma_dom; struct page *page; - if (IS_ERR(domain)) + domain = get_domain(dev); + if (PTR_ERR(domain) == -EINVAL) { + page = alloc_pages(flag, get_order(size)); + *dma_addr = page_to_phys(page); + return page_address(page); + } else if (IS_ERR(domain)) return NULL; dma_dom = to_dma_ops_domain(domain); @@ -2657,13 +2694,22 @@ static void free_coherent(struct device *dev, size_t size, void *virt_addr, dma_addr_t dma_addr, unsigned long attrs) { - struct protection_domain *domain = get_domain(dev); - struct dma_ops_domain *dma_dom = to_dma_ops_domain(domain); - struct page *page = virt_to_page(virt_addr); + struct protection_domain *domain; + struct dma_ops_domain *dma_dom; + struct page *page; + page = virt_to_page(virt_addr); size = PAGE_ALIGN(size); + domain = get_domain(dev); + if (IS_ERR(domain)) + goto free_mem; + + dma_dom = to_dma_ops_domain(domain); + __unmap_single(dma_dom, dma_addr, size, DMA_BIDIRECTIONAL); + +free_mem: if (!dma_release_from_contiguous(dev, page, size >> PAGE_SHIFT)) __free_pages(page, get_order(size)); }