summaryrefslogtreecommitdiff
path: root/drivers/iommu/iommu.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/iommu.c')
-rw-r--r--drivers/iommu/iommu.c26
1 files changed, 16 insertions, 10 deletions
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 9df7cc75c1bc..a712b0cc3a1d 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3352,16 +3352,17 @@ static void __iommu_remove_group_pasid(struct iommu_group *group,
* @domain: the iommu domain.
* @dev: the attached device.
* @pasid: the pasid of the device.
+ * @handle: the attach handle.
*
* Return: 0 on success, or an error.
*/
int iommu_attach_device_pasid(struct iommu_domain *domain,
- struct device *dev, ioasid_t pasid)
+ struct device *dev, ioasid_t pasid,
+ struct iommu_attach_handle *handle)
{
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
struct group_device *device;
- void *curr;
int ret;
if (!domain->ops->set_dev_pasid)
@@ -3382,11 +3383,12 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
}
}
- curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
- if (curr) {
- ret = xa_err(curr) ? : -EBUSY;
+ if (handle)
+ handle->domain = domain;
+
+ ret = xa_insert(&group->pasid_array, pasid, handle, GFP_KERNEL);
+ if (ret)
goto out_unlock;
- }
ret = __iommu_set_group_pasid(domain, group, pasid);
if (ret)
@@ -3414,7 +3416,7 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
mutex_lock(&group->mutex);
__iommu_remove_group_pasid(group, pasid, domain);
- WARN_ON(xa_erase(&group->pasid_array, pasid) != domain);
+ xa_erase(&group->pasid_array, pasid);
mutex_unlock(&group->mutex);
}
EXPORT_SYMBOL_GPL(iommu_detach_device_pasid);
@@ -3439,15 +3441,19 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev,
{
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
- struct iommu_domain *domain;
+ struct iommu_attach_handle *handle;
+ struct iommu_domain *domain = NULL;
if (!group)
return NULL;
xa_lock(&group->pasid_array);
- domain = xa_load(&group->pasid_array, pasid);
+ handle = xa_load(&group->pasid_array, pasid);
+ if (handle)
+ domain = handle->domain;
+
if (type && domain && domain->type != type)
- domain = ERR_PTR(-EBUSY);
+ domain = NULL;
xa_unlock(&group->pasid_array);
return domain;