]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/basic/socket-util.c
macro: introduce TAKE_PTR() macro
[thirdparty/systemd.git] / src / basic / socket-util.c
index 56e0e8e434ea656b0089a56534730eef595232b2..fd26ae713798a31965de37bb4c2dab97c971c4c3 100644 (file)
@@ -41,6 +41,7 @@
 #include "missing.h"
 #include "parse-util.h"
 #include "path-util.h"
+#include "process-util.h"
 #include "socket-util.h"
 #include "string-table.h"
 #include "string-util.h"
@@ -50,7 +51,7 @@
 #include "util.h"
 
 #if ENABLE_IDN
-#  define IDN_FLAGS (NI_IDN|NI_IDN_USE_STD3_ASCII_RULES)
+#  define IDN_FLAGS NI_IDN
 #else
 #  define IDN_FLAGS 0
 #endif
@@ -68,7 +69,6 @@ DEFINE_STRING_TABLE_LOOKUP(socket_address_type, int);
 
 int socket_address_parse(SocketAddress *a, const char *s) {
         char *e, *n;
-        unsigned u;
         int r;
 
         assert(a);
@@ -78,6 +78,8 @@ int socket_address_parse(SocketAddress *a, const char *s) {
         a->type = SOCK_STREAM;
 
         if (*s == '[') {
+                uint16_t port;
+
                 /* IPv6 in [x:.....:z]:p notation */
 
                 e = strchr(s+1, ']');
@@ -95,15 +97,12 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         return -EINVAL;
 
                 e++;
-                r = safe_atou(e, &u);
+                r = parse_ip_port(e, &port);
                 if (r < 0)
                         return r;
 
-                if (u <= 0 || u > 0xFFFF)
-                        return -EINVAL;
-
                 a->sockaddr.in6.sin6_family = AF_INET6;
-                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                a->sockaddr.in6.sin6_port = htobe16(port);
                 a->size = sizeof(struct sockaddr_in6);
 
         } else if (*s == '/') {
@@ -134,12 +133,13 @@ int socket_address_parse(SocketAddress *a, const char *s) {
         } else if (startswith(s, "vsock:")) {
                 /* AF_VSOCK socket in vsock:cid:port notation */
                 const char *cid_start = s + STRLEN("vsock:");
+                unsigned port;
 
                 e = strchr(cid_start, ':');
                 if (!e)
                         return -EINVAL;
 
-                r = safe_atou(e+1, &u);
+                r = safe_atou(e+1, &port);
                 if (r < 0)
                         return r;
 
@@ -152,19 +152,18 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         a->sockaddr.vm.svm_cid = VMADDR_CID_ANY;
 
                 a->sockaddr.vm.svm_family = AF_VSOCK;
-                a->sockaddr.vm.svm_port = u;
+                a->sockaddr.vm.svm_port = port;
                 a->size = sizeof(struct sockaddr_vm);
 
         } else {
+                uint16_t port;
+
                 e = strchr(s, ':');
                 if (e) {
-                        r = safe_atou(e+1, &u);
+                        r = parse_ip_port(e + 1, &port);
                         if (r < 0)
                                 return r;
 
-                        if (u <= 0 || u > 0xFFFF)
-                                return -EINVAL;
-
                         n = strndupa(s, e-s);
 
                         /* IPv4 in w.x.y.z:p notation? */
@@ -175,7 +174,7 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                         if (r > 0) {
                                 /* Gotcha, it's a traditional IPv4 address */
                                 a->sockaddr.in.sin_family = AF_INET;
-                                a->sockaddr.in.sin_port = htobe16((uint16_t)u);
+                                a->sockaddr.in.sin_port = htobe16(port);
                                 a->size = sizeof(struct sockaddr_in);
                         } else {
                                 unsigned idx;
@@ -189,7 +188,7 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                                         return -EINVAL;
 
                                 a->sockaddr.in6.sin6_family = AF_INET6;
-                                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                                a->sockaddr.in6.sin6_port = htobe16(port);
                                 a->sockaddr.in6.sin6_scope_id = idx;
                                 a->sockaddr.in6.sin6_addr = in6addr_any;
                                 a->size = sizeof(struct sockaddr_in6);
@@ -197,21 +196,18 @@ int socket_address_parse(SocketAddress *a, const char *s) {
                 } else {
 
                         /* Just a port */
-                        r = safe_atou(s, &u);
+                        r = parse_ip_port(s, &port);
                         if (r < 0)
                                 return r;
 
-                        if (u <= 0 || u > 0xFFFF)
-                                return -EINVAL;
-
                         if (socket_ipv6_is_supported()) {
                                 a->sockaddr.in6.sin6_family = AF_INET6;
-                                a->sockaddr.in6.sin6_port = htobe16((uint16_t)u);
+                                a->sockaddr.in6.sin6_port = htobe16(port);
                                 a->sockaddr.in6.sin6_addr = in6addr_any;
                                 a->size = sizeof(struct sockaddr_in6);
                         } else {
                                 a->sockaddr.in.sin_family = AF_INET;
-                                a->sockaddr.in.sin_port = htobe16((uint16_t)u);
+                                a->sockaddr.in.sin_port = htobe16(port);
                                 a->sockaddr.in.sin_addr.s_addr = INADDR_ANY;
                                 a->size = sizeof(struct sockaddr_in);
                         }
@@ -762,19 +758,6 @@ int socknameinfo_pretty(union sockaddr_union *sa, socklen_t salen, char **_ret)
         return 0;
 }
 
-int getnameinfo_pretty(int fd, char **ret) {
-        union sockaddr_union sa;
-        socklen_t salen = sizeof(sa);
-
-        assert(fd >= 0);
-        assert(ret);
-
-        if (getsockname(fd, &sa.sa, &salen) < 0)
-                return -errno;
-
-        return socknameinfo_pretty(&sa, salen, ret);
-}
-
 int socket_address_unlink(SocketAddress *a) {
         assert(a);
 
@@ -967,55 +950,43 @@ int getpeercred(int fd, struct ucred *ucred) {
         if (n != sizeof(struct ucred))
                 return -EIO;
 
-        /* Check if the data is actually useful and not suppressed due
-         * to namespacing issues */
-        if (u.pid <= 0)
-                return -ENODATA;
-        if (u.uid == UID_INVALID)
-                return -ENODATA;
-        if (u.gid == GID_INVALID)
+        /* Check if the data is actually useful and not suppressed due to namespacing issues */
+        if (!pid_is_valid(u.pid))
                 return -ENODATA;
 
+        /* Note that we don't check UID/GID here, as namespace translation works differently there: instead of
+         * receiving in "invalid" user/group we get the overflow UID/GID. */
+
         *ucred = u;
         return 0;
 }
 
 int getpeersec(int fd, char **ret) {
+        _cleanup_free_ char *s = NULL;
         socklen_t n = 64;
-        char *s;
-        int r;
 
         assert(fd >= 0);
         assert(ret);
 
-        s = new0(char, n);
-        if (!s)
-                return -ENOMEM;
+        for (;;) {
+                s = new0(char, n+1);
+                if (!s)
+                        return -ENOMEM;
 
-        r = getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n);
-        if (r < 0) {
-                free(s);
+                if (getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n) >= 0)
+                        break;
 
                 if (errno != ERANGE)
                         return -errno;
 
-                s = new0(char, n);
-                if (!s)
-                        return -ENOMEM;
-
-                r = getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n);
-                if (r < 0) {
-                        free(s);
-                        return -errno;
-                }
+                s = mfree(s);
         }
 
-        if (isempty(s)) {
-                free(s);
+        if (isempty(s))
                 return -EOPNOTSUPP;
-        }
 
-        *ret = s;
+        *ret = TAKE_PTR(s);
+
         return 0;
 }
 
@@ -1023,7 +994,7 @@ int getpeergroups(int fd, gid_t **ret) {
         socklen_t n = sizeof(gid_t) * 64;
         _cleanup_free_ gid_t *d = NULL;
 
-        assert(fd);
+        assert(fd >= 0);
         assert(ret);
 
         for (;;) {