]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
vfio/type1: handle DMA map/unmap up to the addressable limit
authorAlex Mastro <amastro@fb.com>
Tue, 28 Oct 2025 16:15:02 +0000 (09:15 -0700)
committerAlex Williamson <alex@shazbot.org>
Tue, 28 Oct 2025 21:54:41 +0000 (15:54 -0600)
Before this commit, it was possible to create end of address space
mappings, but unmapping them via VFIO_IOMMU_UNMAP_DMA, replaying them
for newly added iommu domains, and querying their dirty pages via
VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP was broken due to bugs caused by
comparisons against (iova + size) expressions, which overflow to zero.
Additionally, there appears to be a page pinning leak in the
vfio_iommu_type1_release() path, since vfio_unmap_unpin()'s loop body
where unmap_unpin_*() are called will never be entered due to overflow
of (iova + size) to zero.

This commit handles DMA map/unmap operations up to the addressable
limit by comparing against inclusive end-of-range limits, and changing
iteration to perform relative traversals across range sizes, rather than
absolute traversals across addresses.

vfio_link_dma() inserts a zero-sized vfio_dma into the rb-tree, and is
only used for that purpose, so discard the size from consideration for
the insertion point.

Tested-by: Alejandro Jimenez <alejandro.j.jimenez@oracle.com>
Fixes: 73fa0d10d077 ("vfio: Type1 IOMMU implementation")
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Reviewed-by: Alejandro Jimenez <alejandro.j.jimenez@oracle.com>
Signed-off-by: Alex Mastro <amastro@fb.com>
Link: https://lore.kernel.org/r/20251028-fix-unmap-v6-3-2542b96bcc8e@fb.com
Signed-off-by: Alex Williamson <alex@shazbot.org>
drivers/vfio/vfio_iommu_type1.c

