struct ostream *output;
int ret = 0;
+ if (!conn->handshake_received &&
+ conn->list->v.handshake != NULL) {
+ if ((ret = conn->list->v.handshake(conn)) < 0) {
+ conn->disconnect_reason = CONNECTION_DISCONNECT_HANDSHAKE_FAILED;
+ conn->list->v.destroy(conn);
+ return;
+ } else if (ret == 0) {
+ return;
+ } else {
+ conn->handshake_received = TRUE;
+ }
+ }
+
switch (connection_input_read(conn)) {
case -1:
return;
}
while (!input->closed && (line = i_stream_next_line(input)) != NULL) {
T_BEGIN {
- ret = conn->list->v.input_line(conn, line);
+ if (!conn->handshake_received &&
+ conn->list->v.handshake_line != NULL) {
+ ret = conn->list->v.handshake_line(conn, line);
+ if (ret > 0)
+ conn->handshake_received = TRUE;
+ else if (ret == 0)
+ /* continue reading */
+ ret = 1;
+ } else {
+ ret = conn->list->v.input_line(conn, line);
+ }
} T_END;
if (ret <= 0)
break;
o_stream_unref(&output);
}
if (ret < 0 && !input->closed) {
- conn->disconnect_reason = CONNECTION_DISCONNECT_DEINIT;
+ if (conn->handshake_received)
+ conn->disconnect_reason = CONNECTION_DISCONNECT_DEINIT;
+ else
+ conn->disconnect_reason = CONNECTION_DISCONNECT_HANDSHAKE_FAILED;
conn->list->v.destroy(conn);
}
i_stream_unref(&input);
{
unsigned int recv_major_version;
+ if (conn->version_received)
+ return 1;
+
/* VERSION <tab> service_name <tab> major version <tab> minor version */
if (str_array_length(args) != 4 ||
strcmp(args[0], "VERSION") != 0 ||
recv_major_version, conn->list->set.major_version);
return -1;
}
- return 0;
+
+ conn->version_received = TRUE;
+ return 1;
}
int connection_input_line_default(struct connection *conn, const char *line)
const char *const *args;
args = t_strsplit_tabescaped(line);
- if (!conn->version_received) {
- if (connection_verify_version(conn, args) < 0)
- return -1;
- conn->version_received = TRUE;
- return 1;
- }
if (args[0] == NULL && !conn->list->set.allow_empty_args_input) {
e_error(conn->event, "Unexpectedly received empty line");
return -1;
}
+ if (!conn->handshake_received &&
+ (conn->list->v.handshake_args != connection_verify_version ||
+ conn->list->set.major_version != 0)) {
+ int ret;
+ if ((ret = conn->list->v.handshake_args(conn, args)) == 0)
+ ret = 1; /* continue reading */
+ else if (ret > 0)
+ conn->handshake_received = TRUE;
+ return ret;
+ } else if (!conn->handshake_received) {
+ /* we don't do handshakes */
+ conn->handshake_received = TRUE;
+ }
+
+ /* version must be handled though, by something */
+ i_assert(conn->version_received);
+
return conn->list->v.input_args(conn, args);
}
case CONNECTION_DISCONNECT_NOT:
case CONNECTION_DISCONNECT_BUFFER_FULL:
return io_stream_get_disconnect_reason(conn->input, conn->output);
+ case CONNECTION_DISCONNECT_HANDSHAKE_FAILED:
+ return "Handshake failed";
}
i_unreached();
}
list->v.input = connection_input_default;
if (list->v.input_line == NULL)
list->v.input_line = connection_input_line_default;
+ if (list->v.handshake_args == NULL)
+ list->v.handshake_args = connection_verify_version;
if (list->v.idle_timeout == NULL)
list->v.idle_timeout = connection_idle_timeout;
if (list->v.connect_timeout == NULL)
/* connect() timed out */
CONNECTION_DISCONNECT_CONNECT_TIMEOUT,
/* remote didn't send input */
- CONNECTION_DISCONNECT_IDLE_TIMEOUT
+ CONNECTION_DISCONNECT_IDLE_TIMEOUT,
+ /* handshake failed */
+ CONNECTION_DISCONNECT_HANDSHAKE_FAILED,
};
struct connection_vfuncs {
int (*input_line)(struct connection *conn, const char *line);
int (*input_args)(struct connection *conn, const char *const *args);
+ /* handshake functions. Defaults to version checking.
+ must return 1 when handshake is completed, otherwise return 0.
+ return -1 to indicate error and disconnect client.
+
+ if you implement this, remember to call connection_verify_version
+ yourself, otherwise you end up with assert crash.
+
+ these will not be called if you implement `input` virtual function.
+ */
+ int (*handshake)(struct connection *conn);
+ int (*handshake_line)(struct connection *conn, const char *line);
+ int (*handshake_args)(struct connection *conn, const char *const *args);
+
/* Called when input_idle_timeout_secs is reached, defaults to disconnect */
void (*idle_timeout)(struct connection *conn);
/* Called when client_connect_timeout_msecs is reached, defaults to disconnect */
enum connection_disconnect_reason disconnect_reason;
bool version_received:1;
+ bool handshake_received:1;
bool unix_socket:1;
bool from_streams:1;
bool disconnected:1;