summaryrefslogtreecommitdiff
path: root/drivers/iommu
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu')
-rw-r--r--drivers/iommu/iommu-sva.c13
-rw-r--r--drivers/iommu/iommu.c26
2 files changed, 24 insertions, 15 deletions
diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c
index 18a35e798b72..0fb923254062 100644
--- a/drivers/iommu/iommu-sva.c
+++ b/drivers/iommu/iommu-sva.c
@@ -99,7 +99,9 @@ struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm
/* Search for an existing domain. */
list_for_each_entry(domain, &mm->iommu_mm->sva_domains, next) {
- ret = iommu_attach_device_pasid(domain, dev, iommu_mm->pasid);
+ handle->handle.domain = domain;
+ ret = iommu_attach_device_pasid(domain, dev, iommu_mm->pasid,
+ &handle->handle);
if (!ret) {
domain->users++;
goto out;
@@ -113,7 +115,9 @@ struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm
goto out_free_handle;
}
- ret = iommu_attach_device_pasid(domain, dev, iommu_mm->pasid);
+ handle->handle.domain = domain;
+ ret = iommu_attach_device_pasid(domain, dev, iommu_mm->pasid,
+ &handle->handle);
if (ret)
goto out_free_domain;
domain->users = 1;
@@ -124,7 +128,6 @@ out:
list_add(&handle->handle_item, &mm->iommu_mm->sva_handles);
mutex_unlock(&iommu_sva_lock);
handle->dev = dev;
- handle->domain = domain;
return handle;
out_free_domain:
@@ -147,7 +150,7 @@ EXPORT_SYMBOL_GPL(iommu_sva_bind_device);
*/
void iommu_sva_unbind_device(struct iommu_sva *handle)
{
- struct iommu_domain *domain = handle->domain;
+ struct iommu_domain *domain = handle->handle.domain;
struct iommu_mm_data *iommu_mm = domain->mm->iommu_mm;
struct device *dev = handle->dev;
@@ -170,7 +173,7 @@ EXPORT_SYMBOL_GPL(iommu_sva_unbind_device);
u32 iommu_sva_get_pasid(struct iommu_sva *handle)
{
- struct iommu_domain *domain = handle->domain;
+ struct iommu_domain *domain = handle->handle.domain;
return mm_get_enqcmd_pasid(domain->mm);
}
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 9df7cc75c1bc..a712b0cc3a1d 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3352,16 +3352,17 @@ static void __iommu_remove_group_pasid(struct iommu_group *group,
* @domain: the iommu domain.
* @dev: the attached device.
* @pasid: the pasid of the device.
+ * @handle: the attach handle.
*
* Return: 0 on success, or an error.
*/
int iommu_attach_device_pasid(struct iommu_domain *domain,
- struct device *dev, ioasid_t pasid)
+ struct device *dev, ioasid_t pasid,
+ struct iommu_attach_handle *handle)
{
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
struct group_device *device;
- void *curr;
int ret;
if (!domain->ops->set_dev_pasid)
@@ -3382,11 +3383,12 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
}
}
- curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
- if (curr) {
- ret = xa_err(curr) ? : -EBUSY;
+ if (handle)
+ handle->domain = domain;
+
+ ret = xa_insert(&group->pasid_array, pasid, handle, GFP_KERNEL);
+ if (ret)
goto out_unlock;
- }
ret = __iommu_set_group_pasid(domain, group, pasid);
if (ret)
@@ -3414,7 +3416,7 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
mutex_lock(&group->mutex);
__iommu_remove_group_pasid(group, pasid, domain);
- WARN_ON(xa_erase(&group->pasid_array, pasid) != domain);
+ xa_erase(&group->pasid_array, pasid);
mutex_unlock(&group->mutex);
}
EXPORT_SYMBOL_GPL(iommu_detach_device_pasid);
@@ -3439,15 +3441,19 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev,
{
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
- struct iommu_domain *domain;
+ struct iommu_attach_handle *handle;
+ struct iommu_domain *domain = NULL;
if (!group)
return NULL;
xa_lock(&group->pasid_array);
- domain = xa_load(&group->pasid_array, pasid);
+ handle = xa_load(&group->pasid_array, pasid);
+ if (handle)
+ domain = handle->domain;
+
if (type && domain && domain->type != type)
- domain = ERR_PTR(-EBUSY);
+ domain = NULL;
xa_unlock(&group->pasid_array);
return domain;