]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
mm: userfaultfd: recheck dst_pmd entry in move_pages_pte()
authorQi Zheng <zhengqi.arch@bytedance.com>
Wed, 4 Dec 2024 11:09:42 +0000 (19:09 +0800)
committerAndrew Morton <akpm@linux-foundation.org>
Tue, 14 Jan 2025 06:40:46 +0000 (22:40 -0800)
In move_pages_pte(), since dst_pte needs to be none, the subsequent
pte_same() check cannot prevent the dst_pte page from being freed
concurrently, so we also need to abtain dst_pmdval and recheck pmd_same().
Otherwise, once we support empty PTE page reclaimation for anonymous
pages, it may result in moving the src_pte page into the dts_pte page that
is about to be freed by RCU.

[zhengqi.arch@bytedance.com: remove WARN_ON_ONCE()s]
Link: https://lkml.kernel.org/r/20241210084156.89877-1-zhengqi.arch@bytedance.com
Link: https://lkml.kernel.org/r/8108c262757fc492626f3a2ffc44b775f2710e16.1733305182.git.zhengqi.arch@bytedance.com
Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
Cc: Andy Lutomirski <luto@kernel.org>
Cc: Catalin Marinas <catalin.marinas@arm.com>
Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: David Hildenbrand <david@redhat.com>
Cc: David Rientjes <rientjes@google.com>
Cc: Hugh Dickins <hughd@google.com>
Cc: Jann Horn <jannh@google.com>
Cc: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
Cc: Matthew Wilcox <willy@infradead.org>
Cc: Mel Gorman <mgorman@suse.de>
Cc: Muchun Song <muchun.song@linux.dev>
Cc: Peter Xu <peterx@redhat.com>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Will Deacon <will@kernel.org>
Cc: Zach O'Keefe <zokeefe@google.com>
Cc: Dan Carpenter <dan.carpenter@linaro.org>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/userfaultfd.c

index 60a0be33766ffd71dd7de01d4b1f8fb71db98893..af3dfc3633dbec5ee01d9977c73bd8740799bfef 100644 (file)
@@ -1020,6 +1020,14 @@ void double_pt_unlock(spinlock_t *ptl1,
                __release(ptl2);
 }
 
