diff --git a/drivers/iommu/amd/amd_iommu_types.h b/drivers/iommu/amd/amd_iommu_types.h index 43f0f1b13f2d..d1fed5fc219b 100644 --- a/drivers/iommu/amd/amd_iommu_types.h +++ b/drivers/iommu/amd/amd_iommu_types.h @@ -528,6 +528,7 @@ struct gcr3_tbl_info { u64 *gcr3_tbl; /* Guest CR3 table */ int glx; /* Number of levels for GCR3 table */ u32 pasid_cnt; /* Track attached PASIDs */ + u16 domid; /* Per device domain ID */ }; struct amd_io_pgtable { diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c index ab6af861a3f1..f688443c2764 100644 --- a/drivers/iommu/amd/iommu.c +++ b/drivers/iommu/amd/iommu.c @@ -1443,27 +1443,37 @@ static int device_flush_dte(struct iommu_dev_data *dev_data) return ret; } -/* - * TLB invalidation function which is called from the mapping functions. - * It invalidates a single PTE if the range to flush is within a single - * page. Otherwise it flushes the whole TLB of the IOMMU. - */ -static void __domain_flush_pages(struct protection_domain *domain, +static int domain_flush_pages_v2(struct protection_domain *pdom, u64 address, size_t size) { struct iommu_dev_data *dev_data; + struct iommu_cmd cmd; + int ret = 0; + + list_for_each_entry(dev_data, &pdom->dev_list, list) { + struct amd_iommu *iommu = get_amd_iommu_from_dev(dev_data->dev); + u16 domid = dev_data->gcr3_info.domid; + + build_inv_iommu_pages(&cmd, address, size, + domid, IOMMU_NO_PASID, true); + + ret |= iommu_queue_command(iommu, &cmd); + } + + return ret; +} + +static int domain_flush_pages_v1(struct protection_domain *pdom, + u64 address, size_t size) +{ struct iommu_cmd cmd; int ret = 0, i; - ioasid_t pasid = IOMMU_NO_PASID; - bool gn = false; - if (pdom_is_v2_pgtbl_mode(domain)) - gn = true; - - build_inv_iommu_pages(&cmd, address, size, domain->id, pasid, gn); + build_inv_iommu_pages(&cmd, address, size, + pdom->id, IOMMU_NO_PASID, false); for (i = 0; i < amd_iommu_get_num_iommus(); ++i) { - if (!domain->dev_iommu[i]) + if (!pdom->dev_iommu[i]) continue; /* @@ -1473,6 +1483,28 @@ static void __domain_flush_pages(struct protection_domain *domain, ret |= iommu_queue_command(amd_iommus[i], &cmd); } + return ret; +} + +/* + * TLB invalidation function which is called from the mapping functions. + * It flushes range of PTEs of the domain. + */ +static void __domain_flush_pages(struct protection_domain *domain, + u64 address, size_t size) +{ + struct iommu_dev_data *dev_data; + int ret = 0; + ioasid_t pasid = IOMMU_NO_PASID; + bool gn = false; + + if (pdom_is_v2_pgtbl_mode(domain)) { + gn = true; + ret = domain_flush_pages_v2(domain, address, size); + } else { + ret = domain_flush_pages_v1(domain, address, size); + } + list_for_each_entry(dev_data, &domain->dev_list, list) { if (!dev_data->ats_enabled) @@ -1548,7 +1580,7 @@ void amd_iommu_dev_flush_pasid_pages(struct iommu_dev_data *dev_data, struct amd_iommu *iommu = get_amd_iommu_from_dev(dev_data->dev); build_inv_iommu_pages(&cmd, address, size, - dev_data->domain->id, pasid, true); + dev_data->gcr3_info.domid, pasid, true); iommu_queue_command(iommu, &cmd); if (dev_data->ats_enabled) @@ -1723,6 +1755,9 @@ static void free_gcr3_table(struct gcr3_tbl_info *gcr3_info) gcr3_info->glx = 0; + /* Free per device domain ID */ + domain_id_free(gcr3_info->domid); + free_page((unsigned long)gcr3_info->gcr3_tbl); gcr3_info->gcr3_tbl = NULL; } @@ -1755,9 +1790,14 @@ static int setup_gcr3_table(struct gcr3_tbl_info *gcr3_info, if (gcr3_info->gcr3_tbl) return -EBUSY; + /* Allocate per device domain ID */ + gcr3_info->domid = domain_id_alloc(); + gcr3_info->gcr3_tbl = alloc_pgtable_page(nid, GFP_KERNEL); - if (gcr3_info->gcr3_tbl == NULL) + if (gcr3_info->gcr3_tbl == NULL) { + domain_id_free(gcr3_info->domid); return -ENOMEM; + } gcr3_info->glx = levels; @@ -1856,10 +1896,16 @@ static void set_dte_entry(struct amd_iommu *iommu, u64 flags = 0; u32 old_domid; u16 devid = dev_data->devid; + u16 domid; struct protection_domain *domain = dev_data->domain; struct dev_table_entry *dev_table = get_dev_table(iommu); struct gcr3_tbl_info *gcr3_info = &dev_data->gcr3_info; + if (gcr3_info && gcr3_info->gcr3_tbl) + domid = dev_data->gcr3_info.domid; + else + domid = domain->id; + if (domain->iop.mode != PAGE_MODE_NONE) pte_root = iommu_virt_to_phys(domain->iop.root); @@ -1872,7 +1918,7 @@ static void set_dte_entry(struct amd_iommu *iommu, * When SNP is enabled, Only set TV bit when IOMMU * page translation is in use. */ - if (!amd_iommu_snp_en || (domain->id != 0)) + if (!amd_iommu_snp_en || (domid != 0)) pte_root |= DTE_FLAG_TV; flags = dev_table[devid].data[1]; @@ -1922,7 +1968,7 @@ static void set_dte_entry(struct amd_iommu *iommu, } flags &= ~DEV_DOMID_MASK; - flags |= domain->id; + flags |= domid; old_domid = dev_table[devid].data[1] & DEV_DOMID_MASK; dev_table[devid].data[1] = flags;