]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
shared: don't unconditionally set SOCK_STREAM as type in socket_address_parse()
authorZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Thu, 3 Sep 2020 21:51:21 +0000 (23:51 +0200)
committerZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Wed, 9 Sep 2020 22:46:44 +0000 (00:46 +0200)
We would set .type to a fake value. All real callers (outside of tests)
immediately overwrite .type with a proper value after calling
socket_address_parse(). So let's not set it and adjust the few places
that relied on it being set to the fake value.

socket_address_parse() is modernized to only set the output argument on
success.

src/basic/socket-util.c
src/shared/socket-netlink.c
src/test/test-socket-netlink.c

index fa51997581204ed7ccc64f0ac1dd5c6837bfad42..85edc83cae341770460243c4cfc664de6ed56888 100644 (file)
@@ -68,7 +68,7 @@ int socket_address_verify(const SocketAddress *a, bool strict) {
                 if (a->sockaddr.in.sin_port == 0)
                         return -EINVAL;
 
-                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
+                if (!IN_SET(a->type, 0, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -80,7 +80,7 @@ int socket_address_verify(const SocketAddress *a, bool strict) {
                 if (a->sockaddr.in6.sin6_port == 0)
                         return -EINVAL;
 
-                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
+                if (!IN_SET(a->type, 0, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -114,7 +114,7 @@ int socket_address_verify(const SocketAddress *a, bool strict) {
                         }
                 }
 
-                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET))
+                if (!IN_SET(a->type, 0, SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET))
                         return -EINVAL;
 
                 return 0;
@@ -124,7 +124,7 @@ int socket_address_verify(const SocketAddress *a, bool strict) {
                 if (a->size != sizeof(struct sockaddr_nl))
                         return -EINVAL;
 
-                if (!IN_SET(a->type, SOCK_RAW, SOCK_DGRAM))
+                if (!IN_SET(a->type, 0, SOCK_RAW, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -133,7 +133,7 @@ int socket_address_verify(const SocketAddress *a, bool strict) {
                 if (a->size != sizeof(struct sockaddr_vm))
                         return -EINVAL;
 
-                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
+                if (!IN_SET(a->type, 0, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
index 25bc0167a63999c08741c642666b533741e3f156..198892b007f526c8127fdc83d959162fa222b5d7 100644 (file)
@@ -62,29 +62,25 @@ int socket_address_parse(SocketAddress *a, const char *s) {
         assert(a);
         assert(s);
 
-        *a = (SocketAddress) {
-                .type = SOCK_STREAM,
-        };
-
         if (*s == '/') {
                 /* AF_UNIX socket */
-                size_t l;
 
-                l = strlen(s);
+                size_t l = strlen(s);
                 if (l >= sizeof(a->sockaddr.un.sun_path)) /* Note that we refuse non-NUL-terminated sockets when
                                                            * parsing (the kernel itself is less strict here in what it
                                                            * accepts) */
                         return -EINVAL;
 
-                a->sockaddr.un.sun_family = AF_UNIX;
+                *a = (SocketAddress) {
+                        .sockaddr.un.sun_family = AF_UNIX,
+                        .size = offsetof(struct sockaddr_un, sun_path) + l + 1,
+                };
                 memcpy(a->sockaddr.un.sun_path, s, l);
-                a->size = offsetof(struct sockaddr_un, sun_path) + l + 1;
 
         } else if (*s == '@') {
                 /* Abstract AF_UNIX socket */
-                size_t l;
 
-                l = strlen(s+1);
+                size_t l = strlen(s+1);
                 if (l >= sizeof(a->sockaddr.un.sun_path) - 1) /* Note that we refuse non-NUL-terminated sockets here
                                                                * when parsing, even though abstract namespace sockets
                                                                * explicitly allow embedded NUL bytes and don't consider
@@ -92,14 +88,16 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                                                                * sockets. */
                         return -EINVAL;
 
-                a->sockaddr.un.sun_family = AF_UNIX;
+                *a = (SocketAddress) {
+                        .sockaddr.un.sun_family = AF_UNIX,
+                        .size = offsetof(struct sockaddr_un, sun_path) + 1 + l,
+                };
                 memcpy(a->sockaddr.un.sun_path+1, s+1, l);
-                a->size = offsetof(struct sockaddr_un, sun_path) + 1 + l;
 
         } else if (startswith(s, "vsock:")) {
                 /* AF_VSOCK socket in vsock:cid:port notation */
                 const char *cid_start = s + STRLEN("vsock:");
-                unsigned port;
+                unsigned port, cid;
 
                 e = strchr(cid_start, ':');
                 if (!e)
@@ -113,16 +111,22 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                 if (!n)
                         return -ENOMEM;
 
-                if (!isempty(n)) {
-                        r = safe_atou(n, &a->sockaddr.vm.svm_cid);
+                if (isempty(n))
+                        cid = VMADDR_CID_ANY;
+                else {
+                        r = safe_atou(n, &cid);
                         if (r < 0)
                                 return r;
-                } else
-                        a->sockaddr.vm.svm_cid = VMADDR_CID_ANY;
+                }
 
-                a->sockaddr.vm.svm_family = AF_VSOCK;
-                a->sockaddr.vm.svm_port = port;
-                a->size = sizeof(struct sockaddr_vm);
+                *a = (SocketAddress) {
+                        .sockaddr.vm = {
+                                .svm_cid = cid,
+                                .svm_family = AF_VSOCK,
+                                .svm_port = port,
+                        },
+                        .size = sizeof(struct sockaddr_vm),
+                };
 
         } else {
                 uint16_t port;
@@ -132,17 +136,24 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         return r; /* Valid port syntax, but the numerical value is wrong for a port. */
                 if (r >= 0) {
                         /* Just a port */
-                        if (socket_ipv6_is_supported()) {
-                                a->sockaddr.in6.sin6_family = AF_INET6;
-                                a->sockaddr.in6.sin6_port = htobe16(port);
-                                a->sockaddr.in6.sin6_addr = in6addr_any;
-                                a->size = sizeof(struct sockaddr_in6);
-                        } else {
-                                a->sockaddr.in.sin_family = AF_INET;
-                                a->sockaddr.in.sin_port = htobe16(port);
-                                a->sockaddr.in.sin_addr.s_addr = INADDR_ANY;
-                                a->size = sizeof(struct sockaddr_in);
-                        }
+                        if (socket_ipv6_is_supported())
+                                *a = (SocketAddress) {
+                                        .sockaddr.in6 = {
+                                                .sin6_family = AF_INET6,
+                                                .sin6_port = htobe16(port),
+                                                .sin6_addr = in6addr_any,
+                                        },
+                                        .size = sizeof(struct sockaddr_in6),
+                                };
+                        else
+                                *a = (SocketAddress) {
+                                        .sockaddr.in = {
+                                                .sin_family = AF_INET,
+                                                .sin_port = htobe16(port),
+                                                .sin_addr.s_addr = INADDR_ANY,
+                                        },
+                                        .size = sizeof(struct sockaddr_in),
+                                };
 
                 } else {
                         union in_addr_union address;
@@ -155,21 +166,25 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         if (port == 0) /* No port, no go. */
                                 return -EINVAL;
 
-                        if (family == AF_INET) {
-                                a->sockaddr.in = (struct sockaddr_in) {
-                                        .sin_family = AF_INET,
-                                        .sin_addr = address.in,
-                                        .sin_port = htobe16(port),
+                        if (family == AF_INET)
+                                *a = (SocketAddress) {
+                                        .sockaddr.in = {
+                                                .sin_family = AF_INET,
+                                                .sin_addr = address.in,
+                                                .sin_port = htobe16(port),
+                                        },
+                                        .size = sizeof(struct sockaddr_in),
                                 };
-                                a->size = sizeof(struct sockaddr_in);
-                        } else if (family == AF_INET6) {
-                                a->sockaddr.in6 = (struct sockaddr_in6) {
-                                        .sin6_family = AF_INET6,
-                                        .sin6_addr = address.in6,
-                                        .sin6_port = htobe16(port),
+                        else if (family == AF_INET6)
+                                *a = (SocketAddress) {
+                                        .sockaddr.in6 = {
+                                                .sin6_family = AF_INET6,
+                                                .sin6_addr = address.in6,
+                                                .sin6_port = htobe16(port),
+                                        },
+                                        .size = sizeof(struct sockaddr_in6),
                                 };
-                                a->size = sizeof(struct sockaddr_in6);
-                        } else
+                        else
                                 assert_not_reached("Family quarrel");
                 }
         }
index 06a08cd9d790f43313869d1c82ae765902470360..b87cb7b126880e5521eafde90362815a17f25a63 100644 (file)
@@ -17,6 +17,7 @@ static void test_socket_address_parse_one(const char *in, int ret, int family, c
                 if (r < 0)
                         log_error_errno(r, "Printing failed for \"%s\": %m", in);
                 assert(r >= 0);
+                assert_se(a.type == 0);
         }
 
         log_info("\"%s\" → %s %d → \"%s\" (expect %d / \"%s\")",
@@ -206,9 +207,13 @@ static void test_socket_address_is(void) {
         log_info("/* %s */", __func__);
 
         assert_se(socket_address_parse(&a, "192.168.1.1:8888") >= 0);
-        assert_se(socket_address_is(&a, "192.168.1.1:8888", SOCK_STREAM));
+        assert_se( socket_address_is(&a, "192.168.1.1:8888", 0 /* unspecified yet */));
+        assert_se(!socket_address_is(&a, "route", 0));
         assert_se(!socket_address_is(&a, "route", SOCK_STREAM));
         assert_se(!socket_address_is(&a, "192.168.1.1:8888", SOCK_RAW));
+        assert_se(!socket_address_is(&a, "192.168.1.1:8888", SOCK_STREAM));
+        a.type = SOCK_STREAM;
+        assert_se( socket_address_is(&a, "192.168.1.1:8888", SOCK_STREAM));
 }
 
 static void test_socket_address_is_netlink(void) {
@@ -217,7 +222,7 @@ static void test_socket_address_is_netlink(void) {
         log_info("/* %s */", __func__);
 
         assert_se(socket_address_parse_netlink(&a, "route 10") >= 0);
-        assert_se(socket_address_is_netlink(&a, "route 10"));
+        assert_se( socket_address_is_netlink(&a, "route 10"));
         assert_se(!socket_address_is_netlink(&a, "192.168.1.1:8888"));
         assert_se(!socket_address_is_netlink(&a, "route 1"));
 }