+static inline bool is_pte_pages_stable(pte_t *dst_pte, pte_t *src_pte,
+                                      pte_t orig_dst_pte, pte_t orig_src_pte,
+                                      pmd_t *dst_pmd, pmd_t dst_pmdval)
+{
+       return pte_same(ptep_get(src_pte), orig_src_pte) &&
+              pte_same(ptep_get(dst_pte), orig_dst_pte) &&
+              pmd_same(dst_pmdval, pmdp_get_lockless(dst_pmd));
+}
 
 static int move_present_pte(struct mm_struct *mm,
                            struct vm_area_struct *dst_vma,
@@ -1027,6 +1035,7 @@ static int move_present_pte(struct mm_struct *mm,
                            unsigned long dst_addr, unsigned long src_addr,
                            pte_t *dst_pte, pte_t *src_pte,
                            pte_t orig_dst_pte, pte_t orig_src_pte,
+                           pmd_t *dst_pmd, pmd_t dst_pmdval,
                            spinlock_t *dst_ptl, spinlock_t *src_ptl,
                            struct folio *src_folio)
 {
@@ -1034,8 +1043,8 @@ static int move_present_pte(struct mm_struct *mm,
 
        double_pt_lock(dst_ptl, src_ptl);
 
-       if (!pte_same(ptep_get(src_pte), orig_src_pte) ||
-           !pte_same(ptep_get(dst_pte), orig_dst_pte)) {
+       if (!is_pte_pages_stable(dst_pte, src_pte, orig_dst_pte, orig_src_pte,
+                                dst_pmd, dst_pmdval)) {
                err = -EAGAIN;
                goto out;
        }
@@ -1071,6 +1080,7 @@ static int move_swap_pte(struct mm_struct *mm,
                         unsigned long dst_addr, unsigned long src_addr,
                         pte_t *dst_pte, pte_t *src_pte,
                         pte_t orig_dst_pte, pte_t orig_src_pte,
+                        pmd_t *dst_pmd, pmd_t dst_pmdval,
                         spinlock_t *dst_ptl, spinlock_t *src_ptl)
 {
        if (!pte_swp_exclusive(orig_src_pte))
@@ -1078,8 +1088,8 @@ static int move_swap_pte(struct mm_struct *mm,
 
        double_pt_lock(dst_ptl, src_ptl);
 
-       if (!pte_same(ptep_get(src_pte), orig_src_pte) ||
-           !pte_same(ptep_get(dst_pte), orig_dst_pte)) {
+       if (!is_pte_pages_stable(dst_pte, src_pte, orig_dst_pte, orig_src_pte,
+                                dst_pmd, dst_pmdval)) {
                double_pt_unlock(dst_ptl, src_ptl);
                return -EAGAIN;
        }
@@ -1097,13 +1107,14 @@ static int move_zeropage_pte(struct mm_struct *mm,
                             unsigned long dst_addr, unsigned long src_addr,
                             pte_t *dst_pte, pte_t *src_pte,
                             pte_t orig_dst_pte, pte_t orig_src_pte,
+                            pmd_t *dst_pmd, pmd_t dst_pmdval,
                             spinlock_t *dst_ptl, spinlock_t *src_ptl)
 {
        pte_t zero_pte;
 
        double_pt_lock(dst_ptl, src_ptl);
-       if (!pte_same(ptep_get(src_pte), orig_src_pte) ||
-           !pte_same(ptep_get(dst_pte), orig_dst_pte)) {
+       if (!is_pte_pages_stable(dst_pte, src_pte, orig_dst_pte, orig_src_pte,
+                                dst_pmd, dst_pmdval)) {
                double_pt_unlock(dst_ptl, src_ptl);
                return -EAGAIN;
        }
@@ -1136,6 +1147,7 @@ static int move_pages_pte(struct mm_struct *mm, pmd_t *dst_pmd, pmd_t *src_pmd,
        pte_t *src_pte = NULL;
        pte_t *dst_pte = NULL;
        pmd_t dummy_pmdval;
+       pmd_t dst_pmdval;
        struct folio *src_folio = NULL;
        struct anon_vma *src_anon_vma = NULL;
        struct mmu_notifier_range range;
@@ -1148,11 +1160,11 @@ static int move_pages_pte(struct mm_struct *mm, pmd_t *dst_pmd, pmd_t *src_pmd,
 retry:
        /*
         * Use the maywrite version to indicate that dst_pte will be modified,
-        * but since we will use pte_same() to detect the change of the pte
-        * entry, there is no need to get pmdval, so just pass a dummy variable
-        * to it.
+        * since dst_pte needs to be none, the subsequent pte_same() check
+        * cannot prevent the dst_pte page from being freed concurrently, so we
+        * also need to abtain dst_pmdval and recheck pmd_same() later.
         */
-       dst_pte = pte_offset_map_rw_nolock(mm, dst_pmd, dst_addr, &dummy_pmdval,
+       dst_pte = pte_offset_map_rw_nolock(mm, dst_pmd, dst_addr, &dst_pmdval,
                                           &dst_ptl);
 
        /* Retry if a huge pmd materialized from under us */
@@ -1161,7 +1173,11 @@ retry:
                goto out;
        }
 
-       /* same as dst_pte */
+       /*
+        * Unlike dst_pte, the subsequent pte_same() check can ensure the
+        * stability of the src_pte page, so there is no need to get pmdval,
+        * just pass a dummy variable to it.
+        */
        src_pte = pte_offset_map_rw_nolock(mm, src_pmd, src_addr, &dummy_pmdval,
                                           &src_ptl);
 
@@ -1177,8 +1193,8 @@ retry:
        }
 
        /* Sanity checks before the operation */
-       if (WARN_ON_ONCE(pmd_none(*dst_pmd)) || WARN_ON_ONCE(pmd_none(*src_pmd)) ||
-           WARN_ON_ONCE(pmd_trans_huge(*dst_pmd)) || WARN_ON_ONCE(pmd_trans_huge(*src_pmd))) {
+       if (pmd_none(*dst_pmd) || pmd_none(*src_pmd) ||
+           pmd_trans_huge(*dst_pmd) || pmd_trans_huge(*src_pmd)) {
                err = -EINVAL;
                goto out;
        }
@@ -1213,7 +1229,7 @@ retry:
                        err = move_zeropage_pte(mm, dst_vma, src_vma,
                                               dst_addr, src_addr, dst_pte, src_pte,
                                               orig_dst_pte, orig_src_pte,
-                                              dst_ptl, src_ptl);
+                                              dst_pmd, dst_pmdval, dst_ptl, src_ptl);
                        goto out;
                }
 
@@ -1303,8 +1319,8 @@ retry:
 
                err = move_present_pte(mm,  dst_vma, src_vma,
                                       dst_addr, src_addr, dst_pte, src_pte,
-                                      orig_dst_pte, orig_src_pte,
-                                      dst_ptl, src_ptl, src_folio);
+                                      orig_dst_pte, orig_src_pte, dst_pmd,
+                                      dst_pmdval, dst_ptl, src_ptl, src_folio);
        } else {
                entry = pte_to_swp_entry(orig_src_pte);
                if (non_swap_entry(entry)) {
@@ -1319,10 +1335,9 @@ retry:
                        goto out;
                }
 
-               err = move_swap_pte(mm, dst_addr, src_addr,
-                                   dst_pte, src_pte,
-                                   orig_dst_pte, orig_src_pte,
-                                   dst_ptl, src_ptl);
+               err = move_swap_pte(mm, dst_addr, src_addr, dst_pte, src_pte,
+                                   orig_dst_pte, orig_src_pte, dst_pmd,
+                                   dst_pmdval, dst_ptl, src_ptl);
        }
 
 out: