#include <linux/rculist_nulls.h>
#include <linux/rcupdate_wait.h>
#include <linux/random.h>
+#include <linux/rhashtable.h>
#include <uapi/linux/btf.h>
#include <linux/rcupdate_trace.h>
#include <linux/btf_ids.h>
BATCH_OPS(htab),
.map_btf_id = &htab_map_btf_ids[0],
};
+
+struct rhtab_elem {
+ struct rhash_head node;
+ /* key bytes, then value bytes follow */
+ u8 data[] __aligned(8);
+};
+
+struct bpf_rhtab {
+ struct bpf_map map;
+ struct rhashtable ht;
+ struct bpf_mem_alloc ma;
+ u32 elem_size;
+};
+
+static const struct rhashtable_params rhtab_params = {
+ .head_offset = offsetof(struct rhtab_elem, node),
+ .key_offset = offsetof(struct rhtab_elem, data),
+};
+
+static inline void *rhtab_elem_value(struct rhtab_elem *l, u32 key_size)
+{
+ return l->data + round_up(key_size, 8);
+}
+
+static struct bpf_map *rhtab_map_alloc(union bpf_attr *attr)
+{
+ struct rhashtable_params params;
+ struct bpf_rhtab *rhtab;
+ int err = 0;
+
+ rhtab = bpf_map_area_alloc(sizeof(*rhtab), NUMA_NO_NODE);
+ if (!rhtab)
+ return ERR_PTR(-ENOMEM);
+
+ bpf_map_init_from_attr(&rhtab->map, attr);
+
+ if (rhtab->map.max_entries > 1UL << 31) {
+ err = -E2BIG;
+ goto free_rhtab;
+ }
+
+ rhtab->elem_size = sizeof(struct rhtab_elem) + round_up(rhtab->map.key_size, 8) +
+ round_up(rhtab->map.value_size, 8);
+
+ params = rhtab_params;
+ params.key_len = rhtab->map.key_size;
+ params.nelem_hint = (u32)attr->map_extra;
+ params.automatic_shrinking = true;
+
+ err = rhashtable_init(&rhtab->ht, ¶ms);
+ if (err)
+ goto free_rhtab;
+
+ /* Set max_elems after rhashtable_init() since init zeroes the struct */
+ rhtab->ht.max_elems = rhtab->map.max_entries;
+
+ err = bpf_mem_alloc_init(&rhtab->ma, rhtab->elem_size, false);
+ if (err)
+ goto destroy_rhtab;
+
+ return &rhtab->map;
+
+destroy_rhtab:
+ rhashtable_destroy(&rhtab->ht);
+free_rhtab:
+ bpf_map_area_free(rhtab);
+ return ERR_PTR(err);
+}
+
+static int rhtab_map_alloc_check(union bpf_attr *attr)
+{
+ if (!(attr->map_flags & BPF_F_NO_PREALLOC))
+ return -EINVAL;
+
+ if (attr->map_flags & BPF_F_ZERO_SEED)
+ return -EINVAL;
+
+ if (attr->key_size > U16_MAX)
+ return -E2BIG;
+
+ if (attr->map_extra >> 32)
+ return -EINVAL;
+
+ if ((u32)attr->map_extra > U16_MAX)
+ return -E2BIG;
+
+ if ((u32)attr->map_extra > attr->max_entries)
+ return -EINVAL;
+
+ return htab_map_alloc_check(attr);
+}
+
+static void rhtab_free_elem(void *ptr, void *arg)
+{
+ struct bpf_rhtab *rhtab = arg;
+ struct rhtab_elem *elem = ptr;
+
+ bpf_mem_cache_free_rcu(&rhtab->ma, elem);
+}
+
+static void rhtab_map_free(struct bpf_map *map)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+
+ rhashtable_free_and_destroy(&rhtab->ht, rhtab_free_elem, rhtab);
+ bpf_mem_alloc_destroy(&rhtab->ma);
+ bpf_map_area_free(rhtab);
+}
+
+static void *rhtab_lookup_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+
+ /* Hold RCU lock in case sleepable program calls via gen_lookup */
+ guard(rcu)();
+
+ return rhashtable_lookup_likely(&rhtab->ht, key, rhtab_params);
+}
+
+static void *rhtab_map_lookup_elem(struct bpf_map *map, void *key) __must_hold(RCU)
+{
+ struct rhtab_elem *l;
+
+ l = rhtab_lookup_elem(map, key);
+ return l ? rhtab_elem_value(l, map->key_size) : NULL;
+}
+
+static void rhtab_read_elem_value(struct bpf_map *map, void *dst, struct rhtab_elem *elem,
+ u64 flags)
+{
+ void *src = rhtab_elem_value(elem, map->key_size);
+
+ if (flags & BPF_F_LOCK)
+ copy_map_value_locked(map, dst, src, true);
+ else
+ copy_map_value(map, dst, src);
+}
+
+static int rhtab_delete_elem(struct bpf_rhtab *rhtab, struct rhtab_elem *elem, void *copy,
+ u64 flags)
+{
+ int err;
+
+ /*
+ * disable_instrumentation() mitigates the deadlock for programs running in NMI context.
+ * rhashtable locks bucket with local_irq_save(). Only NMI programs may reenter
+ * rhashtable code, bpf_disable_instrumentation() disables programs running in NMI, except
+ * raw tracepoints, which we don't have in rhashtable.
+ */
+ bpf_disable_instrumentation();
+ err = rhashtable_remove_fast(&rhtab->ht, &elem->node, rhtab_params);
+ bpf_enable_instrumentation();
+
+ if (err)
+ return err;
+
+ if (copy) {
+ rhtab_read_elem_value(&rhtab->map, copy, elem, flags);
+ check_and_init_map_value(&rhtab->map, copy);
+ }
+
+ bpf_mem_cache_free_rcu(&rhtab->ma, elem);
+ return 0;
+}
+
+
+static long rhtab_map_delete_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+ struct rhtab_elem *elem;
+
+ guard(rcu)();
+
+ elem = rhtab_lookup_elem(map, key);
+ if (!elem)
+ return -ENOENT;
+
+ return rhtab_delete_elem(rhtab, elem, NULL, 0);
+}
+
+static int rhtab_map_lookup_and_delete_elem(struct bpf_map *map, void *key, void *value, u64 flags)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+ struct rhtab_elem *elem;
+ int err;
+
+ err = bpf_map_check_op_flags(map, flags, BPF_F_LOCK);
+ if (err)
+ return err;
+
+ guard(rcu)();
+
+ elem = rhtab_lookup_elem(map, key);
+ if (!elem)
+ return -ENOENT;
+
+ return rhtab_delete_elem(rhtab, elem, value, flags);
+}
+
+static long rhtab_map_update_existing(struct bpf_map *map, struct rhtab_elem *elem, void *value,
+ u64 map_flags)
+{
+ void *old_val = rhtab_elem_value(elem, map->key_size);
+
+ if (map_flags & BPF_NOEXIST)
+ return -EEXIST;
+
+ if (map_flags & BPF_F_LOCK)
+ copy_map_value_locked(map, old_val, value, false);
+ else
+ copy_map_value(map, old_val, value);
+ return 0;
+}
+
+static long rhtab_map_update_elem(struct bpf_map *map, void *key, void *value, u64 map_flags)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+ struct rhtab_elem *elem, *tmp;
+
+ if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST))
+ return -EINVAL;
+
+ if ((map_flags & BPF_F_LOCK) && !btf_record_has_field(map->record, BPF_SPIN_LOCK))
+ return -EINVAL;
+
+ guard(rcu)();
+ elem = rhtab_lookup_elem(map, key);
+ if (elem)
+ return rhtab_map_update_existing(map, elem, value, map_flags);
+
+ if (map_flags & BPF_EXIST)
+ return -ENOENT;
+
+ /* Check max_entries limit before inserting new element */
+ if (atomic_read(&rhtab->ht.nelems) >= map->max_entries)
+ return -E2BIG;
+
+ elem = bpf_mem_cache_alloc(&rhtab->ma);
+ if (!elem)
+ return -ENOMEM;
+
+ memcpy(elem->data, key, map->key_size);
+ copy_map_value(map, rhtab_elem_value(elem, map->key_size), value);
+
+ /* Prevent deadlock for NMI programs attempting to take bucket lock */
+ bpf_disable_instrumentation();
+ tmp = rhashtable_lookup_get_insert_fast(&rhtab->ht, &elem->node, rhtab_params);
+ bpf_enable_instrumentation();
+
+ if (tmp) {
+ bpf_mem_cache_free(&rhtab->ma, elem);
+ if (IS_ERR(tmp))
+ return PTR_ERR(tmp);
+
+ return rhtab_map_update_existing(map, tmp, value, map_flags);
+ }
+
+ return 0;
+}
+
+static int rhtab_map_gen_lookup(struct bpf_map *map, struct bpf_insn *insn_buf)
+{
+ struct bpf_insn *insn = insn_buf;
+ const int ret = BPF_REG_0;
+
+ BUILD_BUG_ON(!__same_type(&rhtab_lookup_elem,
+ (void *(*)(struct bpf_map *map, void *key)) NULL));
+ *insn++ = BPF_EMIT_CALL(rhtab_lookup_elem);
+ *insn++ = BPF_JMP_IMM(BPF_JEQ, ret, 0, 1);
+ *insn++ = BPF_ALU64_IMM(BPF_ADD, ret,
+ offsetof(struct rhtab_elem, data) + round_up(map->key_size, 8));
+
+ return insn - insn_buf;
+}
+
+static void rhtab_map_free_internal_structs(struct bpf_map *map)
+{
+}
+
+static int rhtab_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
+{
+ return -EOPNOTSUPP;
+}
+
+static u64 rhtab_map_mem_usage(const struct bpf_map *map)
+{
+ struct bpf_rhtab *rhtab = container_of(map, struct bpf_rhtab, map);
+ u64 num_entries;
+
+ /* Excludes rhashtable bucket overhead (~ nelems * sizeof(void *) at 75% load). */
+ num_entries = atomic_read(&rhtab->ht.nelems);
+ return sizeof(struct bpf_rhtab) + rhtab->elem_size * num_entries;
+}
+
+BTF_ID_LIST_SINGLE(rhtab_map_btf_ids, struct, bpf_rhtab)
+const struct bpf_map_ops rhtab_map_ops = {
+ .map_meta_equal = bpf_map_meta_equal,
+ .map_alloc_check = rhtab_map_alloc_check,
+ .map_alloc = rhtab_map_alloc,
+ .map_free = rhtab_map_free,
+ .map_get_next_key = rhtab_map_get_next_key,
+ .map_release_uref = rhtab_map_free_internal_structs,
+ .map_lookup_elem = rhtab_map_lookup_elem,
+ .map_lookup_and_delete_elem = rhtab_map_lookup_and_delete_elem,
+ .map_update_elem = rhtab_map_update_elem,
+ .map_delete_elem = rhtab_map_delete_elem,
+ .map_gen_lookup = rhtab_map_gen_lookup,
+ .map_mem_usage = rhtab_map_mem_usage,
+ .map_btf_id = &rhtab_map_btf_ids[0],
+};