--- /dev/null
+From 8782fb61cc848364e1e1599d76d3c9dd58a1cc06 Mon Sep 17 00:00:00 2001
+From: Steven Price <steven.price@arm.com>
+Date: Fri, 2 Sep 2022 12:26:12 +0100
+Subject: mm: pagewalk: Fix race between unmap and page walker
+MIME-Version: 1.0
+Content-Type: text/plain; charset=UTF-8
+Content-Transfer-Encoding: 8bit
+
+From: Steven Price <steven.price@arm.com>
+
+commit 8782fb61cc848364e1e1599d76d3c9dd58a1cc06 upstream.
+
+The mmap lock protects the page walker from changes to the page tables
+during the walk. However a read lock is insufficient to protect those
+areas which don't have a VMA as munmap() detaches the VMAs before
+downgrading to a read lock and actually tearing down PTEs/page tables.
+
+For users of walk_page_range() the solution is to simply call pte_hole()
+immediately without checking the actual page tables when a VMA is not
+present. We now never call __walk_page_range() without a valid vma.
+
+For walk_page_range_novma() the locking requirements are tightened to
+require the mmap write lock to be taken, and then walking the pgd
+directly with 'no_vma' set.
+
+This in turn means that all page walkers either have a valid vma, or
+it's that special 'novma' case for page table debugging. As a result,
+all the odd '(!walk->vma && !walk->no_vma)' tests can be removed.
+
+Fixes: dd2283f2605e ("mm: mmap: zap pages with read mmap_sem in munmap")
+Reported-by: Jann Horn <jannh@google.com>
+Signed-off-by: Steven Price <steven.price@arm.com>
+Cc: Vlastimil Babka <vbabka@suse.cz>
+Cc: Thomas Hellström <thomas.hellstrom@linux.intel.com>
+Cc: Konstantin Khlebnikov <koct9i@gmail.com>
+Cc: Andrew Morton <akpm@linux-foundation.org>
+Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
+[manually backported. backport note: walk_page_range_novma() does not exist in
+5.4, so I'm omitting it from the backport]
+Signed-off-by: Jann Horn <jannh@google.com>
+Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
+---
+ mm/pagewalk.c | 13 ++++++++-----
+ 1 file changed, 8 insertions(+), 5 deletions(-)
+
+--- a/mm/pagewalk.c
++++ b/mm/pagewalk.c
+@@ -38,7 +38,7 @@ static int walk_pmd_range(pud_t *pud, un
+ do {
+ again:
+ next = pmd_addr_end(addr, end);
+- if (pmd_none(*pmd) || !walk->vma) {
++ if (pmd_none(*pmd)) {
+ if (ops->pte_hole)
+ err = ops->pte_hole(addr, next, walk);
+ if (err)
+@@ -84,7 +84,7 @@ static int walk_pud_range(p4d_t *p4d, un
+ do {
+ again:
+ next = pud_addr_end(addr, end);
+- if (pud_none(*pud) || !walk->vma) {
++ if (pud_none(*pud)) {
+ if (ops->pte_hole)
+ err = ops->pte_hole(addr, next, walk);
+ if (err)
+@@ -254,7 +254,7 @@ static int __walk_page_range(unsigned lo
+ int err = 0;
+ struct vm_area_struct *vma = walk->vma;
+
+- if (vma && is_vm_hugetlb_page(vma)) {
++ if (is_vm_hugetlb_page(vma)) {
+ if (walk->ops->hugetlb_entry)
+ err = walk_hugetlb_range(start, end, walk);
+ } else
+@@ -324,9 +324,13 @@ int walk_page_range(struct mm_struct *mm
+ if (!vma) { /* after the last vma */
+ walk.vma = NULL;
+ next = end;
++ if (ops->pte_hole)
++ err = ops->pte_hole(start, next, &walk);
+ } else if (start < vma->vm_start) { /* outside vma */
+ walk.vma = NULL;
+ next = min(end, vma->vm_start);
++ if (ops->pte_hole)
++ err = ops->pte_hole(start, next, &walk);
+ } else { /* inside vma */
+ walk.vma = vma;
+ next = min(end, vma->vm_end);
+@@ -344,9 +348,8 @@ int walk_page_range(struct mm_struct *mm
+ }
+ if (err < 0)
+ break;
+- }
+- if (walk.vma || walk.ops->pte_hole)
+ err = __walk_page_range(start, next, &walk);
++ }
+ if (err)
+ break;
+ } while (start = next, start < end);