index 48bcc0633d445f571b243685ec40487942764ff8..5167bec14e363bff1f76af11c1a4b6ec90561dcd 100644 (file)
@@ -168,12 +168,14 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 {
        struct rb_node *node = iommu->dma_list.rb_node;
 
+       WARN_ON(!size);
+
        while (node) {
                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 
-               if (start + size <= dma->iova)
+               if (start + size - 1 < dma->iova)
                        node = node->rb_left;
-               else if (start >= dma->iova + dma->size)
+               else if (start > dma->iova + dma->size - 1)
                        node = node->rb_right;
                else
                        return dma;
@@ -183,16 +185,19 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 }
 
 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
-                                               dma_addr_t start, size_t size)
+                                               dma_addr_t start,
+                                               dma_addr_t end)
 {
        struct rb_node *res = NULL;
        struct rb_node *node = iommu->dma_list.rb_node;
        struct vfio_dma *dma_res = NULL;
 
+       WARN_ON(end < start);
+
        while (node) {
                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 
-               if (start < dma->iova + dma->size) {
+               if (start <= dma->iova + dma->size - 1) {
                        res = node;
                        dma_res = dma;
                        if (start >= dma->iova)
@@ -202,7 +207,7 @@ static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
                        node = node->rb_right;
                }
        }
-       if (res && size && dma_res->iova >= start + size)
+       if (res && dma_res->iova > end)
                res = NULL;
        return res;
 }
@@ -212,11 +217,13 @@ static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
        struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
        struct vfio_dma *dma;
 
+       WARN_ON(new->size != 0);
+
        while (*link) {
                parent = *link;
                dma = rb_entry(parent, struct vfio_dma, node);
 
-               if (new->iova + new->size <= dma->iova)
+               if (new->iova <= dma->iova)
                        link = &(*link)->rb_left;
                else
                        link = &(*link)->rb_right;
@@ -1141,12 +1148,12 @@ static size_t unmap_unpin_slow(struct vfio_domain *domain,
 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
                             bool do_accounting)
 {
-       dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
        struct vfio_domain *domain, *d;
        LIST_HEAD(unmapped_region_list);
        struct iommu_iotlb_gather iotlb_gather;
        int unmapped_region_cnt = 0;
        long unlocked = 0;
+       size_t pos = 0;
 
        if (!dma->size)
                return 0;
@@ -1170,13 +1177,14 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
        }
 
        iommu_iotlb_gather_init(&iotlb_gather);
-       while (iova < end) {
+       while (pos < dma->size) {
                size_t unmapped, len;
                phys_addr_t phys, next;
+               dma_addr_t iova = dma->iova + pos;
 
                phys = iommu_iova_to_phys(domain->domain, iova);
                if (WARN_ON(!phys)) {
-                       iova += PAGE_SIZE;
+                       pos += PAGE_SIZE;
                        continue;
                }
 
@@ -1185,7 +1193,7 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
                 * may require hardware cache flushing, try to find the
                 * largest contiguous physical memory chunk to unmap.
                 */
-               for (len = PAGE_SIZE; iova + len < end; len += PAGE_SIZE) {
+               for (len = PAGE_SIZE; pos + len < dma->size; len += PAGE_SIZE) {
                        next = iommu_iova_to_phys(domain->domain, iova + len);
                        if (next != phys + len)
                                break;
@@ -1206,7 +1214,7 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
                                break;
                }
 
-               iova += unmapped;
+               pos += unmapped;
        }
 
        dma->iommu_mapped = false;
@@ -1298,7 +1306,7 @@ static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
 }
 
 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
-                                 dma_addr_t iova, size_t size, size_t pgsize)
+                                 dma_addr_t iova, dma_addr_t iova_end, size_t pgsize)
 {
        struct vfio_dma *dma;
        struct rb_node *n;
@@ -1315,8 +1323,8 @@ static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
        if (dma && dma->iova != iova)
                return -EINVAL;
 
-       dma = vfio_find_dma(iommu, iova + size - 1, 0);
-       if (dma && dma->iova + dma->size != iova + size)
+       dma = vfio_find_dma(iommu, iova_end, 1);
+       if (dma && dma->iova + dma->size - 1 != iova_end)
                return -EINVAL;
 
        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
@@ -1325,7 +1333,7 @@ static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
                if (dma->iova < iova)
                        continue;
 
-               if (dma->iova > iova + size - 1)
+               if (dma->iova > iova_end)
                        break;
 
                ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
@@ -1418,7 +1426,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
        if (unmap_all) {
                if (iova || size)
                        goto unlock;
-               size = SIZE_MAX;
+               iova_end = ~(dma_addr_t)0;
        } else {
                if (!size || size & (pgsize - 1))
                        goto unlock;
@@ -1473,17 +1481,17 @@ again:
                if (dma && dma->iova != iova)
                        goto unlock;
 
-               dma = vfio_find_dma(iommu, iova_end, 0);
-               if (dma && dma->iova + dma->size != iova + size)
+               dma = vfio_find_dma(iommu, iova_end, 1);
+               if (dma && dma->iova + dma->size - 1 != iova_end)
                        goto unlock;
        }
 
        ret = 0;
-       n = first_n = vfio_find_dma_first_node(iommu, iova, size);
+       n = first_n = vfio_find_dma_first_node(iommu, iova, iova_end);
 
        while (n) {
                dma = rb_entry(n, struct vfio_dma, node);
-               if (dma->iova >= iova + size)
+               if (dma->iova > iova_end)
                        break;
 
                if (!iommu->v2 && iova > dma->iova)
@@ -1813,12 +1821,12 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 
        for (; n; n = rb_next(n)) {
                struct vfio_dma *dma;
-               dma_addr_t iova;
+               size_t pos = 0;
 
                dma = rb_entry(n, struct vfio_dma, node);
-               iova = dma->iova;
 
-               while (iova < dma->iova + dma->size) {
+               while (pos < dma->size) {
+                       dma_addr_t iova = dma->iova + pos;
                        phys_addr_t phys;
                        size_t size;
 
@@ -1834,14 +1842,14 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
                                phys = iommu_iova_to_phys(d->domain, iova);
 
                                if (WARN_ON(!phys)) {
-                                       iova += PAGE_SIZE;
+                                       pos += PAGE_SIZE;
                                        continue;
                                }
 
                                size = PAGE_SIZE;
                                p = phys + size;
                                i = iova + size;
-                               while (i < dma->iova + dma->size &&
+                               while (pos + size < dma->size &&
                                       p == iommu_iova_to_phys(d->domain, i)) {
                                        size += PAGE_SIZE;
                                        p += PAGE_SIZE;
@@ -1849,9 +1857,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
                                }
                        } else {
                                unsigned long pfn;
-                               unsigned long vaddr = dma->vaddr +
-                                                    (iova - dma->iova);
-                               size_t n = dma->iova + dma->size - iova;
+                               unsigned long vaddr = dma->vaddr + pos;
+                               size_t n = dma->size - pos;
                                long npage;
 
                                npage = vfio_pin_pages_remote(dma, vaddr,
@@ -1882,7 +1889,7 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
                                goto unwind;
                        }
 
-                       iova += size;
+                       pos += size;
                }
        }
 
@@ -1899,29 +1906,29 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 unwind:
        for (; n; n = rb_prev(n)) {
                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
-               dma_addr_t iova;
+               size_t pos = 0;
 
                if (dma->iommu_mapped) {
                        iommu_unmap(domain->domain, dma->iova, dma->size);
                        continue;
                }
 
-               iova = dma->iova;
-               while (iova < dma->iova + dma->size) {
+               while (pos < dma->size) {
+                       dma_addr_t iova = dma->iova + pos;
                        phys_addr_t phys, p;
                        size_t size;
                        dma_addr_t i;
 
                        phys = iommu_iova_to_phys(domain->domain, iova);
                        if (!phys) {
-                               iova += PAGE_SIZE;
+                               pos += PAGE_SIZE;
                                continue;
                        }
 
                        size = PAGE_SIZE;
                        p = phys + size;
                        i = iova + size;
-                       while (i < dma->iova + dma->size &&
+                       while (pos + size < dma->size &&
                               p == iommu_iova_to_phys(domain->domain, i)) {
                                size += PAGE_SIZE;
                                p += PAGE_SIZE;
@@ -3059,7 +3066,7 @@ static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
 
                if (iommu->dirty_page_tracking)
                        ret = vfio_iova_dirty_bitmap(range.bitmap.data,
-                                                    iommu, iova, size,
+                                                    iommu, iova, iova_end,
                                                     range.bitmap.pgsize);
                else
                        ret = -EINVAL;