]> git.ipfire.org Git - thirdparty/kernel/linux.git/blobdiff - net/core/sock_map.c
Merge tag 'net-next-6.10' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev...
[thirdparty/kernel/linux.git] / net / core / sock_map.c
index 8598466a3805784f58497d9607c5ace6f081cefb..9402889840bf7e4fe2adb743d387b9dcdbe17024 100644 (file)
@@ -24,8 +24,16 @@ struct bpf_stab {
 #define SOCK_CREATE_FLAG_MASK                          \
        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
 
+/* This mutex is used to
+ *  - protect race between prog/link attach/detach and link prog update, and
+ *  - protect race between releasing and accessing map in bpf_link.
+ * A single global mutex lock is used since it is expected contention is low.
+ */
+static DEFINE_MUTEX(sockmap_mutex);
+
 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
-                               struct bpf_prog *old, u32 which);
+                               struct bpf_prog *old, struct bpf_link *link,
+                               u32 which);
 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map);
 
 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
@@ -71,7 +79,9 @@ int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
        map = __bpf_map_get(f);
        if (IS_ERR(map))
                return PTR_ERR(map);
-       ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
+       mutex_lock(&sockmap_mutex);
+       ret = sock_map_prog_update(map, prog, NULL, NULL, attr->attach_type);
+       mutex_unlock(&sockmap_mutex);
        fdput(f);
        return ret;
 }
@@ -103,7 +113,9 @@ int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
                goto put_prog;
        }
 
-       ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
+       mutex_lock(&sockmap_mutex);
+       ret = sock_map_prog_update(map, NULL, prog, NULL, attr->attach_type);
+       mutex_unlock(&sockmap_mutex);
 put_prog:
        bpf_prog_put(prog);
 put_map:
@@ -1460,55 +1472,84 @@ static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
        return NULL;
 }
 
-static int sock_map_prog_lookup(struct bpf_map *map, struct bpf_prog ***pprog,
-                               u32 which)
+static int sock_map_prog_link_lookup(struct bpf_map *map, struct bpf_prog ***pprog,
+                                    struct bpf_link ***plink, u32 which)
 {
        struct sk_psock_progs *progs = sock_map_progs(map);
+       struct bpf_prog **cur_pprog;
+       struct bpf_link **cur_plink;
 
        if (!progs)
                return -EOPNOTSUPP;
 
        switch (which) {
        case BPF_SK_MSG_VERDICT:
-               *pprog = &progs->msg_parser;
+               cur_pprog = &progs->msg_parser;
+               cur_plink = &progs->msg_parser_link;
                break;
 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
        case BPF_SK_SKB_STREAM_PARSER:
-               *pprog = &progs->stream_parser;
+               cur_pprog = &progs->stream_parser;
+               cur_plink = &progs->stream_parser_link;
                break;
 #endif
        case BPF_SK_SKB_STREAM_VERDICT:
                if (progs->skb_verdict)
                        return -EBUSY;
-               *pprog = &progs->stream_verdict;
+               cur_pprog = &progs->stream_verdict;
+               cur_plink = &progs->stream_verdict_link;
                break;
        case BPF_SK_SKB_VERDICT:
                if (progs->stream_verdict)
                        return -EBUSY;
-               *pprog = &progs->skb_verdict;
+               cur_pprog = &progs->skb_verdict;
+               cur_plink = &progs->skb_verdict_link;
                break;
        default:
                return -EOPNOTSUPP;
        }
 
+       *pprog = cur_pprog;
+       if (plink)
+               *plink = cur_plink;
        return 0;
 }
 
+/* Handle the following four cases:
+ * prog_attach: prog != NULL, old == NULL, link == NULL
+ * prog_detach: prog == NULL, old != NULL, link == NULL
+ * link_attach: prog != NULL, old == NULL, link != NULL
+ * link_detach: prog == NULL, old != NULL, link != NULL
+ */
 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
