]> git.ipfire.org Git - thirdparty/kernel/stable.git/blobdiff - net/bpf/test_run.c
bpf: add tests for direct packet access from CGROUP_SKB
[thirdparty/kernel/stable.git] / net / bpf / test_run.c
index 0c423b8cd75cce9ada273d545216112c829964e0..c89c22c49015ff070f228bece397937d7cdce8b5 100644 (file)
@@ -10,6 +10,8 @@
 #include <linux/etherdevice.h>
 #include <linux/filter.h>
 #include <linux/sched/signal.h>
+#include <net/sock.h>
+#include <net/tcp.h>
 
 static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx,
                struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE])
@@ -115,6 +117,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
        u32 retval, duration;
        int hh_len = ETH_HLEN;
        struct sk_buff *skb;
+       struct sock *sk;
        void *data;
        int ret;
 
@@ -137,11 +140,21 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
                break;
        }
 
+       sk = kzalloc(sizeof(struct sock), GFP_USER);
+       if (!sk) {
+               kfree(data);
+               return -ENOMEM;
+       }
+       sock_net_set(sk, current->nsproxy->net_ns);
+       sock_init_data(NULL, sk);
+
        skb = build_skb(data, 0);
        if (!skb) {
                kfree(data);
+               kfree(sk);
                return -ENOMEM;
        }
+       skb->sk = sk;
 
        skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN);
        __skb_put(skb, size);
@@ -159,6 +172,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
 
                        if (pskb_expand_head(skb, nhead, 0, GFP_USER)) {
                                kfree_skb(skb);
+                               kfree(sk);
                                return -ENOMEM;
                        }
                }
@@ -171,6 +185,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
                size = skb_headlen(skb);
        ret = bpf_test_finish(kattr, uattr, skb->data, size, retval, duration);
        kfree_skb(skb);
+       kfree(sk);
        return ret;
 }