]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
anvil: connect-limit - Keep track of all sessions per process
authorTimo Sirainen <timo.sirainen@open-xchange.com>
Wed, 8 Dec 2021 00:25:39 +0000 (02:25 +0200)
committerTimo Sirainen <timo.sirainen@open-xchange.com>
Tue, 8 Feb 2022 09:48:24 +0000 (10:48 +0100)
src/anvil/connect-limit.c

index 310b62a15179b7075aeb8f5b082773fa78f0bbb0..c231cb5179a3910aad60e2cfeb0d6c8b1a7b5ab2 100644 (file)
@@ -2,12 +2,18 @@
 
 #include "common.h"
 #include "hash.h"
+#include "llist.h"
 #include "str.h"
 #include "str-table.h"
 #include "strescape.h"
 #include "ostream.h"
 #include "connect-limit.h"
 
+struct process {
+       pid_t pid;
+       struct session *sessions;
+};
+
 struct userip {
        char *username;
        const char *service;
@@ -15,9 +21,12 @@ struct userip {
 };
 
 struct session {
+       /* process->sessions linked list */
+       struct session *process_prev, *process_next;
+
        /* points to userip_hash keys */
        struct userip *userip;
-       pid_t pid;
+       struct process *process;
        guid_128_t conn_guid;
        unsigned int refcount;
 };
@@ -29,6 +38,8 @@ struct connect_limit {
        HASH_TABLE(struct userip *, void *) userip_hash;
        /* (userip, pid) => struct session */
        HASH_TABLE(struct session *, struct session *) session_hash;
+       /* pid_t => struct process */
+       HASH_TABLE(void *, struct process *) process_hash;
 };
 
 static unsigned int userip_hash(const struct userip *userip)
@@ -52,7 +63,7 @@ static int userip_cmp(const struct userip *userip1,
 static unsigned int session_hash(const struct session *session)
 {
        return userip_hash(session->userip) ^
-               guid_128_hash(session->conn_guid) ^ session->pid;
+               guid_128_hash(session->conn_guid) ^ session->process->pid;
 }
 
 static int session_cmp(const struct session *session1,
@@ -63,9 +74,9 @@ static int session_cmp(const struct session *session1,
        if (ret != 0)
                return ret;
 
-       if (session1->pid < session2->pid)
+       if (session1->process->pid < session2->process->pid)
                return -1;
-       else if (session1->pid > session2->pid)
+       else if (session1->process->pid > session2->process->pid)
                return 1;
        else
                return userip_cmp(session1->userip, session2->userip);
@@ -81,6 +92,7 @@ struct connect_limit *connect_limit_init(void)
                          userip_hash, userip_cmp);
        hash_table_create(&limit->session_hash, default_pool, 0,
                          session_hash, session_cmp);
+       hash_table_create_direct(&limit->process_hash, default_pool, 0);
        return limit;
 }
 
@@ -91,6 +103,7 @@ void connect_limit_deinit(struct connect_limit **_limit)
        *_limit = NULL;
        hash_table_destroy(&limit->userip_hash);
        hash_table_destroy(&limit->session_hash);
+       hash_table_destroy(&limit->process_hash);
        str_table_deinit(&limit->strings);
        i_free(limit);
 }
@@ -109,6 +122,25 @@ unsigned int connect_limit_lookup(struct connect_limit *limit,
        return POINTER_CAST_TO(value, unsigned int);
 }
 
+static struct process *process_lookup(struct connect_limit *limit, pid_t pid)
+{
+       return hash_table_lookup(limit->process_hash, POINTER_CAST(pid));
+}
+
+static struct process *process_get(struct connect_limit *limit, pid_t pid)
+{
+       struct process *process;
+
+       process = process_lookup(limit, pid);
+       if (process == NULL) {
+               process = i_new(struct process, 1);
+               process->pid = pid;
+               hash_table_insert(limit->process_hash,
+                                 POINTER_CAST(pid), process);
+       }
+       return process;
+}
+
 void connect_limit_connect(struct connect_limit *limit, pid_t pid,
                           const struct connect_limit_key *key,
                           const guid_128_t conn_guid)
@@ -137,17 +169,19 @@ void connect_limit_connect(struct connect_limit *limit, pid_t pid,
 
        struct session session_lookup = {
                .userip = userip,
-               .pid = pid,
+               .process = process_get(limit, pid),
        };
        guid_128_copy(session_lookup.conn_guid, conn_guid);
        session = hash_table_lookup(limit->session_hash, &session_lookup);
        if (session == NULL) {
                session = i_new(struct session, 1);
                session->userip = userip;
-               session->pid = pid;
+               session->process = session_lookup.process;
                guid_128_copy(session->conn_guid, conn_guid);
                session->refcount = 1;
                hash_table_insert(limit->session_hash, session, session);
+               DLLIST_PREPEND_FULL(&session->process->sessions, session,
+                                   process_prev, process_next);
        } else {
                session->refcount++;
        }
@@ -186,16 +220,24 @@ void connect_limit_disconnect(struct connect_limit *limit, pid_t pid,
                              const struct connect_limit_key *key,
                              const guid_128_t conn_guid)
 {
+       struct process *process;
        struct session *session;
+
+       process = process_lookup(limit, pid);
+       if (process == NULL) {
+               i_error("connect limit: disconnection for unknown pid %s",
+                       dec2str(pid));
+               return;
+       }
+
        struct userip userip_lookup = {
                .username = (char *)key->username,
                .service = key->service,
                .ip = key->ip,
        };
-
        struct session session_lookup = {
                .userip = &userip_lookup,
-               .pid = pid,
+               .process = process,
        };
        guid_128_copy(session_lookup.conn_guid, conn_guid);
 
@@ -209,30 +251,40 @@ void connect_limit_disconnect(struct connect_limit *limit, pid_t pid,
        }
 
        if (--session->refcount == 0) {
+               DLLIST_REMOVE_FULL(&process->sessions, session,
+                                  process_prev, process_next);
                hash_table_remove(limit->session_hash, session);
                session_free(session);
        }
 
        userip_hash_unref(limit, &userip_lookup);
+       if (process->sessions == NULL) {
+               hash_table_remove(limit->process_hash, POINTER_CAST(pid));
+               i_free(process);
+       }
 }
 
 void connect_limit_disconnect_pid(struct connect_limit *limit, pid_t pid)
 {
-       struct hash_iterate_context *iter;
-       struct session *session, *value;
+       struct process *process;
+       struct session *session;
 
-       /* this should happen rarely (or never), so this slow implementation
-          should be fine. */
-       iter = hash_table_iterate_init(limit->session_hash);
-       while (hash_table_iterate(iter, limit->session_hash, &session, &value)) {
-               if (session->pid == pid) {
-                       hash_table_remove(limit->session_hash, session);
-                       for (; session->refcount > 0; session->refcount--)
-                               userip_hash_unref(limit, session->userip);
-                       session_free(session);
-               }
+       process = process_lookup(limit, pid);
+       if (process == NULL)
+               return;
+
+       while (process->sessions != NULL) {
+               session = process->sessions;
+               DLLIST_REMOVE_FULL(&process->sessions, session,
+                                  process_prev, process_next);
+
+               hash_table_remove(limit->session_hash, session);
+               for (; session->refcount > 0; session->refcount--)
+                       userip_hash_unref(limit, session->userip);
+               session_free(session);
        }
-       hash_table_iterate_deinit(&iter);
+       hash_table_remove(limit->process_hash, POINTER_CAST(pid));
+       i_free(process);
 }
 
 void connect_limit_dump(struct connect_limit *limit, struct ostream *output)
@@ -246,7 +298,7 @@ void connect_limit_dump(struct connect_limit *limit, struct ostream *output)
        while (ret >= 0 &&
               hash_table_iterate(iter, limit->session_hash, &session, &value)) T_BEGIN {
                str_truncate(str, 0);
-               str_printfa(str, "%ld\t%u\t", (long)session->pid,
+               str_printfa(str, "%ld\t%u\t", (long)session->process->pid,
                            session->refcount);
                str_append_tabescaped(str, session->userip->username);
                str_append_c(str, '\t');