-                               struct bpf_prog *old, u32 which)
+                               struct bpf_prog *old, struct bpf_link *link,
+                               u32 which)
 {
        struct bpf_prog **pprog;
+       struct bpf_link **plink;
        int ret;
 
-       ret = sock_map_prog_lookup(map, &pprog, which);
+       ret = sock_map_prog_link_lookup(map, &pprog, &plink, which);
        if (ret)
                return ret;
 
-       if (old)
-               return psock_replace_prog(pprog, prog, old);
+       /* for prog_attach/prog_detach/link_attach, return error if a bpf_link
+        * exists for that prog.
+        */
+       if ((!link || prog) && *plink)
+               return -EBUSY;
 
-       psock_set_prog(pprog, prog);
-       return 0;
+       if (old) {
+               ret = psock_replace_prog(pprog, prog, old);
+               if (!ret)
+                       *plink = NULL;
+       } else {
+               psock_set_prog(pprog, prog);
+               if (link)
+                       *plink = link;
+       }
+
+       return ret;
 }
 
 int sock_map_bpf_prog_query(const union bpf_attr *attr,
@@ -1533,7 +1574,7 @@ int sock_map_bpf_prog_query(const union bpf_attr *attr,
 
        rcu_read_lock();
 
-       ret = sock_map_prog_lookup(map, &pprog, attr->query.attach_type);
+       ret = sock_map_prog_link_lookup(map, &pprog, NULL, attr->query.attach_type);
        if (ret)
                goto end;
 
@@ -1663,6 +1704,196 @@ void sock_map_close(struct sock *sk, long timeout)
 }
 EXPORT_SYMBOL_GPL(sock_map_close);
 
+struct sockmap_link {
+       struct bpf_link link;
+       struct bpf_map *map;
+       enum bpf_attach_type attach_type;
+};
+
+static void sock_map_link_release(struct bpf_link *link)
+{
+       struct sockmap_link *sockmap_link = container_of(link, struct sockmap_link, link);
+
+       mutex_lock(&sockmap_mutex);
+       if (!sockmap_link->map)
+               goto out;
+
+       WARN_ON_ONCE(sock_map_prog_update(sockmap_link->map, NULL, link->prog, link,
+                                         sockmap_link->attach_type));
+
+       bpf_map_put_with_uref(sockmap_link->map);
+       sockmap_link->map = NULL;
+out:
+       mutex_unlock(&sockmap_mutex);
+}
+
+static int sock_map_link_detach(struct bpf_link *link)
+{
+       sock_map_link_release(link);
+       return 0;
+}
+
+static void sock_map_link_dealloc(struct bpf_link *link)
+{
+       kfree(link);
+}
+
+/* Handle the following two cases:
+ * case 1: link != NULL, prog != NULL, old != NULL
+ * case 2: link != NULL, prog != NULL, old == NULL
+ */
+static int sock_map_link_update_prog(struct bpf_link *link,
+                                    struct bpf_prog *prog,
+                                    struct bpf_prog *old)
+{
+       const struct sockmap_link *sockmap_link = container_of(link, struct sockmap_link, link);
+       struct bpf_prog **pprog, *old_link_prog;
+       struct bpf_link **plink;
+       int ret = 0;
+
+       mutex_lock(&sockmap_mutex);
+
+       /* If old prog is not NULL, ensure old prog is the same as link->prog. */
+       if (old && link->prog != old) {
+               ret = -EPERM;
+               goto out;
+       }
+       /* Ensure link->prog has the same type/attach_type as the new prog. */
+       if (link->prog->type != prog->type ||
+           link->prog->expected_attach_type != prog->expected_attach_type) {
+               ret = -EINVAL;
+               goto out;
+       }
+
+       ret = sock_map_prog_link_lookup(sockmap_link->map, &pprog, &plink,
+                                       sockmap_link->attach_type);
+       if (ret)
+               goto out;
+
+       /* return error if the stored bpf_link does not match the incoming bpf_link. */
+       if (link != *plink) {
+               ret = -EBUSY;
+               goto out;
+       }
+
+       if (old) {
+               ret = psock_replace_prog(pprog, prog, old);
+               if (ret)
+                       goto out;
+       } else {
+               psock_set_prog(pprog, prog);
+       }
+
+       bpf_prog_inc(prog);
+       old_link_prog = xchg(&link->prog, prog);
+       bpf_prog_put(old_link_prog);
+
+out:
+       mutex_unlock(&sockmap_mutex);
+       return ret;
+}
+
+static u32 sock_map_link_get_map_id(const struct sockmap_link *sockmap_link)
+{
+       u32 map_id = 0;
+
+       mutex_lock(&sockmap_mutex);
+       if (sockmap_link->map)
+               map_id = sockmap_link->map->id;
+       mutex_unlock(&sockmap_mutex);
+       return map_id;
+}
+
+static int sock_map_link_fill_info(const struct bpf_link *link,
+                                  struct bpf_link_info *info)
+{
+       const struct sockmap_link *sockmap_link = container_of(link, struct sockmap_link, link);
+       u32 map_id = sock_map_link_get_map_id(sockmap_link);
+
+       info->sockmap.map_id = map_id;
+       info->sockmap.attach_type = sockmap_link->attach_type;
+       return 0;
+}
+
+static void sock_map_link_show_fdinfo(const struct bpf_link *link,
+                                     struct seq_file *seq)
+{
+       const struct sockmap_link *sockmap_link = container_of(link, struct sockmap_link, link);
+       u32 map_id = sock_map_link_get_map_id(sockmap_link);
+
+       seq_printf(seq, "map_id:\t%u\n", map_id);
+       seq_printf(seq, "attach_type:\t%u\n", sockmap_link->attach_type);
+}
+
+static const struct bpf_link_ops sock_map_link_ops = {
+       .release = sock_map_link_release,
+       .dealloc = sock_map_link_dealloc,
+       .detach = sock_map_link_detach,
+       .update_prog = sock_map_link_update_prog,
+       .fill_link_info = sock_map_link_fill_info,
+       .show_fdinfo = sock_map_link_show_fdinfo,
+};
+
+int sock_map_link_create(const union bpf_attr *attr, struct bpf_prog *prog)
+{
+       struct bpf_link_primer link_primer;
+       struct sockmap_link *sockmap_link;
+       enum bpf_attach_type attach_type;
+       struct bpf_map *map;
+       int ret;
+
+       if (attr->link_create.flags)
+               return -EINVAL;
+
+       map = bpf_map_get_with_uref(attr->link_create.target_fd);
+       if (IS_ERR(map))
+               return PTR_ERR(map);
+       if (map->map_type != BPF_MAP_TYPE_SOCKMAP && map->map_type != BPF_MAP_TYPE_SOCKHASH) {
+               ret = -EINVAL;
+               goto out;
+       }
+
+       sockmap_link = kzalloc(sizeof(*sockmap_link), GFP_USER);
+       if (!sockmap_link) {
+               ret = -ENOMEM;
+               goto out;
+       }
+
+       attach_type = attr->link_create.attach_type;
+       bpf_link_init(&sockmap_link->link, BPF_LINK_TYPE_SOCKMAP, &sock_map_link_ops, prog);
+       sockmap_link->map = map;
+       sockmap_link->attach_type = attach_type;
+
+       ret = bpf_link_prime(&sockmap_link->link, &link_primer);
+       if (ret) {
+               kfree(sockmap_link);
+               goto out;
+       }
+
+       mutex_lock(&sockmap_mutex);
+       ret = sock_map_prog_update(map, prog, NULL, &sockmap_link->link, attach_type);
+       mutex_unlock(&sockmap_mutex);
+       if (ret) {
+               bpf_link_cleanup(&link_primer);
+               goto out;
+       }
+
+       /* Increase refcnt for the prog since when old prog is replaced with
+        * psock_replace_prog() and psock_set_prog() its refcnt will be decreased.
+        *
+        * Actually, we do not need to increase refcnt for the prog since bpf_link
+        * will hold a reference. But in order to have less complexity w.r.t.
+        * replacing/setting prog, let us increase the refcnt to make things simpler.
+        */
+       bpf_prog_inc(prog);
+
+       return bpf_link_settle(&link_primer);
+
+out:
+       bpf_map_put_with_uref(map);
+       return ret;
+}
+
 static int sock_map_iter_attach_target(struct bpf_prog *prog,
                                       union bpf_iter_link_info *linfo,
                                       struct bpf_iter_aux_info *aux)