]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib: connection - Add handshake support
authorAki Tuomi <aki.tuomi@dovecot.fi>
Wed, 21 Nov 2018 10:13:06 +0000 (12:13 +0200)
committerAki Tuomi <aki.tuomi@dovecot.fi>
Tue, 27 Nov 2018 17:30:40 +0000 (19:30 +0200)
This allows specifying a custom handshake, that will be
called before actual processing starts. Defaults to version check.

src/lib/connection.c
src/lib/connection.h

index f8defde590c54ce05ad2b0d2c5d987edbb054e44..e6801a0eafe5923a0d03ec764ebf24b09680675d 100644 (file)
@@ -35,6 +35,19 @@ void connection_input_default(struct connection *conn)
        struct ostream *output;
        int ret = 0;
 
+       if (!conn->handshake_received &&
+           conn->list->v.handshake != NULL) {
+               if ((ret = conn->list->v.handshake(conn)) < 0) {
+                       conn->disconnect_reason = CONNECTION_DISCONNECT_HANDSHAKE_FAILED;
+                       conn->list->v.destroy(conn);
+                       return;
+               } else if (ret == 0) {
+                       return;
+               } else {
+                       conn->handshake_received = TRUE;
+               }
+       }
+
        switch (connection_input_read(conn)) {
        case -1:
                return;
@@ -54,7 +67,17 @@ void connection_input_default(struct connection *conn)
        }
        while (!input->closed && (line = i_stream_next_line(input)) != NULL) {
                T_BEGIN {
-                       ret = conn->list->v.input_line(conn, line);
+                       if (!conn->handshake_received &&
+                           conn->list->v.handshake_line != NULL) {
+                               ret = conn->list->v.handshake_line(conn, line);
+                               if (ret > 0)
+                                       conn->handshake_received = TRUE;
+                               else if (ret == 0)
+                                       /* continue reading */
+                                       ret = 1;
+                       } else {
+                               ret = conn->list->v.input_line(conn, line);
+                       }
                } T_END;
                if (ret <= 0)
                        break;
@@ -64,7 +87,10 @@ void connection_input_default(struct connection *conn)
                o_stream_unref(&output);
        }
        if (ret < 0 && !input->closed) {
-               conn->disconnect_reason = CONNECTION_DISCONNECT_DEINIT;
+               if (conn->handshake_received)
+                       conn->disconnect_reason = CONNECTION_DISCONNECT_DEINIT;
+               else
+                       conn->disconnect_reason = CONNECTION_DISCONNECT_HANDSHAKE_FAILED;
                conn->list->v.destroy(conn);
        }
        i_stream_unref(&input);
@@ -75,6 +101,9 @@ int connection_verify_version(struct connection *conn,
 {
        unsigned int recv_major_version;
 
+       if (conn->version_received)
+               return 1;
+
        /* VERSION <tab> service_name <tab> major version <tab> minor version */
        if (str_array_length(args) != 4 ||
            strcmp(args[0], "VERSION") != 0 ||
@@ -98,7 +127,9 @@ int connection_verify_version(struct connection *conn,
                        recv_major_version, conn->list->set.major_version);
                return -1;
        }
-       return 0;
+
+       conn->version_received = TRUE;
+       return 1;
 }
 
 int connection_input_line_default(struct connection *conn, const char *line)
@@ -106,17 +137,28 @@ int connection_input_line_default(struct connection *conn, const char *line)
        const char *const *args;
 
        args = t_strsplit_tabescaped(line);
-       if (!conn->version_received) {
-               if (connection_verify_version(conn, args) < 0)
-                       return -1;
-               conn->version_received = TRUE;
-               return 1;
-       }
        if (args[0] == NULL && !conn->list->set.allow_empty_args_input) {
                e_error(conn->event, "Unexpectedly received empty line");
                return -1;
        }
 
+       if (!conn->handshake_received &&
+           (conn->list->v.handshake_args != connection_verify_version ||
+            conn->list->set.major_version != 0)) {
+               int ret;
+               if ((ret = conn->list->v.handshake_args(conn, args)) == 0)
+                       ret = 1; /* continue reading */
+               else if (ret > 0)
+                       conn->handshake_received = TRUE;
+               return ret;
+       } else if (!conn->handshake_received) {
+               /* we don't do handshakes */
+               conn->handshake_received = TRUE;
+       }
+
+       /* version must be handled though, by something */
+       i_assert(conn->version_received);
+
        return conn->list->v.input_args(conn, args);
 }
 
@@ -550,6 +592,8 @@ const char *connection_disconnect_reason(struct connection *conn)
        case CONNECTION_DISCONNECT_NOT:
        case CONNECTION_DISCONNECT_BUFFER_FULL:
                return io_stream_get_disconnect_reason(conn->input, conn->output);
+       case CONNECTION_DISCONNECT_HANDSHAKE_FAILED:
+               return "Handshake failed";
        }
        i_unreached();
 }
@@ -612,6 +656,8 @@ connection_list_init(const struct connection_settings *set,
                list->v.input = connection_input_default;
        if (list->v.input_line == NULL)
                list->v.input_line = connection_input_line_default;
+       if (list->v.handshake_args == NULL)
+               list->v.handshake_args = connection_verify_version;
        if (list->v.idle_timeout == NULL)
                list->v.idle_timeout = connection_idle_timeout;
        if (list->v.connect_timeout == NULL)
index 9bac17097a1b5f4f1e14f24134815d82fe42854d..bec037cd77bf4ca30ffec9326bb20d803dda5152 100644 (file)
@@ -23,7 +23,9 @@ enum connection_disconnect_reason {
        /* connect() timed out */
        CONNECTION_DISCONNECT_CONNECT_TIMEOUT,
        /* remote didn't send input */
-       CONNECTION_DISCONNECT_IDLE_TIMEOUT
+       CONNECTION_DISCONNECT_IDLE_TIMEOUT,
+       /* handshake failed */
+       CONNECTION_DISCONNECT_HANDSHAKE_FAILED,
 };
 
 struct connection_vfuncs {
@@ -44,6 +46,19 @@ struct connection_vfuncs {
        int (*input_line)(struct connection *conn, const char *line);
        int (*input_args)(struct connection *conn, const char *const *args);
 
+       /* handshake functions. Defaults to version checking.
+          must return 1 when handshake is completed, otherwise return 0.
+          return -1 to indicate error and disconnect client.
+
+          if you implement this, remember to call connection_verify_version
+          yourself, otherwise you end up with assert crash.
+
+          these will not be called if you implement `input` virtual function.
+       */
+       int (*handshake)(struct connection *conn);
+       int (*handshake_line)(struct connection *conn, const char *line);
+       int (*handshake_args)(struct connection *conn, const char *const *args);
+
        /* Called when input_idle_timeout_secs is reached, defaults to disconnect */
        void (*idle_timeout)(struct connection *conn);
        /* Called when client_connect_timeout_msecs is reached, defaults to disconnect */
@@ -113,6 +128,7 @@ struct connection {
        enum connection_disconnect_reason disconnect_reason;
 
        bool version_received:1;
+       bool handshake_received:1;
        bool unix_socket:1;
        bool from_streams:1;
        bool disconnected:1;