]> git.ipfire.org Git - thirdparty/kernel/stable.git/blobdiff - net/bpf/test_run.c
bpf: add tests for direct packet access from CGROUP_SKB
[thirdparty/kernel/stable.git] / net / bpf / test_run.c
index f4078830ea505ee955d4bc83010bd2fec59583fe..c89c22c49015ff070f228bece397937d7cdce8b5 100644 (file)
 #include <linux/etherdevice.h>
 #include <linux/filter.h>
 #include <linux/sched/signal.h>
+#include <net/sock.h>
+#include <net/tcp.h>
 
 static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx,
-                                           struct bpf_cgroup_storage *storage)
+               struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE])
 {
        u32 ret;
 
@@ -28,13 +30,20 @@ static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx,
 
 static u32 bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *time)
 {
-       struct bpf_cgroup_storage *storage = NULL;
+       struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { 0 };
+       enum bpf_cgroup_storage_type stype;
        u64 time_start, time_spent = 0;
        u32 ret = 0, i;
 
-       storage = bpf_cgroup_storage_alloc(prog);
-       if (IS_ERR(storage))
-               return PTR_ERR(storage);
+       for_each_cgroup_storage_type(stype) {
+               storage[stype] = bpf_cgroup_storage_alloc(prog, stype);
+               if (IS_ERR(storage[stype])) {
+                       storage[stype] = NULL;
+                       for_each_cgroup_storage_type(stype)
+                               bpf_cgroup_storage_free(storage[stype]);
+                       return -ENOMEM;
+               }
+       }
 
        if (!repeat)
                repeat = 1;
@@ -53,7 +62,8 @@ static u32 bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *time)
        do_div(time_spent, repeat);
        *time = time_spent > U32_MAX ? U32_MAX : (u32)time_spent;
 
-       bpf_cgroup_storage_free(storage);
+       for_each_cgroup_storage_type(stype)
+               bpf_cgroup_storage_free(storage[stype]);
 
        return ret;
 }
@@ -107,6 +117,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
        u32 retval, duration;
        int hh_len = ETH_HLEN;
        struct sk_buff *skb;
+       struct sock *sk;
        void *data;
        int ret;
 
@@ -129,11 +140,21 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
                break;
        }
 
+       sk = kzalloc(sizeof(struct sock), GFP_USER);
+       if (!sk) {
+               kfree(data);
+               return -ENOMEM;
+       }
+       sock_net_set(sk, current->nsproxy->net_ns);
+       sock_init_data(NULL, sk);
+
        skb = build_skb(data, 0);
        if (!skb) {
                kfree(data);
+               kfree(sk);
                return -ENOMEM;
        }
+       skb->sk = sk;
 
        skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN);
        __skb_put(skb, size);
@@ -151,6 +172,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
 
                        if (pskb_expand_head(skb, nhead, 0, GFP_USER)) {
                                kfree_skb(skb);
+                               kfree(sk);
                                return -ENOMEM;
                        }
                }
@@ -163,6 +185,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
                size = skb_headlen(skb);
        ret = bpf_test_finish(kattr, uattr, skb->data, size, retval, duration);
        kfree_skb(skb);
+       kfree(sk);
        return ret;
 }