--- /dev/null
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ * Memory Controller-related BPF kfuncs and auxiliary code
+ *
+ * Author: Roman Gushchin <roman.gushchin@linux.dev>
+ */
+
+#include <linux/memcontrol.h>
+#include <linux/bpf.h>
+
+__bpf_kfunc_start_defs();
+
+/**
+ * bpf_get_mem_cgroup - Get a reference to a memory cgroup
+ * @css: pointer to the css structure
+ *
+ * It's fine to pass a css which belongs to any cgroup controller,
+ * e.g. unified hierarchy's main css.
+ *
+ * Implements KF_ACQUIRE semantics.
+ *
+ * Return: A pointer to a mem_cgroup structure after bumping
+ * the corresponding css's reference counter.
+ */
+__bpf_kfunc struct mem_cgroup *
+bpf_get_mem_cgroup(struct cgroup_subsys_state *css)
+{
+ struct mem_cgroup *memcg = NULL;
+ bool rcu_unlock = false;
+
+ if (mem_cgroup_disabled() || !root_mem_cgroup)
+ return NULL;
+
+ if (root_mem_cgroup->css.ss != css->ss) {
+ struct cgroup *cgroup = css->cgroup;
+ int ssid = root_mem_cgroup->css.ss->id;
+
+ rcu_read_lock();
+ rcu_unlock = true;
+ css = rcu_dereference_raw(cgroup->subsys[ssid]);
+ }
+
+ if (css && css_tryget(css))
+ memcg = container_of(css, struct mem_cgroup, css);
+
+ if (rcu_unlock)
+ rcu_read_unlock();
+
+ return memcg;
+}
+
+/**
+ * bpf_put_mem_cgroup - Put a reference to a memory cgroup
+ * @memcg: memory cgroup to release
+ *
+ * Releases a previously acquired memcg reference.
+ * Implements KF_RELEASE semantics.
+ */
+__bpf_kfunc void bpf_put_mem_cgroup(struct mem_cgroup *memcg)
+{
+ css_put(&memcg->css);
+}
+
+__bpf_kfunc_end_defs();
+
+BTF_KFUNCS_START(bpf_memcontrol_kfuncs)
+BTF_ID_FLAGS(func, bpf_get_mem_cgroup, KF_ACQUIRE | KF_RET_NULL | KF_RCU)
+BTF_ID_FLAGS(func, bpf_put_mem_cgroup, KF_RELEASE)
+
+BTF_KFUNCS_END(bpf_memcontrol_kfuncs)
+
+static const struct btf_kfunc_id_set bpf_memcontrol_kfunc_set = {
+ .owner = THIS_MODULE,
+ .set = &bpf_memcontrol_kfuncs,
+};
+
+static int __init bpf_memcontrol_init(void)
+{
+ int err;
+
+ err = register_btf_kfunc_id_set(BPF_PROG_TYPE_UNSPEC,
+ &bpf_memcontrol_kfunc_set);
+ if (err)
+ pr_warn("error while registering bpf memcontrol kfuncs: %d", err);
+
+ return err;
+}
+late_initcall(bpf_memcontrol_init);