]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
selftests/bpf: Support more socket types in create_pair()
authorMichal Luczaj <mhal@rbox.co>
Wed, 31 Jul 2024 10:01:26 +0000 (12:01 +0200)
committerMartin KaFai Lau <martin.lau@kernel.org>
Mon, 19 Aug 2024 23:43:31 +0000 (16:43 -0700)
Extend the function to allow creating socket pairs of SOCK_STREAM,
SOCK_DGRAM and SOCK_SEQPACKET.

Adapt direct callers and leave further cleanups for the following patch.

Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Tested-by: Jakub Sitnicki <jakub@cloudflare.com>
Suggested-by: Jakub Sitnicki <jakub@cloudflare.com>
Signed-off-by: Michal Luczaj <mhal@rbox.co>
Link: https://lore.kernel.org/r/20240731-selftest-sockmap-fixes-v2-1-08a0c73abed2@rbox.co
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
tools/testing/selftests/bpf/prog_tests/sockmap_basic.c
tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h

index 1337153eb0ad7926eb4bc4648ffcbff03a8da90e..5b17d69c9ee68c49f885905886a39eeacdbbc597 100644 (file)
@@ -451,11 +451,11 @@ out:
 #define MAX_EVENTS 10
 static void test_sockmap_skb_verdict_shutdown(void)
 {
+       int n, err, map, verdict, c1 = -1, p1 = -1;
        struct epoll_event ev, events[MAX_EVENTS];
-       int n, err, map, verdict, s, c1 = -1, p1 = -1;
        struct test_sockmap_pass_prog *skel;
-       int epollfd;
        int zero = 0;
+       int epollfd;
        char b;
 
        skel = test_sockmap_pass_prog__open_and_load();
@@ -469,10 +469,7 @@ static void test_sockmap_skb_verdict_shutdown(void)
        if (!ASSERT_OK(err, "bpf_prog_attach"))
                goto out;
 
-       s = socket_loopback(AF_INET, SOCK_STREAM);
-       if (s < 0)
-               goto out;
-       err = create_pair(s, AF_INET, SOCK_STREAM, &c1, &p1);
+       err = create_pair(AF_INET, SOCK_STREAM, &c1, &p1);
        if (err < 0)
                goto out;
 
@@ -570,16 +567,12 @@ out:
 
 static void test_sockmap_skb_verdict_peek_helper(int map)
 {
-       int err, s, c1, p1, zero = 0, sent, recvd, avail;
+       int err, c1, p1, zero = 0, sent, recvd, avail;
        char snd[256] = "0123456789";
        char rcv[256] = "0";
 
-       s = socket_loopback(AF_INET, SOCK_STREAM);
-       if (!ASSERT_GT(s, -1, "socket_loopback(s)"))
-               return;
-
-       err = create_pair(s, AF_INET, SOCK_STREAM, &c1, &p1);
-       if (!ASSERT_OK(err, "create_pairs(s)"))
+       err = create_pair(AF_INET, SOCK_STREAM, &c1, &p1);
+       if (!ASSERT_OK(err, "create_pair()"))
                return;
 
        err = bpf_map_update_elem(map, &zero, &c1, BPF_NOEXIST);
index e880f97bc44d35d1609356129f6b31b268eb435e..77b73333f091aa0d049a80ad751ce1e091745678 100644 (file)
@@ -3,6 +3,9 @@
 
 #include <linux/vm_sockets.h>
 
+/* include/linux/net.h */
+#define SOCK_TYPE_MASK 0xf
+
 #define IO_TIMEOUT_SEC 30
 #define MAX_STRERR_LEN 256
 #define MAX_TEST_NAME 80
@@ -312,54 +315,6 @@ static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
        return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
 }
 
-static inline int create_pair(int s, int family, int sotype, int *c, int *p)
-{
-       struct sockaddr_storage addr;
-       socklen_t len;
-       int err = 0;
-
-       len = sizeof(addr);
-       err = xgetsockname(s, sockaddr(&addr), &len);
-       if (err)
-               return err;
-
-       *c = xsocket(family, sotype, 0);
-       if (*c < 0)
-               return errno;
-       err = xconnect(*c, sockaddr(&addr), len);
-       if (err) {
-               err = errno;
-               goto close_cli0;
-       }
-
-       *p = xaccept_nonblock(s, NULL, NULL);
-       if (*p < 0) {
-               err = errno;
-               goto close_cli0;
-       }
-       return err;
-close_cli0:
-       close(*c);
-       return err;
-}
-
-static inline int create_socket_pairs(int s, int family, int sotype,
-                                     int *c0, int *c1, int *p0, int *p1)
-{
-       int err;
-
-       err = create_pair(s, family, sotype, c0, p0);
-       if (err)
-               return err;
-
-       err = create_pair(s, family, sotype, c1, p1);
-       if (err) {
-               close(*c0);
-               close(*p0);
-       }
-       return err;
-}
-
 static inline int enable_reuseport(int s, int progfd)
 {
        int err, one = 1;
@@ -412,5 +367,92 @@ static inline int socket_loopback(int family, int sotype)
        return socket_loopback_reuseport(family, sotype, -1);
 }
 
+static inline int create_pair(int family, int sotype, int *p0, int *p1)
+{
+       struct sockaddr_storage addr;
+       socklen_t len = sizeof(addr);
+       int s, c, p, err;
+
+       s = socket_loopback(family, sotype);
+       if (s < 0)
+               return s;
+
+       err = xgetsockname(s, sockaddr(&addr), &len);
+       if (err)
+               goto close_s;
+
+       c = xsocket(family, sotype, 0);
+       if (c < 0) {
+               err = c;
+               goto close_s;
+       }
+
+       err = connect(c, sockaddr(&addr), len);
+       if (err) {
+               if (errno != EINPROGRESS) {
+                       FAIL_ERRNO("connect");
+                       goto close_c;
+               }
+
+               err = poll_connect(c, IO_TIMEOUT_SEC);
+               if (err) {
+                       FAIL_ERRNO("poll_connect");
+                       goto close_c;
+               }
+       }
+
+       switch (sotype & SOCK_TYPE_MASK) {
+       case SOCK_DGRAM:
+               err = xgetsockname(c, sockaddr(&addr), &len);
+               if (err)
+                       goto close_c;
+
+               err = xconnect(s, sockaddr(&addr), len);
+               if (!err) {
+                       *p0 = s;
+                       *p1 = c;
+                       return err;
+               }
+               break;
+       case SOCK_STREAM:
+       case SOCK_SEQPACKET:
+               p = xaccept_nonblock(s, NULL, NULL);
+               if (p >= 0) {
+                       *p0 = p;
+                       *p1 = c;
+                       goto close_s;
+               }
+
+               err = p;
+               break;
+       default:
+               FAIL("Unsupported socket type %#x", sotype);
+               err = -EOPNOTSUPP;
+       }
+
+close_c:
+       close(c);
+close_s:
+       close(s);
+       return err;
+}
+
+static inline int create_socket_pairs(int s, int family, int sotype,
+                                     int *c0, int *c1, int *p0, int *p1)
+{
+       int err;
+
+       err = create_pair(family, sotype, c0, p0);
+       if (err)
+               return err;
+
+       err = create_pair(family, sotype, c1, p1);
+       if (err) {
+               close(*c0);
+               close(*p0);
+       }
+
+       return err;
+}
 
 #endif // __SOCKMAP_HELPERS__