]> git.ipfire.org Git - thirdparty/kernel/stable.git/blobdiff - mm/memcontrol.c
Merge tag 'mm-stable-2024-05-17-19-19' of git://git.kernel.org/pub/scm/linux/kernel...
[thirdparty/kernel/stable.git] / mm / memcontrol.c
index fabce2b50c69551e46b0197cbad5a96cb79dde19..7fad15b2290c70051261813190068fa77b8be62d 100644 (file)
@@ -350,7 +350,7 @@ static void memcg_reparent_objcgs(struct mem_cgroup *memcg,
 
 /*
  * A lot of the calls to the cache allocation functions are expected to be
- * inlined by the compiler. Since the calls to memcg_slab_pre_alloc_hook() are
+ * inlined by the compiler. Since the calls to memcg_slab_post_alloc_hook() are
  * conditional to this static branch, we'll have to allow modules that does
  * kmem_cache_alloc and the such to see this symbol as well
  */
@@ -575,6 +575,136 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
        return mz;
 }
 
+/* Subset of node_stat_item for memcg stats */
+static const unsigned int memcg_node_stat_items[] = {
+       NR_INACTIVE_ANON,
+       NR_ACTIVE_ANON,
+       NR_INACTIVE_FILE,
+       NR_ACTIVE_FILE,
+       NR_UNEVICTABLE,
+       NR_SLAB_RECLAIMABLE_B,
+       NR_SLAB_UNRECLAIMABLE_B,
+       WORKINGSET_REFAULT_ANON,
+       WORKINGSET_REFAULT_FILE,
+       WORKINGSET_ACTIVATE_ANON,
+       WORKINGSET_ACTIVATE_FILE,
+       WORKINGSET_RESTORE_ANON,
+       WORKINGSET_RESTORE_FILE,
+       WORKINGSET_NODERECLAIM,
+       NR_ANON_MAPPED,
+       NR_FILE_MAPPED,
+       NR_FILE_PAGES,
+       NR_FILE_DIRTY,
+       NR_WRITEBACK,
+       NR_SHMEM,
+       NR_SHMEM_THPS,
+       NR_FILE_THPS,
+       NR_ANON_THPS,
+       NR_KERNEL_STACK_KB,
+       NR_PAGETABLE,
+       NR_SECONDARY_PAGETABLE,
+#ifdef CONFIG_SWAP
+       NR_SWAPCACHE,
+#endif
+};
+
+static const unsigned int memcg_stat_items[] = {
+       MEMCG_SWAP,
+       MEMCG_SOCK,
+       MEMCG_PERCPU_B,
+       MEMCG_VMALLOC,
+       MEMCG_KMEM,
+       MEMCG_ZSWAP_B,
+       MEMCG_ZSWAPPED,
+};
+
+#define NR_MEMCG_NODE_STAT_ITEMS ARRAY_SIZE(memcg_node_stat_items)
+#define MEMCG_VMSTAT_SIZE (NR_MEMCG_NODE_STAT_ITEMS + \
+                          ARRAY_SIZE(memcg_stat_items))
+static int8_t mem_cgroup_stats_index[MEMCG_NR_STAT] __read_mostly;
+
+static void init_memcg_stats(void)
+{
+       int8_t i, j = 0;
+
+       BUILD_BUG_ON(MEMCG_NR_STAT >= S8_MAX);
+
+       for (i = 0; i < NR_MEMCG_NODE_STAT_ITEMS; ++i)
+               mem_cgroup_stats_index[memcg_node_stat_items[i]] = ++j;
+
+       for (i = 0; i < ARRAY_SIZE(memcg_stat_items); ++i)
+               mem_cgroup_stats_index[memcg_stat_items[i]] = ++j;
+}
+
+static inline int memcg_stats_index(int idx)
+{
+       return mem_cgroup_stats_index[idx] - 1;
+}
+
+struct lruvec_stats_percpu {
+       /* Local (CPU and cgroup) state */
+       long state[NR_MEMCG_NODE_STAT_ITEMS];
+
+       /* Delta calculation for lockless upward propagation */
+       long state_prev[NR_MEMCG_NODE_STAT_ITEMS];
+};
+
+struct lruvec_stats {
+       /* Aggregated (CPU and subtree) state */
+       long state[NR_MEMCG_NODE_STAT_ITEMS];
+
+       /* Non-hierarchical (CPU aggregated) state */
+       long state_local[NR_MEMCG_NODE_STAT_ITEMS];
+
+       /* Pending child counts during tree propagation */
+       long state_pending[NR_MEMCG_NODE_STAT_ITEMS];
+};
+
+unsigned long lruvec_page_state(struct lruvec *lruvec, enum node_stat_item idx)
+{
+       struct mem_cgroup_per_node *pn;
+       long x;
+       int i;
+
+       if (mem_cgroup_disabled())
+               return node_page_state(lruvec_pgdat(lruvec), idx);
+
+       i = memcg_stats_index(idx);
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return 0;
+
+       pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec);
+       x = READ_ONCE(pn->lruvec_stats->state[i]);
+#ifdef CONFIG_SMP
+       if (x < 0)
+               x = 0;
+#endif
+       return x;
+}
+
+unsigned long lruvec_page_state_local(struct lruvec *lruvec,
+                                     enum node_stat_item idx)
+{
+       struct mem_cgroup_per_node *pn;
+       long x;
+       int i;
+
+       if (mem_cgroup_disabled())
+               return node_page_state(lruvec_pgdat(lruvec), idx);
+
+       i = memcg_stats_index(idx);
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return 0;
+
+       pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec);
+       x = READ_ONCE(pn->lruvec_stats->state_local[i]);
+#ifdef CONFIG_SMP
+       if (x < 0)
+               x = 0;
+#endif
+       return x;
+}
+
 /* Subset of vm_event_item to report for memcg event stats */
 static const unsigned int memcg_vm_event_stat[] = {
        PGPGIN,
@@ -606,11 +736,13 @@ static const unsigned int memcg_vm_event_stat[] = {
 };
 
 #define NR_MEMCG_EVENTS ARRAY_SIZE(memcg_vm_event_stat)
-static int mem_cgroup_events_index[NR_VM_EVENT_ITEMS] __read_mostly;
+static int8_t mem_cgroup_events_index[NR_VM_EVENT_ITEMS] __read_mostly;
 
 static void init_memcg_events(void)
 {
-       int i;
+       int8_t i;
+
+       BUILD_BUG_ON(NR_VM_EVENT_ITEMS >= S8_MAX);
 
        for (i = 0; i < NR_MEMCG_EVENTS; ++i)
                mem_cgroup_events_index[memcg_vm_event_stat[i]] = i + 1;
@@ -632,11 +764,11 @@ struct memcg_vmstats_percpu {
        /* The above should fit a single cacheline for memcg_rstat_updated() */
 
        /* Local (CPU and cgroup) page state & events */
-       long                    state[MEMCG_NR_STAT];
+       long                    state[MEMCG_VMSTAT_SIZE];
        unsigned long           events[NR_MEMCG_EVENTS];
 
        /* Delta calculation for lockless upward propagation */
-       long                    state_prev[MEMCG_NR_STAT];
+       long                    state_prev[MEMCG_VMSTAT_SIZE];
        unsigned long           events_prev[NR_MEMCG_EVENTS];
 
        /* Cgroup1: threshold notifications & softlimit tree updates */
@@ -646,15 +778,15 @@ struct memcg_vmstats_percpu {
 
 struct memcg_vmstats {
        /* Aggregated (CPU and subtree) page state & events */
-       long                    state[MEMCG_NR_STAT];
+       long                    state[MEMCG_VMSTAT_SIZE];
        unsigned long           events[NR_MEMCG_EVENTS];
 
        /* Non-hierarchical (CPU aggregated) page state & events */
-       long                    state_local[MEMCG_NR_STAT];
+       long                    state_local[MEMCG_VMSTAT_SIZE];
        unsigned long           events_local[NR_MEMCG_EVENTS];
 
        /* Pending child counts during tree propagation */
-       long                    state_pending[MEMCG_NR_STAT];
+       long                    state_pending[MEMCG_VMSTAT_SIZE];
        unsigned long           events_pending[NR_MEMCG_EVENTS];
 
        /* Stats updates since the last flush */
@@ -715,6 +847,7 @@ static inline void memcg_rstat_updated(struct mem_cgroup *memcg, int val)
 {
        struct memcg_vmstats_percpu *statc;
        int cpu = smp_processor_id();
+       unsigned int stats_updates;
 
        if (!val)
                return;
@@ -722,8 +855,9 @@ static inline void memcg_rstat_updated(struct mem_cgroup *memcg, int val)
        cgroup_rstat_updated(memcg->css.cgroup, cpu);
        statc = this_cpu_ptr(memcg->vmstats_percpu);
        for (; statc; statc = statc->parent) {
-               statc->stats_updates += abs(val);
-               if (statc->stats_updates < MEMCG_CHARGE_BATCH)
+               stats_updates = READ_ONCE(statc->stats_updates) + abs(val);
+               WRITE_ONCE(statc->stats_updates, stats_updates);
+               if (stats_updates < MEMCG_CHARGE_BATCH)
                        continue;
 
                /*
@@ -731,9 +865,9 @@ static inline void memcg_rstat_updated(struct mem_cgroup *memcg, int val)
                 * redundant. Avoid the overhead of the atomic update.
                 */
                if (!memcg_vmstats_needs_flush(statc->vmstats))
-                       atomic64_add(statc->stats_updates,
+                       atomic64_add(stats_updates,
                                     &statc->vmstats->stats_updates);
-               statc->stats_updates = 0;
+               WRITE_ONCE(statc->stats_updates, 0);
        }
 }
 
@@ -785,7 +919,13 @@ static void flush_memcg_stats_dwork(struct work_struct *w)
 
 unsigned long memcg_page_state(struct mem_cgroup *memcg, int idx)
 {
-       long x = READ_ONCE(memcg->vmstats->state[idx]);
+       long x;
+       int i = memcg_stats_index(idx);
+
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return 0;
+
+       x = READ_ONCE(memcg->vmstats->state[i]);
 #ifdef CONFIG_SMP
        if (x < 0)
                x = 0;
@@ -815,20 +955,31 @@ static int memcg_state_val_in_pages(int idx, int val)
  * @idx: the stat item - can be enum memcg_stat_item or enum node_stat_item
  * @val: delta to add to the counter, can be negative
  */
-void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
+void __mod_memcg_state(struct mem_cgroup *memcg, enum memcg_stat_item idx,
+                      int val)
 {
+       int i = memcg_stats_index(idx);
+
        if (mem_cgroup_disabled())
                return;
 
-       __this_cpu_add(memcg->vmstats_percpu->state[idx], val);
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return;
+
+       __this_cpu_add(memcg->vmstats_percpu->state[i], val);
        memcg_rstat_updated(memcg, memcg_state_val_in_pages(idx, val));
 }
 
 /* idx can be of type enum memcg_stat_item or node_stat_item. */
 static unsigned long memcg_page_state_local(struct mem_cgroup *memcg, int idx)
 {
-       long x = READ_ONCE(memcg->vmstats->state_local[idx]);
+       long x;
+       int i = memcg_stats_index(idx);
+
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return 0;
 
+       x = READ_ONCE(memcg->vmstats->state_local[i]);
 #ifdef CONFIG_SMP
        if (x < 0)
                x = 0;
@@ -836,11 +987,16 @@ static unsigned long memcg_page_state_local(struct mem_cgroup *memcg, int idx)
        return x;
 }
 
-void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
-                             int val)
+static void __mod_memcg_lruvec_state(struct lruvec *lruvec,
+                                    enum node_stat_item idx,
+                                    int val)
 {
        struct mem_cgroup_per_node *pn;
        struct mem_cgroup *memcg;
+       int i = memcg_stats_index(idx);
+
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
+               return;
 
        pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec);
        memcg = pn->memcg;
@@ -857,8 +1013,6 @@ void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
                case NR_ANON_MAPPED:
                case NR_FILE_MAPPED:
                case NR_ANON_THPS:
-               case NR_SHMEM_PMDMAPPED:
-               case NR_FILE_PMDMAPPED:
                        WARN_ON_ONCE(!in_task());
                        break;
                default:
@@ -867,10 +1021,10 @@ void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        }
 
        /* Update memcg */
-       __this_cpu_add(memcg->vmstats_percpu->state[idx], val);
+       __this_cpu_add(memcg->vmstats_percpu->state[i], val);
 
        /* Update lruvec */
-       __this_cpu_add(pn->lruvec_stats_percpu->state[idx], val);
+       __this_cpu_add(pn->lruvec_stats_percpu->state[i], val);
 
        memcg_rstat_updated(memcg, memcg_state_val_in_pages(idx, val));
        memcg_stats_unlock();
@@ -952,34 +1106,38 @@ void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
 void __count_memcg_events(struct mem_cgroup *memcg, enum vm_event_item idx,
                          unsigned long count)
 {
-       int index = memcg_events_index(idx);
+       int i = memcg_events_index(idx);
+
+       if (mem_cgroup_disabled())
+               return;
 
-       if (mem_cgroup_disabled() || index < 0)
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, idx))
                return;
 
        memcg_stats_lock();
-       __this_cpu_add(memcg->vmstats_percpu->events[index], count);
+       __this_cpu_add(memcg->vmstats_percpu->events[i], count);
        memcg_rstat_updated(memcg, count);
        memcg_stats_unlock();
 }
 
 static unsigned long memcg_events(struct mem_cgroup *memcg, int event)
 {
-       int index = memcg_events_index(event);
+       int i = memcg_events_index(event);
 
-       if (index < 0)
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, event))
                return 0;
-       return READ_ONCE(memcg->vmstats->events[index]);
+
+       return READ_ONCE(memcg->vmstats->events[i]);
 }
 
 static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
 {
-       int index = memcg_events_index(event);
+       int i = memcg_events_index(event);
 
-       if (index < 0)
+       if (WARN_ONCE(i < 0, "%s: missing stat item %d\n", __func__, event))
                return 0;
 
-       return READ_ONCE(memcg->vmstats->events_local[index]);
+       return READ_ONCE(memcg->vmstats->events_local[i]);
 }
 
 static void mem_cgroup_charge_statistics(struct mem_cgroup *memcg,
@@ -2030,8 +2188,6 @@ static bool mem_cgroup_oom(struct mem_cgroup *memcg, gfp_t mask, int order)
                if (current->in_user_fault) {
                        css_get(&memcg->css);
                        current->memcg_in_oom = memcg;
-                       current->memcg_oom_gfp_mask = mask;
-                       current->memcg_oom_order = order;
                }
                return false;
        }
@@ -2310,6 +2466,7 @@ static void memcg_account_kmem(struct mem_cgroup *memcg, int nr_pages)
 static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        struct memcg_stock_pcp *stock;
+       unsigned int stock_pages;
        unsigned long flags;
        bool ret = false;
 
@@ -2319,8 +2476,9 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       if (memcg == READ_ONCE(stock->cached) && stock->nr_pages >= nr_pages) {
-               stock->nr_pages -= nr_pages;
+       stock_pages = READ_ONCE(stock->nr_pages);
+       if (memcg == READ_ONCE(stock->cached) && stock_pages >= nr_pages) {
+               WRITE_ONCE(stock->nr_pages, stock_pages - nr_pages);
                ret = true;
        }
 
@@ -2334,16 +2492,18 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
  */
 static void drain_stock(struct memcg_stock_pcp *stock)
 {
+       unsigned int stock_pages = READ_ONCE(stock->nr_pages);
        struct mem_cgroup *old = READ_ONCE(stock->cached);
 
        if (!old)
                return;
 
-       if (stock->nr_pages) {
-               page_counter_uncharge(&old->memory, stock->nr_pages);
+       if (stock_pages) {
+               page_counter_uncharge(&old->memory, stock_pages);
                if (do_memsw_account())
-                       page_counter_uncharge(&old->memsw, stock->nr_pages);
-               stock->nr_pages = 0;
+                       page_counter_uncharge(&old->memsw, stock_pages);
+
+               WRITE_ONCE(stock->nr_pages, 0);
        }
 
        css_put(&old->css);
@@ -2369,8 +2529,7 @@ static void drain_local_stock(struct work_struct *dummy)
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       if (old)
-               obj_cgroup_put(old);
+       obj_cgroup_put(old);
 }
 
 /*
@@ -2380,6 +2539,7 @@ static void drain_local_stock(struct work_struct *dummy)
 static void __refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        struct memcg_stock_pcp *stock;
+       unsigned int stock_pages;
 
        stock = this_cpu_ptr(&memcg_stock);
        if (READ_ONCE(stock->cached) != memcg) { /* reset if necessary */
@@ -2387,9 +2547,10 @@ static void __refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
                css_get(&memcg->css);
                WRITE_ONCE(stock->cached, memcg);
        }
-       stock->nr_pages += nr_pages;
+       stock_pages = READ_ONCE(stock->nr_pages) + nr_pages;
+       WRITE_ONCE(stock->nr_pages, stock_pages);
 
-       if (stock->nr_pages > MEMCG_CHARGE_BATCH)
+       if (stock_pages > MEMCG_CHARGE_BATCH)
                drain_stock(stock);
 }
 
@@ -2428,7 +2589,7 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
 
                rcu_read_lock();
                memcg = READ_ONCE(stock->cached);
-               if (memcg && stock->nr_pages &&
+               if (memcg && READ_ONCE(stock->nr_pages) &&
                    mem_cgroup_is_descendant(memcg, root_memcg))
                        flush = true;
                else if (obj_stock_flush_required(stock, root_memcg))
@@ -2978,88 +3139,44 @@ void mem_cgroup_commit_charge(struct folio *folio, struct mem_cgroup *memcg)
 }
 
 #ifdef CONFIG_MEMCG_KMEM
-/*
- * The allocated objcg pointers array is not accounted directly.
- * Moreover, it should not come from DMA buffer and is not readily
- * reclaimable. So those GFP bits should be masked off.
- */
-#define OBJCGS_CLEAR_MASK      (__GFP_DMA | __GFP_RECLAIMABLE | \
-                                __GFP_ACCOUNT | __GFP_NOFAIL)
 
