]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
core/socket: rework SocketPeer refcounting
authorZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Fri, 5 Aug 2016 01:42:23 +0000 (21:42 -0400)
committerZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Fri, 5 Aug 2016 12:12:31 +0000 (08:12 -0400)
Make functions and definitions that don't need to be shared local to
socket.c.

src/core/socket.c
src/core/socket.h

index d3b9a755478056c8fdc8c093636392a75da5adbe..972d494dbc1a73b7eafed41bff9044dfaae0af0e 100644 (file)
 #include "user-util.h"
 #include "in-addr-util.h"
 
+struct SocketPeer {
+        unsigned n_ref;
+
+        Socket *socket;
+        union sockaddr_union peer;
+};
+
 static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
         [SOCKET_DEAD] = UNIT_INACTIVE,
         [SOCKET_START_PRE] = UNIT_ACTIVATING,
@@ -78,9 +85,6 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
 static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
 static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
 
-SocketPeer *socket_peer_new(void);
-int socket_find_peer(Socket *s, int fd, SocketPeer **p);
-
 static void socket_init(Unit *u) {
         Socket *s = SOCKET(u);
 
@@ -482,10 +486,11 @@ static void peer_address_hash_func(const void *p, struct siphash *state) {
         const SocketPeer *s = p;
 
         assert(s);
+        assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
 
         if (s->peer.sa.sa_family == AF_INET)
                 siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
-        else if (s->peer.sa.sa_family == AF_INET6)
+        else
                 siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
 }
 
@@ -503,8 +508,7 @@ static int peer_address_compare_func(const void *a, const void *b) {
         case AF_INET6:
                 return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
         }
-
-        return -1;
+        assert_not_reached("Black sheep in the family!");
 }
 
 const struct hash_ops peer_address_hash_ops = {
@@ -537,6 +541,87 @@ static int socket_load(Unit *u) {
         return socket_verify(s);
 }
 
+static SocketPeer *socket_peer_new(void) {
+        SocketPeer *p;
+
+        p = new0(SocketPeer, 1);
+        if (!p)
+                return NULL;
+
+        p->n_ref = 1;
+
+        return p;
+}
+
+SocketPeer *socket_peer_ref(SocketPeer *p) {
+        if (!p)
+                return NULL;
+
+        assert(p->n_ref > 0);
+        p->n_ref++;
+
+        return p;
+}
+
+SocketPeer *socket_peer_unref(SocketPeer *p) {
+        if (!p)
+                return NULL;
+
+        assert(p->n_ref > 0);
+
+        p->n_ref--;
+
+        if (p->n_ref > 0)
+                return NULL;
+
+        if (p->socket)
+                set_remove(p->socket->peers_by_address, p);
+
+        return mfree(p);
+}
+
+static int socket_acquire_peer(Socket *s, int fd, SocketPeer **p) {
+        _cleanup_(socket_peer_unrefp) SocketPeer *remote = NULL;
+        SocketPeer sa = {}, *i;
+        socklen_t salen = sizeof(sa.peer);
+        int r;
+
+        assert(fd >= 0);
+        assert(s);
+
+        r = getpeername(fd, &sa.peer.sa, &salen);
+        if (r < 0)
+                return log_error_errno(errno, "getpeername failed: %m");
+
+        if (!IN_SET(sa.peer.sa.sa_family, AF_INET, AF_INET6)) {
+                *p = NULL;
+                return 0;
+        }
+
+        i = set_get(s->peers_by_address, &sa);
+        if (i) {
+                *p = socket_peer_ref(i);
+                return 1;
+        }
+
+        remote = socket_peer_new();
+        if (!remote)
+                return log_oom();
+
+        remote->peer = sa.peer;
+
+        r = set_put(s->peers_by_address, remote);
+        if (r < 0)
+                return r;
+
+        remote->socket = s;
+
+        *p = remote;
+        remote = NULL;
+
+        return 1;
+}
+
 _const_ static const char* listen_lookup(int family, int type) {
 
         if (family == AF_NETLINK)
@@ -2102,22 +2187,22 @@ static void socket_enter_running(Socket *s, int cfd) {
                 Service *service;
 
                 if (s->n_connections >= s->max_connections) {
-                        log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.", s->n_connections);
+                        log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.",
+                                         s->n_connections);
                         safe_close(cfd);
                         return;
                 }
 
                 if (s->max_connections_per_source > 0) {
-                        r = socket_find_peer(s, cfd, &p);
+                        r = socket_acquire_peer(s, cfd, &p);
                         if (r < 0) {
                                 safe_close(cfd);
                                 return;
-                        }
-
-                        if (p->n_ref > s->max_connections_per_source) {
-                                log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
+                        } else if (r > 0 && p->n_ref > s->max_connections_per_source) {
+                                log_unit_warning(UNIT(s),
+                                                 "Too many incoming connections (%u) from source, refusing connection attempt.",
+                                                 p->n_ref);
                                 safe_close(cfd);
-                                p = NULL;
                                 return;
                         }
                 }
@@ -2163,10 +2248,8 @@ static void socket_enter_running(Socket *s, int cfd) {
                 cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
                 s->n_connections++;
 
-                if (s->max_connections_per_source > 0) {
-                        service->peer = socket_peer_ref(p);
-                        p = NULL;
-                }
+                service->peer = p; /* Pass ownership of the peer reference */
+                p = NULL;
 
                 r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
                 if (r < 0) {
@@ -2662,83 +2745,6 @@ _pure_ static bool socket_check_gc(Unit *u) {
         return s->n_connections > 0;
 }
 
-SocketPeer *socket_peer_new(void) {
-        SocketPeer *p;
-
-        p = new0(SocketPeer, 1);
-        if (!p)
-                return NULL;
-
-        p->n_ref = 1;
-
-        return p;
-}
-
-SocketPeer *socket_peer_ref(SocketPeer *p) {
-        if (!p)
-                return NULL;
-
-        assert(p->n_ref > 0);
-        p->n_ref++;
-
-        return p;
-}
-
-SocketPeer *socket_peer_unref(SocketPeer *p) {
-        if (!p)
-                return NULL;
-
-        assert(p->n_ref > 0);
-
-        p->n_ref--;
-
-        if (p->n_ref > 0)
-                return NULL;
-
-        if (p->socket)
-                set_remove(p->socket->peers_by_address, p);
-
-        free(p);
-
-        return NULL;
-}
-
-int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
-        _cleanup_free_ SocketPeer *remote = NULL;
-        SocketPeer sa, *i;
-        socklen_t salen = sizeof(sa.peer);
-        int r;
-
-        assert(fd >= 0);
-        assert(s);
-
-        r = getpeername(fd, &sa.peer.sa, &salen);
-        if (r < 0)
-                return log_error_errno(errno, "getpeername failed: %m");
-
-        i = set_get(s->peers_by_address, &sa);
-        if (i) {
-                *p = i;
-                return 1;
-        }
-
-        remote = socket_peer_new();
-        if (!remote)
-                return log_oom();
-
-        memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
-        remote->socket = s;
-
-        r = set_put(s->peers_by_address, remote);
-        if (r < 0)
-                return r;
-
-        *p = remote;
-        remote = NULL;
-
-        return 0;
-}
-
 static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
         SocketPort *p = userdata;
         int cfd = -1;
index edbe9df6b1d5ce86374e35e69c8d3dacd7a10dbe..6a78fd322d84cc0acf67e7a6fac03ff2916f0006 100644 (file)
@@ -168,13 +168,6 @@ struct Socket {
         RateLimit trigger_limit;
 };
 
-struct SocketPeer {
-        unsigned n_ref;
-
-        Socket *socket;
-        union sockaddr_union peer;
-};
-
 SocketPeer *socket_peer_ref(SocketPeer *p);
 SocketPeer *socket_peer_unref(SocketPeer *p);