]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/basic/socket-util.c
Merge pull request #8143 from yuwata/drop-unused-func
[thirdparty/systemd.git] / src / basic / socket-util.c
index f1b2256b2178cc06b90b3310f43b92e899444328..b91b093132e787d722e45f4fdb1cff05c85901e7 100644 (file)
@@ -1,3 +1,4 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
@@ -40,6 +41,7 @@
 #include "missing.h"
 #include "parse-util.h"
 #include "path-util.h"
+#include "process-util.h"
 #include "socket-util.h"
 #include "string-table.h"
 #include "string-util.h"
 #include "utf8.h"
 #include "util.h"
 
-#ifdef ENABLE_IDN
-#  define IDN_FLAGS (NI_IDN|NI_IDN_USE_STD3_ASCII_RULES)
+#if ENABLE_IDN
+#  define IDN_FLAGS NI_IDN
 #else
 #  define IDN_FLAGS 0
 #endif
 
+static const char* const socket_address_type_table[] = {
+        [SOCK_STREAM] = "Stream",
+        [SOCK_DGRAM] = "Datagram",
+        [SOCK_RAW] = "Raw",
+        [SOCK_RDM] = "ReliableDatagram",
+        [SOCK_SEQPACKET] = "SequentialPacket",
+        [SOCK_DCCP] = "DatagramCongestionControl",
+};
+
+DEFINE_STRING_TABLE_LOOKUP(socket_address_type, int);
+
 int socket_address_parse(SocketAddress *a, const char *s) {
         char *e, *n;
-        unsigned u;
         int r;
 
         assert(a);
@@ -66,6 +78,8 @@ int socket_address_parse(SocketAddress *a, const char *s) {
         a->type = SOCK_STREAM;
 
         if (*s == '[') {
+                uint16_t port;
+
                 /* IPv6 in [x:.....:z]:p notation */
 
                 e = strchr(s+1, ']');
@@ -83,15 +97,12 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         return -EINVAL;
 
                 e++;
-                r = safe_atou(e, &u);
+                r = parse_ip_port(e, &port);
                 if (r < 0)
                         return r;
 
-                if (u <= 0 || u > 0xFFFF)
-                        return -EINVAL;
-
                 a->sockaddr.in6.sin6_family = AF_INET6;
-                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                a->sockaddr.in6.sin6_port = htobe16(port);
                 a->size = sizeof(struct sockaddr_in6);
 
         } else if (*s == '/') {
@@ -121,13 +132,14 @@ int socket_address_parse(SocketAddress *a, const char *s) {
 
         } else if (startswith(s, "vsock:")) {
                 /* AF_VSOCK socket in vsock:cid:port notation */
-                const char *cid_start = s + strlen("vsock:");
+                const char *cid_start = s + STRLEN("vsock:");
+                unsigned port;
 
                 e = strchr(cid_start, ':');
                 if (!e)
                         return -EINVAL;
 
-                r = safe_atou(e+1, &u);
+                r = safe_atou(e+1, &port);
                 if (r < 0)
                         return r;
 
@@ -140,19 +152,18 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         a->sockaddr.vm.svm_cid = VMADDR_CID_ANY;
 
                 a->sockaddr.vm.svm_family = AF_VSOCK;
-                a->sockaddr.vm.svm_port = u;
+                a->sockaddr.vm.svm_port = port;
                 a->size = sizeof(struct sockaddr_vm);
 
         } else {
+                uint16_t port;
+
                 e = strchr(s, ':');
                 if (e) {
-                        r = safe_atou(e+1, &u);
+                        r = parse_ip_port(e + 1, &port);
                         if (r < 0)
                                 return r;
 
-                        if (u <= 0 || u > 0xFFFF)
-                                return -EINVAL;
-
                         n = strndupa(s, e-s);
 
                         /* IPv4 in w.x.y.z:p notation? */
@@ -163,7 +174,7 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         if (r > 0) {
                                 /* Gotcha, it's a traditional IPv4 address */
                                 a->sockaddr.in.sin_family = AF_INET;
-                                a->sockaddr.in.sin_port = htobe16((uint16_t)u);
+                                a->sockaddr.in.sin_port = htobe16(port);
                                 a->size = sizeof(struct sockaddr_in);
                         } else {
                                 unsigned idx;
@@ -177,7 +188,7 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                                         return -EINVAL;
 
                                 a->sockaddr.in6.sin6_family = AF_INET6;
-                                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                                a->sockaddr.in6.sin6_port = htobe16(port);
                                 a->sockaddr.in6.sin6_scope_id = idx;
                                 a->sockaddr.in6.sin6_addr = in6addr_any;
                                 a->size = sizeof(struct sockaddr_in6);
@@ -185,21 +196,18 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                 } else {
 
                         /* Just a port */
-                        r = safe_atou(s, &u);
+                        r = parse_ip_port(s, &port);
                         if (r < 0)
                                 return r;
 
-                        if (u <= 0 || u > 0xFFFF)
-                                return -EINVAL;
-
                         if (socket_ipv6_is_supported()) {
                                 a->sockaddr.in6.sin6_family = AF_INET6;
-                                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                                a->sockaddr.in6.sin6_port = htobe16(port);
                                 a->sockaddr.in6.sin6_addr = in6addr_any;
                                 a->size = sizeof(struct sockaddr_in6);
                         } else {
                                 a->sockaddr.in.sin_family = AF_INET;
-                                a->sockaddr.in.sin_port = htobe16((uint16_t)u);
+                                a->sockaddr.in.sin_port = htobe16(port);
                                 a->sockaddr.in.sin_addr.s_addr = INADDR_ANY;
                                 a->size = sizeof(struct sockaddr_in);
                         }
@@ -268,7 +276,7 @@ int socket_address_verify(const SocketAddress *a) {
                 if (a->sockaddr.in.sin_port == 0)
                         return -EINVAL;
 
-                if (a->type != SOCK_STREAM && a->type != SOCK_DGRAM)
+                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -280,7 +288,7 @@ int socket_address_verify(const SocketAddress *a) {
                 if (a->sockaddr.in6.sin6_port == 0)
                         return -EINVAL;
 
-                if (a->type != SOCK_STREAM && a->type != SOCK_DGRAM)
+                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -304,7 +312,7 @@ int socket_address_verify(const SocketAddress *a) {
                         }
                 }
 
-                if (a->type != SOCK_STREAM && a->type != SOCK_DGRAM && a->type != SOCK_SEQPACKET)
+                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET))
                         return -EINVAL;
 
                 return 0;
@@ -314,7 +322,7 @@ int socket_address_verify(const SocketAddress *a) {
                 if (a->size != sizeof(struct sockaddr_nl))
                         return -EINVAL;
 
-                if (a->type != SOCK_RAW && a->type != SOCK_DGRAM)
+                if (!IN_SET(a->type, SOCK_RAW, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -323,7 +331,7 @@ int socket_address_verify(const SocketAddress *a) {
                 if (a->size != sizeof(struct sockaddr_vm))
                         return -EINVAL;
 
-                if (a->type != SOCK_STREAM && a->type != SOCK_DGRAM)
+                if (!IN_SET(a->type, SOCK_STREAM, SOCK_DGRAM))
                         return -EINVAL;
 
                 return 0;
@@ -364,8 +372,7 @@ bool socket_address_can_accept(const SocketAddress *a) {
         assert(a);
 
         return
-                a->type == SOCK_STREAM ||
-                a->type == SOCK_SEQPACKET;
+                IN_SET(a->type, SOCK_STREAM, SOCK_SEQPACKET);
 }
 
 bool socket_address_equal(const SocketAddress *a, const SocketAddress *b) {
@@ -528,22 +535,25 @@ bool socket_address_matches_fd(const SocketAddress *a, int fd) {
         return socket_address_equal(a, &b);
 }
 
-int sockaddr_port(const struct sockaddr *_sa, unsigned *port) {
+int sockaddr_port(const struct sockaddr *_sa, unsigned *ret_port) {
         union sockaddr_union *sa = (union sockaddr_union*) _sa;
 
+        /* Note, this returns the port as 'unsigned' rather than 'uint16_t', as AF_VSOCK knows larger ports */
+
         assert(sa);
 
         switch (sa->sa.sa_family) {
+
         case AF_INET:
-                *port = be16toh(sa->in.sin_port);
+                *ret_port = be16toh(sa->in.sin_port);
                 return 0;
 
         case AF_INET6:
-                *port = be16toh(sa->in6.sin6_port);
+                *ret_port = be16toh(sa->in6.sin6_port);
                 return 0;
 
         case AF_VSOCK:
-                *port = sa->vm.svm_port;
+                *ret_port = sa->vm.svm_port;
                 return 0;
 
         default:
@@ -748,19 +758,6 @@ int socknameinfo_pretty(union sockaddr_union *sa, socklen_t salen, char **_ret)
         return 0;
 }
 
-int getnameinfo_pretty(int fd, char **ret) {
-        union sockaddr_union sa;
-        socklen_t salen = sizeof(sa);
-
-        assert(fd >= 0);
-        assert(ret);
-
-        if (getsockname(fd, &sa.sa, &salen) < 0)
-                return -errno;
-
-        return socknameinfo_pretty(&sa, salen, ret);
-}
-
 int socket_address_unlink(SocketAddress *a) {
         assert(a);
 
@@ -807,6 +804,18 @@ static const char* const socket_address_bind_ipv6_only_table[_SOCKET_ADDRESS_BIN
 
 DEFINE_STRING_TABLE_LOOKUP(socket_address_bind_ipv6_only, SocketAddressBindIPv6Only);
 
+SocketAddressBindIPv6Only parse_socket_address_bind_ipv6_only_or_bool(const char *n) {
+        int r;
+
+        r = parse_boolean(n);
+        if (r > 0)
+                return SOCKET_ADDRESS_IPV6_ONLY;
+        if (r == 0)
+                return SOCKET_ADDRESS_BOTH;
+
+        return socket_address_bind_ipv6_only_from_string(n);
+}
+
 bool sockaddr_equal(const union sockaddr_union *a, const union sockaddr_union *b) {
         assert(a);
         assert(b);
@@ -893,7 +902,7 @@ bool ifname_valid(const char *p) {
                 if ((unsigned char) *p <= 32U)
                         return false;
 
-                if (*p == ':' || *p == '/')
+                if (IN_SET(*p, ':', '/'))
                         return false;
 
                 numeric = numeric && (*p >= '0' && *p <= '9');
@@ -941,58 +950,80 @@ int getpeercred(int fd, struct ucred *ucred) {
         if (n != sizeof(struct ucred))
                 return -EIO;
 
-        /* Check if the data is actually useful and not suppressed due
-         * to namespacing issues */
-        if (u.pid <= 0)
-                return -ENODATA;
-        if (u.uid == UID_INVALID)
-                return -ENODATA;
-        if (u.gid == GID_INVALID)
+        /* Check if the data is actually useful and not suppressed due to namespacing issues */
+        if (!pid_is_valid(u.pid))
                 return -ENODATA;
 
+        /* Note that we don't check UID/GID here, as namespace translation works differently there: instead of
+         * receiving in "invalid" user/group we get the overflow UID/GID. */
+
         *ucred = u;
         return 0;
 }
 
 int getpeersec(int fd, char **ret) {
+        _cleanup_free_ char *s = NULL;
         socklen_t n = 64;
-        char *s;
-        int r;
 
         assert(fd >= 0);
         assert(ret);
 
-        s = new0(char, n);
-        if (!s)
-                return -ENOMEM;
+        for (;;) {
+                s = new0(char, n+1);
+                if (!s)
+                        return -ENOMEM;
 
-        r = getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n);
-        if (r < 0) {
-                free(s);
+                if (getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n) >= 0)
+                        break;
 
                 if (errno != ERANGE)
                         return -errno;
 
-                s = new0(char, n);
-                if (!s)
-                        return -ENOMEM;
-
-                r = getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n);
-                if (r < 0) {
-                        free(s);
-                        return -errno;
-                }
+                s = mfree(s);
         }
 
-        if (isempty(s)) {
-                free(s);
+        if (isempty(s))
                 return -EOPNOTSUPP;
-        }
 
         *ret = s;
+        s = NULL;
+
         return 0;
 }
 
+int getpeergroups(int fd, gid_t **ret) {
+        socklen_t n = sizeof(gid_t) * 64;
+        _cleanup_free_ gid_t *d = NULL;
+
+        assert(fd >= 0);
+        assert(ret);
+
+        for (;;) {
+                d = malloc(n);
+                if (!d)
+                        return -ENOMEM;
+
+                if (getsockopt(fd, SOL_SOCKET, SO_PEERGROUPS, d, &n) >= 0)
+                        break;
+
+                if (errno != ERANGE)
+                        return -errno;
+
+                d = mfree(d);
+        }
+
+        assert_se(n % sizeof(gid_t) == 0);
+        n /= sizeof(gid_t);
+
+        if ((socklen_t) (int) n != n)
+                return -E2BIG;
+
+        *ret = d;
+        d = NULL;
+
+        return (int) n;
+}
+
 int send_one_fd_sa(
                 int transport_fd,
                 int fd,
@@ -1081,7 +1112,7 @@ ssize_t next_datagram_size_fd(int fd) {
 
         l = recv(fd, NULL, 0, MSG_PEEK|MSG_TRUNC);
         if (l < 0) {
-                if (errno == EOPNOTSUPP || errno == EFAULT)
+                if (IN_SET(errno, EOPNOTSUPP, EFAULT))
                         goto fallback;
 
                 return -errno;