-/*
- * mod_objcg_mlstate() may be called with irq enabled, so
- * mod_memcg_lruvec_state() should be used.
- */
-static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
-                                    struct pglist_data *pgdat,
-                                    enum node_stat_item idx, int nr)
+static inline void __mod_objcg_mlstate(struct obj_cgroup *objcg,
+                                      struct pglist_data *pgdat,
+                                      enum node_stat_item idx, int nr)
 {
        struct mem_cgroup *memcg;
        struct lruvec *lruvec;
 
+       lockdep_assert_irqs_disabled();
+
        rcu_read_lock();
        memcg = obj_cgroup_memcg(objcg);
        lruvec = mem_cgroup_lruvec(memcg, pgdat);
-       mod_memcg_lruvec_state(lruvec, idx, nr);
+       __mod_memcg_lruvec_state(lruvec, idx, nr);
        rcu_read_unlock();
 }
 
-int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
-                                gfp_t gfp, bool new_slab)
-{
-       unsigned int objects = objs_per_slab(s, slab);
-       unsigned long memcg_data;
-       void *vec;
-
-       gfp &= ~OBJCGS_CLEAR_MASK;
-       vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
-                          slab_nid(slab));
-       if (!vec)
-               return -ENOMEM;
-
-       memcg_data = (unsigned long) vec | MEMCG_DATA_OBJCGS;
-       if (new_slab) {
-               /*
-                * If the slab is brand new and nobody can yet access its
-                * memcg_data, no synchronization is required and memcg_data can
-                * be simply assigned.
-                */
-               slab->memcg_data = memcg_data;
-       } else if (cmpxchg(&slab->memcg_data, 0, memcg_data)) {
-               /*
-                * If the slab is already in use, somebody can allocate and
-                * assign obj_cgroups in parallel. In this case the existing
-                * objcg vector should be reused.
-                */
-               kfree(vec);
-               return 0;
-       }
-
-       kmemleak_not_leak(vec);
-       return 0;
-}
-
 static __always_inline
 struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
 {
        /*
         * Slab objects are accounted individually, not per-page.
         * Memcg membership data for each individual object is saved in
-        * slab->memcg_data.
+        * slab->obj_exts.
         */
        if (folio_test_slab(folio)) {
-               struct obj_cgroup **objcgs;
+               struct slabobj_ext *obj_exts;
                struct slab *slab;
                unsigned int off;
 
                slab = folio_slab(folio);
-               objcgs = slab_objcgs(slab);
-               if (!objcgs)
+               obj_exts = slab_obj_exts(slab);
+               if (!obj_exts)
                        return NULL;
 
                off = obj_to_index(slab->slab_cache, slab, p);
-               if (objcgs[off])
-                       return obj_cgroup_memcg(objcgs[off]);
+               if (obj_exts[off].objcg)
+                       return obj_cgroup_memcg(obj_exts[off].objcg);
 
                return NULL;
        }
@@ -3067,7 +3184,7 @@ struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
        /*
         * folio_memcg_check() is used here, because in theory we can encounter
         * a folio where the slab flag has been cleared already, but
-        * slab->memcg_data has not been freed yet
+        * slab->obj_exts has not been freed yet
         * folio_memcg_check() will guarantee that a proper memory
         * cgroup pointer or NULL will be returned.
         */
@@ -3145,8 +3262,7 @@ static struct obj_cgroup *current_objcg_update(void)
                if (old) {
                        old = (struct obj_cgroup *)
                                ((unsigned long)old & ~CURRENT_OBJCG_UPDATE_FLAG);
-                       if (old)
-                               obj_cgroup_put(old);
+                       obj_cgroup_put(old);
 
                        old = NULL;
                }
@@ -3356,7 +3472,7 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
        obj_cgroup_put(objcg);
 }
 
