diff options
Diffstat (limited to 'drivers/iommu/iommu.c')
-rw-r--r-- | drivers/iommu/iommu.c | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 9df7cc75c1bc..a712b0cc3a1d 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -3352,16 +3352,17 @@ static void __iommu_remove_group_pasid(struct iommu_group *group, * @domain: the iommu domain. * @dev: the attached device. * @pasid: the pasid of the device. + * @handle: the attach handle. * * Return: 0 on success, or an error. */ int iommu_attach_device_pasid(struct iommu_domain *domain, - struct device *dev, ioasid_t pasid) + struct device *dev, ioasid_t pasid, + struct iommu_attach_handle *handle) { /* Caller must be a probed driver on dev */ struct iommu_group *group = dev->iommu_group; struct group_device *device; - void *curr; int ret; if (!domain->ops->set_dev_pasid) @@ -3382,11 +3383,12 @@ int iommu_attach_device_pasid(struct iommu_domain *domain, } } - curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL); - if (curr) { - ret = xa_err(curr) ? : -EBUSY; + if (handle) + handle->domain = domain; + + ret = xa_insert(&group->pasid_array, pasid, handle, GFP_KERNEL); + if (ret) goto out_unlock; - } ret = __iommu_set_group_pasid(domain, group, pasid); if (ret) @@ -3414,7 +3416,7 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev, mutex_lock(&group->mutex); __iommu_remove_group_pasid(group, pasid, domain); - WARN_ON(xa_erase(&group->pasid_array, pasid) != domain); + xa_erase(&group->pasid_array, pasid); mutex_unlock(&group->mutex); } EXPORT_SYMBOL_GPL(iommu_detach_device_pasid); @@ -3439,15 +3441,19 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev, { /* Caller must be a probed driver on dev */ struct iommu_group *group = dev->iommu_group; - struct iommu_domain *domain; + struct iommu_attach_handle *handle; + struct iommu_domain *domain = NULL; if (!group) return NULL; xa_lock(&group->pasid_array); - domain = xa_load(&group->pasid_array, pasid); + handle = xa_load(&group->pasid_array, pasid); + if (handle) + domain = handle->domain; + if (type && domain && domain->type != type) - domain = ERR_PTR(-EBUSY); + domain = NULL; xa_unlock(&group->pasid_array); return domain; |