--- /dev/null
+From 08f70b6345667b36b24a756767a275a43ebb1225 Mon Sep 17 00:00:00 2001
+From: Sasha Levin <sashal@kernel.org>
+Date: Tue, 6 Aug 2024 12:30:36 -0700
+Subject: mm: gup: stop abusing try_grab_folio
+
+From: Yang Shi <yang@os.amperecomputing.com>
+
+commit f442fa6141379a20b48ae3efabee827a3d260787 upstream
+
+A kernel warning was reported when pinning folio in CMA memory when
+launching SEV virtual machine. The splat looks like:
+
+[ 464.325306] WARNING: CPU: 13 PID: 6734 at mm/gup.c:1313 __get_user_pages+0x423/0x520
+[ 464.325464] CPU: 13 PID: 6734 Comm: qemu-kvm Kdump: loaded Not tainted 6.6.33+ #6
+[ 464.325477] RIP: 0010:__get_user_pages+0x423/0x520
+[ 464.325515] Call Trace:
+[ 464.325520] <TASK>
+[ 464.325523] ? __get_user_pages+0x423/0x520
+[ 464.325528] ? __warn+0x81/0x130
+[ 464.325536] ? __get_user_pages+0x423/0x520
+[ 464.325541] ? report_bug+0x171/0x1a0
+[ 464.325549] ? handle_bug+0x3c/0x70
+[ 464.325554] ? exc_invalid_op+0x17/0x70
+[ 464.325558] ? asm_exc_invalid_op+0x1a/0x20
+[ 464.325567] ? __get_user_pages+0x423/0x520
+[ 464.325575] __gup_longterm_locked+0x212/0x7a0
+[ 464.325583] internal_get_user_pages_fast+0xfb/0x190
+[ 464.325590] pin_user_pages_fast+0x47/0x60
+[ 464.325598] sev_pin_memory+0xca/0x170 [kvm_amd]
+[ 464.325616] sev_mem_enc_register_region+0x81/0x130 [kvm_amd]
+
+Per the analysis done by yangge, when starting the SEV virtual machine, it
+will call pin_user_pages_fast(..., FOLL_LONGTERM, ...) to pin the memory.
+But the page is in CMA area, so fast GUP will fail then fallback to the
+slow path due to the longterm pinnalbe check in try_grab_folio().
+
+The slow path will try to pin the pages then migrate them out of CMA area.
+But the slow path also uses try_grab_folio() to pin the page, it will
+also fail due to the same check then the above warning is triggered.
+
+In addition, the try_grab_folio() is supposed to be used in fast path and
+it elevates folio refcount by using add ref unless zero. We are guaranteed
+to have at least one stable reference in slow path, so the simple atomic add
+could be used. The performance difference should be trivial, but the
+misuse may be confusing and misleading.
+
+Redefined try_grab_folio() to try_grab_folio_fast(), and try_grab_page()
+to try_grab_folio(), and use them in the proper paths. This solves both
+the abuse and the kernel warning.
+
+The proper naming makes their usecase more clear and should prevent from
+abusing in the future.
+
+peterx said:
+
+: The user will see the pin fails, for gpu-slow it further triggers the WARN
+: right below that failure (as in the original report):
+:
+: folio = try_grab_folio(page, page_increm - 1,
+: foll_flags);
+: if (WARN_ON_ONCE(!folio)) { <------------------------ here
+: /*
+: * Release the 1st page ref if the
+: * folio is problematic, fail hard.
+: */
+: gup_put_folio(page_folio(page), 1,
+: foll_flags);
+: ret = -EFAULT;
+: goto out;
+: }
+
+[1] https://lore.kernel.org/linux-mm/1719478388-31917-1-git-send-email-yangge1116@126.com/
+
+[shy828301@gmail.com: fix implicit declaration of function try_grab_folio_fast]
+ Link: https://lkml.kernel.org/r/CAHbLzkowMSso-4Nufc9hcMehQsK9PNz3OSu-+eniU-2Mm-xjhA@mail.gmail.com
+Link: https://lkml.kernel.org/r/20240628191458.2605553-1-yang@os.amperecomputing.com
+Fixes: 57edfcfd3419 ("mm/gup: accelerate thp gup even for "pages != NULL"")
+Signed-off-by: Yang Shi <yang@os.amperecomputing.com>
+Reported-by: yangge <yangge1116@126.com>
+Cc: Christoph Hellwig <hch@infradead.org>
+Cc: David Hildenbrand <david@redhat.com>
+Cc: Peter Xu <peterx@redhat.com>
+Cc: <stable@vger.kernel.org> [6.6+]
+Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
+Signed-off-by: Sasha Levin <sashal@kernel.org>
+---
+ mm/gup.c | 251 ++++++++++++++++++++++++-----------------------
+ mm/huge_memory.c | 4 +-
+ mm/hugetlb.c | 2 +-
+ mm/internal.h | 4 +-
+ 4 files changed, 134 insertions(+), 127 deletions(-)
+
+diff --git a/mm/gup.c b/mm/gup.c
+index f50fe2219a13b..fdd75384160d8 100644
+--- a/mm/gup.c
++++ b/mm/gup.c
+@@ -97,95 +97,6 @@ static inline struct folio *try_get_folio(struct page *page, int refs)
+ return folio;
+ }
+
+-/**
+- * try_grab_folio() - Attempt to get or pin a folio.
+- * @page: pointer to page to be grabbed
+- * @refs: the value to (effectively) add to the folio's refcount
+- * @flags: gup flags: these are the FOLL_* flag values.
+- *
+- * "grab" names in this file mean, "look at flags to decide whether to use
+- * FOLL_PIN or FOLL_GET behavior, when incrementing the folio's refcount.
+- *
+- * Either FOLL_PIN or FOLL_GET (or neither) must be set, but not both at the
+- * same time. (That's true throughout the get_user_pages*() and
+- * pin_user_pages*() APIs.) Cases:
+- *
+- * FOLL_GET: folio's refcount will be incremented by @refs.
+- *
+- * FOLL_PIN on large folios: folio's refcount will be incremented by
+- * @refs, and its pincount will be incremented by @refs.
+- *
+- * FOLL_PIN on single-page folios: folio's refcount will be incremented by
+- * @refs * GUP_PIN_COUNTING_BIAS.
+- *
+- * Return: The folio containing @page (with refcount appropriately
+- * incremented) for success, or NULL upon failure. If neither FOLL_GET
+- * nor FOLL_PIN was set, that's considered failure, and furthermore,
+- * a likely bug in the caller, so a warning is also emitted.
+- */
+-struct folio *try_grab_folio(struct page *page, int refs, unsigned int flags)
+-{
+- struct folio *folio;
+-
+- if (WARN_ON_ONCE((flags & (FOLL_GET | FOLL_PIN)) == 0))
+- return NULL;
+-
+- if (unlikely(!(flags & FOLL_PCI_P2PDMA) && is_pci_p2pdma_page(page)))
+- return NULL;
+-
+- if (flags & FOLL_GET)
+- return try_get_folio(page, refs);
+-
+- /* FOLL_PIN is set */
+-
+- /*
+- * Don't take a pin on the zero page - it's not going anywhere
+- * and it is used in a *lot* of places.
+- */
+- if (is_zero_page(page))
+- return page_folio(page);
+-
+- folio = try_get_folio(page, refs);
+- if (!folio)
+- return NULL;
+-
+- /*
+- * Can't do FOLL_LONGTERM + FOLL_PIN gup fast path if not in a
+- * right zone, so fail and let the caller fall back to the slow
+- * path.
+- */
+- if (unlikely((flags & FOLL_LONGTERM) &&
+- !folio_is_longterm_pinnable(folio))) {
+- if (!put_devmap_managed_page_refs(&folio->page, refs))
+- folio_put_refs(folio, refs);
+- return NULL;
+- }
+-
+- /*
+- * When pinning a large folio, use an exact count to track it.
+- *
+- * However, be sure to *also* increment the normal folio
+- * refcount field at least once, so that the folio really
+- * is pinned. That's why the refcount from the earlier
+- * try_get_folio() is left intact.
+- */
+- if (folio_test_large(folio))
+- atomic_add(refs, &folio->_pincount);
+- else
+- folio_ref_add(folio,
+- refs * (GUP_PIN_COUNTING_BIAS - 1));
+- /*
+- * Adjust the pincount before re-checking the PTE for changes.
+- * This is essentially a smp_mb() and is paired with a memory
+- * barrier in page_try_share_anon_rmap().
+- */
+- smp_mb__after_atomic();
+-
+- node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, refs);
+-
+- return folio;
+-}
+-
+ static void gup_put_folio(struct folio *folio, int refs, unsigned int flags)
+ {
+ if (flags & FOLL_PIN) {
+@@ -203,58 +114,59 @@ static void gup_put_folio(struct folio *folio, int refs, unsigned int flags)
+ }
+
+ /**
+- * try_grab_page() - elevate a page's refcount by a flag-dependent amount
+- * @page: pointer to page to be grabbed
+- * @flags: gup flags: these are the FOLL_* flag values.
++ * try_grab_folio() - add a folio's refcount by a flag-dependent amount
++ * @folio: pointer to folio to be grabbed
++ * @refs: the value to (effectively) add to the folio's refcount
++ * @flags: gup flags: these are the FOLL_* flag values
+ *
+ * This might not do anything at all, depending on the flags argument.
+ *
+ * "grab" names in this file mean, "look at flags to decide whether to use
+- * FOLL_PIN or FOLL_GET behavior, when incrementing the page's refcount.
++ * FOLL_PIN or FOLL_GET behavior, when incrementing the folio's refcount.
+ *
+ * Either FOLL_PIN or FOLL_GET (or neither) may be set, but not both at the same
+- * time. Cases: please see the try_grab_folio() documentation, with
+- * "refs=1".
++ * time.
+ *
+ * Return: 0 for success, or if no action was required (if neither FOLL_PIN
+ * nor FOLL_GET was set, nothing is done). A negative error code for failure:
+ *
+- * -ENOMEM FOLL_GET or FOLL_PIN was set, but the page could not
++ * -ENOMEM FOLL_GET or FOLL_PIN was set, but the folio could not
+ * be grabbed.
++ *
++ * It is called when we have a stable reference for the folio, typically in
++ * GUP slow path.
+ */
+-int __must_check try_grab_page(struct page *page, unsigned int flags)
++int __must_check try_grab_folio(struct folio *folio, int refs,
++ unsigned int flags)
+ {
+- struct folio *folio = page_folio(page);
+-
+ if (WARN_ON_ONCE(folio_ref_count(folio) <= 0))
+ return -ENOMEM;
+
+- if (unlikely(!(flags & FOLL_PCI_P2PDMA) && is_pci_p2pdma_page(page)))
++ if (unlikely(!(flags & FOLL_PCI_P2PDMA) && is_pci_p2pdma_page(&folio->page)))
+ return -EREMOTEIO;
+
+ if (flags & FOLL_GET)
+- folio_ref_inc(folio);
++ folio_ref_add(folio, refs);
+ else if (flags & FOLL_PIN) {
+ /*
+ * Don't take a pin on the zero page - it's not going anywhere
+ * and it is used in a *lot* of places.
+ */
+- if (is_zero_page(page))
++ if (is_zero_folio(folio))
+ return 0;
+
+ /*
+- * Similar to try_grab_folio(): be sure to *also*
+- * increment the normal page refcount field at least once,
++ * Increment the normal page refcount field at least once,
+ * so that the page really is pinned.
+ */
+ if (folio_test_large(folio)) {
+- folio_ref_add(folio, 1);
+- atomic_add(1, &folio->_pincount);
++ folio_ref_add(folio, refs);
++ atomic_add(refs, &folio->_pincount);
+ } else {
+- folio_ref_add(folio, GUP_PIN_COUNTING_BIAS);
++ folio_ref_add(folio, refs * GUP_PIN_COUNTING_BIAS);
+ }
+
+- node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, 1);
++ node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, refs);
+ }
+
+ return 0;
+@@ -647,8 +559,8 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
+ VM_BUG_ON_PAGE((flags & FOLL_PIN) && PageAnon(page) &&
+ !PageAnonExclusive(page), page);
+
+- /* try_grab_page() does nothing unless FOLL_GET or FOLL_PIN is set. */
+- ret = try_grab_page(page, flags);
++ /* try_grab_folio() does nothing unless FOLL_GET or FOLL_PIN is set. */
++ ret = try_grab_folio(page_folio(page), 1, flags);
+ if (unlikely(ret)) {
+ page = ERR_PTR(ret);
+ goto out;
+@@ -899,7 +811,7 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
+ goto unmap;
+ *page = pte_page(entry);
+ }
+- ret = try_grab_page(*page, gup_flags);
++ ret = try_grab_folio(page_folio(*page), 1, gup_flags);
+ if (unlikely(ret))
+ goto unmap;
+ out:
+@@ -1302,20 +1214,19 @@ static long __get_user_pages(struct mm_struct *mm,
+ * pages.
+ */
+ if (page_increm > 1) {
+- struct folio *folio;
++ struct folio *folio = page_folio(page);
+
+ /*
+ * Since we already hold refcount on the
+ * large folio, this should never fail.
+ */
+- folio = try_grab_folio(page, page_increm - 1,
+- foll_flags);
+- if (WARN_ON_ONCE(!folio)) {
++ if (try_grab_folio(folio, page_increm - 1,
++ foll_flags)) {
+ /*
+ * Release the 1st page ref if the
+ * folio is problematic, fail hard.
+ */
+- gup_put_folio(page_folio(page), 1,
++ gup_put_folio(folio, 1,
+ foll_flags);
+ ret = -EFAULT;
+ goto out;
+@@ -2541,6 +2452,102 @@ static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
+ }
+ }
+
++/**
++ * try_grab_folio_fast() - Attempt to get or pin a folio in fast path.
++ * @page: pointer to page to be grabbed
++ * @refs: the value to (effectively) add to the folio's refcount
++ * @flags: gup flags: these are the FOLL_* flag values.
++ *
++ * "grab" names in this file mean, "look at flags to decide whether to use
++ * FOLL_PIN or FOLL_GET behavior, when incrementing the folio's refcount.
++ *
++ * Either FOLL_PIN or FOLL_GET (or neither) must be set, but not both at the
++ * same time. (That's true throughout the get_user_pages*() and
++ * pin_user_pages*() APIs.) Cases:
++ *
++ * FOLL_GET: folio's refcount will be incremented by @refs.
++ *
++ * FOLL_PIN on large folios: folio's refcount will be incremented by
++ * @refs, and its pincount will be incremented by @refs.
++ *
++ * FOLL_PIN on single-page folios: folio's refcount will be incremented by
++ * @refs * GUP_PIN_COUNTING_BIAS.
++ *
++ * Return: The folio containing @page (with refcount appropriately
++ * incremented) for success, or NULL upon failure. If neither FOLL_GET
++ * nor FOLL_PIN was set, that's considered failure, and furthermore,
++ * a likely bug in the caller, so a warning is also emitted.
++ *
++ * It uses add ref unless zero to elevate the folio refcount and must be called
++ * in fast path only.
++ */
++static struct folio *try_grab_folio_fast(struct page *page, int refs,
++ unsigned int flags)
++{
++ struct folio *folio;
++
++ /* Raise warn if it is not called in fast GUP */
++ VM_WARN_ON_ONCE(!irqs_disabled());
++
++ if (WARN_ON_ONCE((flags & (FOLL_GET | FOLL_PIN)) == 0))
++ return NULL;
++
++ if (unlikely(!(flags & FOLL_PCI_P2PDMA) && is_pci_p2pdma_page(page)))
++ return NULL;
++
++ if (flags & FOLL_GET)
++ return try_get_folio(page, refs);
++
++ /* FOLL_PIN is set */
++
++ /*
++ * Don't take a pin on the zero page - it's not going anywhere
++ * and it is used in a *lot* of places.
++ */
++ if (is_zero_page(page))
++ return page_folio(page);
++
++ folio = try_get_folio(page, refs);
++ if (!folio)
++ return NULL;
++
++ /*
++ * Can't do FOLL_LONGTERM + FOLL_PIN gup fast path if not in a
++ * right zone, so fail and let the caller fall back to the slow
++ * path.
++ */
++ if (unlikely((flags & FOLL_LONGTERM) &&
++ !folio_is_longterm_pinnable(folio))) {
++ if (!put_devmap_managed_page_refs(&folio->page, refs))
++ folio_put_refs(folio, refs);
++ return NULL;
++ }
++
++ /*
++ * When pinning a large folio, use an exact count to track it.
++ *
++ * However, be sure to *also* increment the normal folio
++ * refcount field at least once, so that the folio really
++ * is pinned. That's why the refcount from the earlier
++ * try_get_folio() is left intact.
++ */
++ if (folio_test_large(folio))
++ atomic_add(refs, &folio->_pincount);
++ else
++ folio_ref_add(folio,
++ refs * (GUP_PIN_COUNTING_BIAS - 1));
++ /*
++ * Adjust the pincount before re-checking the PTE for changes.
++ * This is essentially a smp_mb() and is paired with a memory
++ * barrier in folio_try_share_anon_rmap_*().
++ */
++ smp_mb__after_atomic();
++
++ node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, refs);
++
++ return folio;
++}
++
+ #ifdef CONFIG_ARCH_HAS_PTE_SPECIAL
+ /*
+ * Fast-gup relies on pte change detection to avoid concurrent pgtable
+@@ -2605,7 +2612,7 @@ static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
+ VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
+ page = pte_page(pte);
+
+- folio = try_grab_folio(page, 1, flags);
++ folio = try_grab_folio_fast(page, 1, flags);
+ if (!folio)
+ goto pte_unmap;
+
+@@ -2699,7 +2706,7 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
+
+ SetPageReferenced(page);
+ pages[*nr] = page;
+- if (unlikely(try_grab_page(page, flags))) {
++ if (unlikely(try_grab_folio(page_folio(page), 1, flags))) {
+ undo_dev_pagemap(nr, nr_start, flags, pages);
+ break;
+ }
+@@ -2808,7 +2815,7 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
+ page = nth_page(pte_page(pte), (addr & (sz - 1)) >> PAGE_SHIFT);
+ refs = record_subpages(page, addr, end, pages + *nr);
+
+- folio = try_grab_folio(page, refs, flags);
++ folio = try_grab_folio_fast(page, refs, flags);
+ if (!folio)
+ return 0;
+
+@@ -2879,7 +2886,7 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
+ page = nth_page(pmd_page(orig), (addr & ~PMD_MASK) >> PAGE_SHIFT);
+ refs = record_subpages(page, addr, end, pages + *nr);
+
+- folio = try_grab_folio(page, refs, flags);
++ folio = try_grab_folio_fast(page, refs, flags);
+ if (!folio)
+ return 0;
+
+@@ -2923,7 +2930,7 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
+ page = nth_page(pud_page(orig), (addr & ~PUD_MASK) >> PAGE_SHIFT);
+ refs = record_subpages(page, addr, end, pages + *nr);
+
+- folio = try_grab_folio(page, refs, flags);
++ folio = try_grab_folio_fast(page, refs, flags);
+ if (!folio)
+ return 0;
+
+@@ -2963,7 +2970,7 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
+ page = nth_page(pgd_page(orig), (addr & ~PGDIR_MASK) >> PAGE_SHIFT);
+ refs = record_subpages(page, addr, end, pages + *nr);
+
+- folio = try_grab_folio(page, refs, flags);
++ folio = try_grab_folio_fast(page, refs, flags);
+ if (!folio)
+ return 0;
+
+diff --git a/mm/huge_memory.c b/mm/huge_memory.c
+index 79fbd6ddec49f..fc773b0c4438c 100644
+--- a/mm/huge_memory.c
++++ b/mm/huge_memory.c
+@@ -1052,7 +1052,7 @@ struct page *follow_devmap_pmd(struct vm_area_struct *vma, unsigned long addr,
+ if (!*pgmap)
+ return ERR_PTR(-EFAULT);
+ page = pfn_to_page(pfn);
+- ret = try_grab_page(page, flags);
++ ret = try_grab_folio(page_folio(page), 1, flags);
+ if (ret)
+ page = ERR_PTR(ret);
+
+@@ -1471,7 +1471,7 @@ struct page *follow_trans_huge_pmd(struct vm_area_struct *vma,
+ VM_BUG_ON_PAGE((flags & FOLL_PIN) && PageAnon(page) &&
+ !PageAnonExclusive(page), page);
+
+- ret = try_grab_page(page, flags);
++ ret = try_grab_folio(page_folio(page), 1, flags);
+ if (ret)
+ return ERR_PTR(ret);
+
+diff --git a/mm/hugetlb.c b/mm/hugetlb.c
+index a480affd475bf..ab040f8d19876 100644
+--- a/mm/hugetlb.c
++++ b/mm/hugetlb.c
+@@ -6532,7 +6532,7 @@ struct page *hugetlb_follow_page_mask(struct vm_area_struct *vma,
+ * try_grab_page() should always be able to get the page here,
+ * because we hold the ptl lock and have verified pte_present().
+ */
+- ret = try_grab_page(page, flags);
++ ret = try_grab_folio(page_folio(page), 1, flags);
+
+ if (WARN_ON_ONCE(ret)) {
+ page = ERR_PTR(ret);
+diff --git a/mm/internal.h b/mm/internal.h
+index abed947f784b7..ef8d787a510c5 100644
+--- a/mm/internal.h
++++ b/mm/internal.h
+@@ -938,8 +938,8 @@ int migrate_device_coherent_page(struct page *page);
+ /*
+ * mm/gup.c
+ */
+-struct folio *try_grab_folio(struct page *page, int refs, unsigned int flags);
+-int __must_check try_grab_page(struct page *page, unsigned int flags);
++int __must_check try_grab_folio(struct folio *folio, int refs,
++ unsigned int flags);
+
+ /*
+ * mm/huge_memory.c
+--
+2.43.0
+