--- /dev/null
+From 9972605a238339b85bd16b084eed5f18414d22db Mon Sep 17 00:00:00 2001
+From: Shakeel Butt <shakeel.butt@linux.dev>
+Date: Fri, 2 Aug 2024 16:58:22 -0700
+Subject: memcg: protect concurrent access to mem_cgroup_idr
+
+From: Shakeel Butt <shakeel.butt@linux.dev>
+
+commit 9972605a238339b85bd16b084eed5f18414d22db upstream.
+
+Commit 73f576c04b94 ("mm: memcontrol: fix cgroup creation failure after
+many small jobs") decoupled the memcg IDs from the CSS ID space to fix the
+cgroup creation failures. It introduced IDR to maintain the memcg ID
+space. The IDR depends on external synchronization mechanisms for
+modifications. For the mem_cgroup_idr, the idr_alloc() and idr_replace()
+happen within css callback and thus are protected through cgroup_mutex
+from concurrent modifications. However idr_remove() for mem_cgroup_idr
+was not protected against concurrency and can be run concurrently for
+different memcgs when they hit their refcnt to zero. Fix that.
+
+We have been seeing list_lru based kernel crashes at a low frequency in
+our fleet for a long time. These crashes were in different part of
+list_lru code including list_lru_add(), list_lru_del() and reparenting
+code. Upon further inspection, it looked like for a given object (dentry
+and inode), the super_block's list_lru didn't have list_lru_one for the
+memcg of that object. The initial suspicions were either the object is
+not allocated through kmem_cache_alloc_lru() or somehow
+memcg_list_lru_alloc() failed to allocate list_lru_one() for a memcg but
+returned success. No evidence were found for these cases.
+
+Looking more deeply, we started seeing situations where valid memcg's id
+is not present in mem_cgroup_idr and in some cases multiple valid memcgs
+have same id and mem_cgroup_idr is pointing to one of them. So, the most
+reasonable explanation is that these situations can happen due to race
+between multiple idr_remove() calls or race between
+idr_alloc()/idr_replace() and idr_remove(). These races are causing
+multiple memcgs to acquire the same ID and then offlining of one of them
+would cleanup list_lrus on the system for all of them. Later access from
+other memcgs to the list_lru cause crashes due to missing list_lru_one.
+
+Link: https://lkml.kernel.org/r/20240802235822.1830976-1-shakeel.butt@linux.dev
+Fixes: 73f576c04b94 ("mm: memcontrol: fix cgroup creation failure after many small jobs")
+Signed-off-by: Shakeel Butt <shakeel.butt@linux.dev>
+Acked-by: Muchun Song <muchun.song@linux.dev>
+Reviewed-by: Roman Gushchin <roman.gushchin@linux.dev>
+Acked-by: Johannes Weiner <hannes@cmpxchg.org>
+Cc: Michal Hocko <mhocko@suse.com>
+Cc: <stable@vger.kernel.org>
+Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
+[ Adapted due to commit be740503ed03 ("mm: memcontrol: fix cannot alloc the
+ maximum memcg ID") and 6f0df8e16eb5 ("memcontrol: ensure memcg acquired by id
+ is properly set up") not in the tree ]
+Signed-off-by: Tomas Krcka <krckatom@amazon.de>
+Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
+---
+ mm/memcontrol.c | 23 ++++++++++++++++++++---
+ 1 file changed, 20 insertions(+), 3 deletions(-)
+
+--- a/mm/memcontrol.c
++++ b/mm/memcontrol.c
+@@ -5160,11 +5160,28 @@ static struct cftype mem_cgroup_legacy_f
+ */
+
+ static DEFINE_IDR(mem_cgroup_idr);
++static DEFINE_SPINLOCK(memcg_idr_lock);
++
++static int mem_cgroup_alloc_id(void)
++{
++ int ret;
++
++ idr_preload(GFP_KERNEL);
++ spin_lock(&memcg_idr_lock);
++ ret = idr_alloc(&mem_cgroup_idr, NULL, 1, MEM_CGROUP_ID_MAX + 1,
++ GFP_NOWAIT);
++ spin_unlock(&memcg_idr_lock);
++ idr_preload_end();
++ return ret;
++}
+
+ static void mem_cgroup_id_remove(struct mem_cgroup *memcg)
+ {
+ if (memcg->id.id > 0) {
++ spin_lock(&memcg_idr_lock);
+ idr_remove(&mem_cgroup_idr, memcg->id.id);
++ spin_unlock(&memcg_idr_lock);
++
+ memcg->id.id = 0;
+ }
+ }
+@@ -5294,9 +5311,7 @@ static struct mem_cgroup *mem_cgroup_all
+ if (!memcg)
+ return ERR_PTR(error);
+
+- memcg->id.id = idr_alloc(&mem_cgroup_idr, NULL,
+- 1, MEM_CGROUP_ID_MAX,
+- GFP_KERNEL);
++ memcg->id.id = mem_cgroup_alloc_id();
+ if (memcg->id.id < 0) {
+ error = memcg->id.id;
+ goto fail;
+@@ -5342,7 +5357,9 @@ static struct mem_cgroup *mem_cgroup_all
+ INIT_LIST_HEAD(&memcg->deferred_split_queue.split_queue);
+ memcg->deferred_split_queue.split_queue_len = 0;
+ #endif
++ spin_lock(&memcg_idr_lock);
+ idr_replace(&mem_cgroup_idr, memcg, memcg->id.id);
++ spin_unlock(&memcg_idr_lock);
+ return memcg;
+ fail:
+ mem_cgroup_id_remove(memcg);