]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib-master: When finishing auth, don't send REQUEST if auth process has restarted.
authorTimo Sirainen <tss@iki.fi>
Mon, 20 Sep 2010 16:50:29 +0000 (17:50 +0100)
committerTimo Sirainen <tss@iki.fi>
Mon, 20 Sep 2010 16:50:29 +0000 (17:50 +0100)
This avoids unnecessary "Master requested auth for nonexistent client" errors
when auth process restarts (crashes).

src/lib-master/master-login-auth.c

index 030eee83ad058e8bd397f6cdf334bce112ef62f1..06aecec9a01a2a84953e6ae422ab72ac49d0f110 100644 (file)
@@ -23,6 +23,11 @@ struct master_login_auth_request {
        unsigned int id;
        time_t create_stamp;
 
+       pid_t auth_pid;
+       unsigned int auth_id;
+       unsigned int client_pid;
+       uint8_t cookie[MASTER_AUTH_COOKIE_SIZE];
+
        master_login_auth_request_callback_t *callback;
        void *context;
 };
@@ -43,10 +48,14 @@ struct master_login_auth {
        /* linked list of requests, ordered by create_stamp */
        struct master_login_auth_request *request_head, *request_tail;
 
+       pid_t auth_server_pid;
+
        unsigned int version_received:1;
+       unsigned int spid_received:1;
 };
 
 static void master_login_auth_set_timeout(struct master_login_auth *auth);
+static void master_login_auth_send_all_requests(struct master_login_auth *auth);
 
 struct master_login_auth *master_login_auth_init(const char *auth_socket_path)
 {
@@ -156,27 +165,35 @@ static void master_login_auth_set_timeout(struct master_login_auth *auth)
        }
 }
 
-static struct master_login_auth_request *
-master_login_auth_lookup_request(struct master_login_auth *auth,
-                                unsigned int id)
+static void
+master_login_auth_request_remove(struct master_login_auth *auth,
+                                struct master_login_auth_request *request)
 {
-       struct master_login_auth_request *request;
        bool update_timeout;
 
-       request = hash_table_lookup(auth->requests, POINTER_CAST(id));
-       if (request == NULL) {
-               i_error("Auth server sent reply with unknown ID %u", id);
-               return NULL;
-       }
        update_timeout = request->prev == NULL;
 
-       hash_table_remove(auth->requests, POINTER_CAST(id));
+       hash_table_remove(auth->requests, POINTER_CAST(request->id));
        DLLIST2_REMOVE(&auth->request_head, &auth->request_tail, request);
 
        if (update_timeout) {
                timeout_remove(&auth->to);
                master_login_auth_set_timeout(auth);
        }
+}
+
+static struct master_login_auth_request *
+master_login_auth_lookup_request(struct master_login_auth *auth,
+                                unsigned int id)
+{
+       struct master_login_auth_request *request;
+
+       request = hash_table_lookup(auth->requests, POINTER_CAST(id));
+       if (request == NULL) {
+               i_error("Auth server sent reply with unknown ID %u", id);
+               return NULL;
+       }
+       master_login_auth_request_remove(auth, request);
        return request;
 }
 
@@ -291,6 +308,21 @@ static void master_login_auth_input(struct master_login_auth *auth)
                }
                auth->version_received = TRUE;
        }
+       if (!auth->spid_received) {
+               line = i_stream_next_line(auth->input);
+               if (line == NULL)
+                       return;
+
+               if (strncmp(line, "SPID\t", 5) != 0 ||
+                   str_to_pid(line + 5, &auth->auth_server_pid) < 0) {
+                       i_error("Authentication server didn't "
+                               "send valid SPID as expected: %s", line);
+                       master_login_auth_disconnect(auth);
+                       return;
+               }
+               auth->spid_received = TRUE;
+               master_login_auth_send_all_requests(auth);
+       }
 
        auth->refcount++;
        while ((line = i_stream_next_line(auth->input)) != NULL) {
@@ -331,6 +363,41 @@ master_login_auth_connect(struct master_login_auth *auth)
        return 0;
 }
 
+static void
+master_login_auth_send_request(struct master_login_auth *auth,
+                              struct master_login_auth_request *req)
+{
+       string_t *str;
+
+       i_assert(auth->spid_received);
+
+       if (auth->auth_server_pid != req->auth_pid) {
+               /* auth server was restarted. don't even attempt a login. */
+               master_login_auth_request_remove(auth, req);
+               req->callback(NULL, MASTER_AUTH_ERRMSG_INTERNAL_FAILURE,
+                             req->context);
+               i_free(req);
+               return;
+       }
+
+       str = t_str_new(128);
+       str_printfa(str, "REQUEST\t%u\t%u\t%u\t", req->id,
+                   req->client_pid, req->auth_id);
+       binary_to_hex_append(str, req->cookie, sizeof(req->cookie));
+       str_append_c(str, '\n');
+       o_stream_send(auth->output, str_data(str), str_len(str));
+}
+
+static void master_login_auth_send_all_requests(struct master_login_auth *auth)
+{
+       struct master_login_auth_request *req, *next;
+
+       for (req = auth->request_head; req != NULL; req = next) {
+               next = req->next;
+               master_login_auth_send_request(auth, req);
+       }
+}
+
 void master_login_auth_request(struct master_login_auth *auth,
                               const struct master_auth_request *req,
                               master_login_auth_request_callback_t *callback,
@@ -338,33 +405,30 @@ void master_login_auth_request(struct master_login_auth *auth,
 {
        struct master_login_auth_request *login_req;
        unsigned int id;
-       string_t *str;
 
-       str = t_str_new(128);
        if (auth->fd == -1) {
                if (master_login_auth_connect(auth) < 0) {
                        callback(NULL, MASTER_AUTH_ERRMSG_INTERNAL_FAILURE,
                                 context);
                        return;
                }
-               str_printfa(str, "VERSION\t%u\t%u\n",
-                           AUTH_MASTER_PROTOCOL_MAJOR_VERSION,
-                           AUTH_MASTER_PROTOCOL_MINOR_VERSION);
+               o_stream_send_str(auth->output,
+                       t_strdup_printf("VERSION\t%u\t%u\n",
+                                       AUTH_MASTER_PROTOCOL_MAJOR_VERSION,
+                                       AUTH_MASTER_PROTOCOL_MINOR_VERSION));
        }
 
        id = ++auth->id_counter;
        if (id == 0)
                id++;
 
-       str_printfa(str, "REQUEST\t%u\t%u\t%u\t", id,
-                   req->client_pid, req->auth_id);
-       binary_to_hex_append(str, req->cookie, sizeof(req->cookie));
-       str_append_c(str, '\n');
-       o_stream_send(auth->output, str_data(str), str_len(str));
-
        login_req = i_new(struct master_login_auth_request, 1);
        login_req->create_stamp = ioloop_time;
        login_req->id = id;
+       login_req->auth_pid = req->auth_pid;
+       login_req->client_pid = req->client_pid;
+       login_req->auth_id = req->auth_id;
+       memcpy(login_req->cookie, req->cookie, sizeof(login_req->cookie));
        login_req->callback = callback;
        login_req->context = context;
        hash_table_insert(auth->requests, POINTER_CAST(id), login_req);
@@ -372,6 +436,9 @@ void master_login_auth_request(struct master_login_auth *auth,
 
        if (auth->to == NULL)
                master_login_auth_set_timeout(auth);
+
+       if (auth->spid_received)
+               master_login_auth_send_request(auth, login_req);
 }
 
 unsigned int master_login_auth_request_count(struct master_login_auth *auth)