-void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
+static void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
                     enum node_stat_item idx, int nr)
 {
        struct memcg_stock_pcp *stock;
@@ -3384,12 +3500,12 @@ void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
                struct pglist_data *oldpg = stock->cached_pgdat;
 
                if (stock->nr_slab_reclaimable_b) {
-                       mod_objcg_mlstate(objcg, oldpg, NR_SLAB_RECLAIMABLE_B,
+                       __mod_objcg_mlstate(objcg, oldpg, NR_SLAB_RECLAIMABLE_B,
                                          stock->nr_slab_reclaimable_b);
                        stock->nr_slab_reclaimable_b = 0;
                }
                if (stock->nr_slab_unreclaimable_b) {
-                       mod_objcg_mlstate(objcg, oldpg, NR_SLAB_UNRECLAIMABLE_B,
+                       __mod_objcg_mlstate(objcg, oldpg, NR_SLAB_UNRECLAIMABLE_B,
                                          stock->nr_slab_unreclaimable_b);
                        stock->nr_slab_unreclaimable_b = 0;
                }
@@ -3415,11 +3531,10 @@ void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
                }
        }
        if (nr)
-               mod_objcg_mlstate(objcg, pgdat, idx, nr);
+               __mod_objcg_mlstate(objcg, pgdat, idx, nr);
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       if (old)
-               obj_cgroup_put(old);
+       obj_cgroup_put(old);
 }
 
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
@@ -3482,13 +3597,13 @@ static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
         */
        if (stock->nr_slab_reclaimable_b || stock->nr_slab_unreclaimable_b) {
                if (stock->nr_slab_reclaimable_b) {
-                       mod_objcg_mlstate(old, stock->cached_pgdat,
+                       __mod_objcg_mlstate(old, stock->cached_pgdat,
                                          NR_SLAB_RECLAIMABLE_B,
                                          stock->nr_slab_reclaimable_b);
                        stock->nr_slab_reclaimable_b = 0;
                }
                if (stock->nr_slab_unreclaimable_b) {
-                       mod_objcg_mlstate(old, stock->cached_pgdat,
+                       __mod_objcg_mlstate(old, stock->cached_pgdat,
                                          NR_SLAB_UNRECLAIMABLE_B,
                                          stock->nr_slab_unreclaimable_b);
                        stock->nr_slab_unreclaimable_b = 0;
@@ -3546,8 +3661,7 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
        }
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       if (old)
-               obj_cgroup_put(old);
+       obj_cgroup_put(old);
 
        if (nr_pages)
                obj_cgroup_uncharge_pages(objcg, nr_pages);
@@ -3602,6 +3716,96 @@ void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
        refill_obj_stock(objcg, size, true);
 }
 
+static inline size_t obj_full_size(struct kmem_cache *s)
+{
+       /*
+        * For each accounted object there is an extra space which is used
+        * to store obj_cgroup membership. Charge it too.
+        */
+       return s->size + sizeof(struct obj_cgroup *);
+}
+
+bool __memcg_slab_post_alloc_hook(struct kmem_cache *s, struct list_lru *lru,
+                                 gfp_t flags, size_t size, void **p)
+{
+       struct obj_cgroup *objcg;
+       struct slab *slab;
+       unsigned long off;
+       size_t i;
+
+       /*
+        * The obtained objcg pointer is safe to use within the current scope,
+        * defined by current task or set_active_memcg() pair.
+        * obj_cgroup_get() is used to get a permanent reference.
+        */
+       objcg = current_obj_cgroup();
+       if (!objcg)
+               return true;
+
+       /*
+        * slab_alloc_node() avoids the NULL check, so we might be called with a
+        * single NULL object. kmem_cache_alloc_bulk() aborts if it can't fill
+        * the whole requested size.
+        * return success as there's nothing to free back
+        */
+       if (unlikely(*p == NULL))
+               return true;
+
+       flags &= gfp_allowed_mask;
+
+       if (lru) {
+               int ret;
+               struct mem_cgroup *memcg;
+
+               memcg = get_mem_cgroup_from_objcg(objcg);
+               ret = memcg_list_lru_alloc(memcg, lru, flags);
+               css_put(&memcg->css);
+
+               if (ret)
+                       return false;
+       }
+
+       if (obj_cgroup_charge(objcg, flags, size * obj_full_size(s)))
+               return false;
+
+       for (i = 0; i < size; i++) {
+               slab = virt_to_slab(p[i]);
+
+               if (!slab_obj_exts(slab) &&
+                   alloc_slab_obj_exts(slab, s, flags, false)) {
+                       obj_cgroup_uncharge(objcg, obj_full_size(s));
+                       continue;
+               }
+
+               off = obj_to_index(s, slab, p[i]);
+               obj_cgroup_get(objcg);
+               slab_obj_exts(slab)[off].objcg = objcg;
+               mod_objcg_state(objcg, slab_pgdat(slab),
+                               cache_vmstat_idx(s), obj_full_size(s));
+       }
+
+       return true;
+}
+
+void __memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
+                           void **p, int objects, struct slabobj_ext *obj_exts)
+{
+       for (int i = 0; i < objects; i++) {
+               struct obj_cgroup *objcg;
+               unsigned int off;
+
+               off = obj_to_index(s, slab, p[i]);
+               objcg = obj_exts[off].objcg;
+               if (!objcg)
+                       continue;
+
+               obj_exts[off].objcg = NULL;
+               obj_cgroup_uncharge(objcg, obj_full_size(s));
+               mod_objcg_state(objcg, slab_pgdat(slab), cache_vmstat_idx(s),
+                               -obj_full_size(s));
+               obj_cgroup_put(objcg);
+       }
+}
 #endif /* CONFIG_MEMCG_KMEM */
 
 /*
@@ -5431,26 +5635,33 @@ struct mem_cgroup *mem_cgroup_get_from_ino(unsigned long ino)
 }
 #endif
 
-static int alloc_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
+static bool alloc_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
 {
        struct mem_cgroup_per_node *pn;
 
        pn = kzalloc_node(sizeof(*pn), GFP_KERNEL, node);
        if (!pn)
-               return 1;
+               return false;
+
+       pn->lruvec_stats = kzalloc_node(sizeof(struct lruvec_stats),
+                                       GFP_KERNEL_ACCOUNT, node);
+       if (!pn->lruvec_stats)
+               goto fail;
 
        pn->lruvec_stats_percpu = alloc_percpu_gfp(struct lruvec_stats_percpu,
                                                   GFP_KERNEL_ACCOUNT);
-       if (!pn->lruvec_stats_percpu) {
-               kfree(pn);
-               return 1;
-       }
+       if (!pn->lruvec_stats_percpu)
+               goto fail;
 
        lruvec_init(&pn->lruvec);
        pn->memcg = memcg;
 
        memcg->nodeinfo[node] = pn;
-       return 0;
+       return true;
+fail:
+       kfree(pn->lruvec_stats);
+       kfree(pn);
+       return false;
 }
 
 static void free_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
@@ -5461,6 +5672,7 @@ static void free_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
                return;
 
        free_percpu(pn->lruvec_stats_percpu);
+       kfree(pn->lruvec_stats);
        kfree(pn);
 }
 
@@ -5468,8 +5680,7 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg)
 {
        int node;
 
-       if (memcg->orig_objcg)
-               obj_cgroup_put(memcg->orig_objcg);
+       obj_cgroup_put(memcg->orig_objcg);
 
        for_each_node(node)
                free_mem_cgroup_per_node_info(memcg, node);
@@ -5504,7 +5715,8 @@ static struct mem_cgroup *mem_cgroup_alloc(struct mem_cgroup *parent)
                goto fail;
        }
 
-       memcg->vmstats = kzalloc(sizeof(struct memcg_vmstats), GFP_KERNEL);
+       memcg->vmstats = kzalloc(sizeof(struct memcg_vmstats),
+                                GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats)
                goto fail;
 
@@ -5522,7 +5734,7 @@ static struct mem_cgroup *mem_cgroup_alloc(struct mem_cgroup *parent)
        }
 
        for_each_node(node)
-               if (alloc_mem_cgroup_per_node_info(memcg, node))
+               if (!alloc_mem_cgroup_per_node_info(memcg, node))
                        goto fail;
 
        if (memcg_wb_domain_init(memcg, GFP_KERNEL))
@@ -5588,6 +5800,7 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
                page_counter_init(&memcg->kmem, &parent->kmem);
                page_counter_init(&memcg->tcpmem, &parent->tcpmem);
        } else {
+               init_memcg_stats();
                init_memcg_events();
                page_counter_init(&memcg->memory, NULL);
                page_counter_init(&memcg->swap, NULL);
@@ -5759,7 +5972,7 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
 
        statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
 
-       for (i = 0; i < MEMCG_NR_STAT; i++) {
+       for (i = 0; i < MEMCG_VMSTAT_SIZE; i++) {
                /*
                 * Collect the aggregated propagation counts of groups
                 * below us. We're in a per-cpu loop here and this is
@@ -5814,18 +6027,19 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
 
        for_each_node_state(nid, N_MEMORY) {
                struct mem_cgroup_per_node *pn = memcg->nodeinfo[nid];
-               struct mem_cgroup_per_node *ppn = NULL;
+               struct lruvec_stats *lstats = pn->lruvec_stats;
+               struct lruvec_stats *plstats = NULL;
                struct lruvec_stats_percpu *lstatc;
 
                if (parent)
-                       ppn = parent->nodeinfo[nid];
+                       plstats = parent->nodeinfo[nid]->lruvec_stats;
 
                lstatc = per_cpu_ptr(pn->lruvec_stats_percpu, cpu);
 
-               for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++) {
-                       delta = pn->lruvec_stats.state_pending[i];
+               for (i = 0; i < NR_MEMCG_NODE_STAT_ITEMS; i++) {
+                       delta = lstats->state_pending[i];
                        if (delta)
-                               pn->lruvec_stats.state_pending[i] = 0;
+                               lstats->state_pending[i] = 0;
 
                        delta_cpu = 0;
                        v = READ_ONCE(lstatc->state[i]);
@@ -5836,16 +6050,16 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
                        }
 
                        if (delta_cpu)
-                               pn->lruvec_stats.state_local[i] += delta_cpu;
+                               lstats->state_local[i] += delta_cpu;
 
                        if (delta) {
-                               pn->lruvec_stats.state[i] += delta;
-                               if (ppn)
-                                       ppn->lruvec_stats.state_pending[i] += delta;
+                               lstats->state[i] += delta;
+                               if (plstats)
+                                       plstats->state_pending[i] += delta;
                        }
                }
        }
-       statc->stats_updates = 0;
+       WRITE_ONCE(statc->stats_updates, 0);
        /* We are in a per-cpu loop here, only do the atomic write once */
        if (atomic64_read(&memcg->vmstats->stats_updates))
                atomic64_set(&memcg->vmstats->stats_updates, 0);
@@ -6620,8 +6834,7 @@ static void mem_cgroup_exit(struct task_struct *task)
 
        objcg = (struct obj_cgroup *)
                ((unsigned long)objcg & ~CURRENT_OBJCG_UPDATE_FLAG);
-       if (objcg)
-               obj_cgroup_put(objcg);
+       obj_cgroup_put(objcg);
 
        /*
         * Some kernel allocations can happen after this point,
@@ -7448,6 +7661,9 @@ static void uncharge_folio(struct folio *folio, struct uncharge_gather *ug)
        struct obj_cgroup *objcg;
 
        VM_BUG_ON_FOLIO(folio_test_lru(folio), folio);
+       VM_BUG_ON_FOLIO(folio_order(folio) > 1 &&
+                       !folio_test_hugetlb(folio) &&
+                       !list_empty(&folio->_deferred_list), folio);
 
        /*
         * Nobody should be changing or seriously looking at