#include <linux/io.h>
  #include <linux/ftrace.h>
  #include <linux/syscalls.h>
++#include <linux/iommu.h>
  
  #include <asm/processor.h>
  #include <asm/pkru.h>
 
        select NEED_DMA_MAP_STATE
        select DMAR_TABLE
        select SWIOTLB
--      select IOASID
        select PCI_ATS
        select PCI_PRI
        select PCI_PASID
 
  #include "iommu-sva.h"
  
  static DEFINE_MUTEX(iommu_sva_lock);
- static DECLARE_IOASID_SET(iommu_sva_pasid);
+ static DEFINE_IDA(iommu_global_pasid_ida);
  
- /**
-  * iommu_sva_alloc_pasid - Allocate a PASID for the mm
-  * @mm: the mm
-  * @min: minimum PASID value (inclusive)
-  * @max: maximum PASID value (inclusive)
-  *
-  * Try to allocate a PASID for this mm, or take a reference to the existing one
-  * provided it fits within the [@min, @max] range. On success the PASID is
-  * available in mm->pasid and will be available for the lifetime of the mm.
-  *
-  * Returns 0 on success and < 0 on error.
-  */
- int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
+ /* Allocate a PASID for the mm within range (inclusive) */
+ static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
  {
        int ret = 0;
-       ioasid_t pasid;
  
-       if (min == INVALID_IOASID || max == INVALID_IOASID ||
 -      if (!pasid_valid(min) || !pasid_valid(max) ||
++      if (min == IOMMU_PASID_INVALID ||
++          max == IOMMU_PASID_INVALID ||
            min == 0 || max < min)
                return -EINVAL;
  
 +      if (!arch_pgtable_dma_compat(mm))
 +              return -EBUSY;
 +
        mutex_lock(&iommu_sva_lock);
        /* Is a PASID already associated with this mm? */
 -      if (pasid_valid(mm->pasid)) {
 +      if (mm_valid_pasid(mm)) {
-               if (mm->pasid < min || mm->pasid >= max)
+               if (mm->pasid < min || mm->pasid > max)
                        ret = -EOVERFLOW;
                goto out;
        }
  
        return status;
  }
 -      if (likely(!pasid_valid(mm->pasid)))
+ 
+ void mm_pasid_drop(struct mm_struct *mm)
+ {
++      if (likely(!mm_valid_pasid(mm)))
+               return;
+ 
+       ida_free(&iommu_global_pasid_ida, mm->pasid);
+ }
 
  
  static int iommu_bus_notifier(struct notifier_block *nb,
                              unsigned long action, void *data);
+ static void iommu_release_device(struct device *dev);
  static int iommu_alloc_default_domain(struct iommu_group *group,
                                      struct device *dev);
 -static struct iommu_domain *__iommu_domain_alloc(struct bus_type *bus,
 +static struct iommu_domain *__iommu_domain_alloc(const struct bus_type *bus,
                                                 unsigned type);
  static int __iommu_attach_device(struct iommu_domain *domain,
                                 struct device *dev);
 
        return dev->iommu->iommu_dev->ops;
  }
  
 -extern int bus_iommu_probe(struct bus_type *bus);
 -extern bool iommu_present(struct bus_type *bus);
 +extern int bus_iommu_probe(const struct bus_type *bus);
 +extern bool iommu_present(const struct bus_type *bus);
  extern bool device_iommu_capable(struct device *dev, enum iommu_cap cap);
  extern bool iommu_group_has_isolated_msi(struct iommu_group *group);
 -extern struct iommu_domain *iommu_domain_alloc(struct bus_type *bus);
 +extern struct iommu_domain *iommu_domain_alloc(const struct bus_type *bus);
- extern struct iommu_group *iommu_group_get_by_id(int id);
  extern void iommu_domain_free(struct iommu_domain *domain);
  extern int iommu_attach_device(struct iommu_domain *domain,
                               struct device *dev);
        return false;
  }
  
 -static inline bool pasid_valid(ioasid_t ioasid)
 -{
 -      return ioasid != IOMMU_PASID_INVALID;
 -}
 -
  #ifdef CONFIG_IOMMU_SVA
+ static inline void mm_pasid_init(struct mm_struct *mm)
+ {
+       mm->pasid = IOMMU_PASID_INVALID;
+ }
++static inline bool mm_valid_pasid(struct mm_struct *mm)
++{
++      return mm->pasid != IOMMU_PASID_INVALID;
++}
+ void mm_pasid_drop(struct mm_struct *mm);
  struct iommu_sva *iommu_sva_bind_device(struct device *dev,
                                        struct mm_struct *mm);
  void iommu_sva_unbind_device(struct iommu_sva *handle);
  {
        return IOMMU_PASID_INVALID;
  }
+ static inline void mm_pasid_init(struct mm_struct *mm) {}
++static inline bool mm_valid_pasid(struct mm_struct *mm) { return false; }
+ static inline void mm_pasid_drop(struct mm_struct *mm) {}
  #endif /* CONFIG_IOMMU_SVA */
  
  #endif /* __LINUX_IOMMU_H */
 
  #include <linux/io_uring.h>
  #include <linux/bpf.h>
  #include <linux/stackprotector.h>
 +#include <linux/user_events.h>
+ #include <linux/iommu.h>
  
  #include <asm/pgalloc.h>
  #include <linux/uaccess.h>