]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
auth: Added CACHE-FLUSH command to flush some/all users from auth cache.
authorTimo Sirainen <tss@iki.fi>
Wed, 4 Jul 2012 07:56:53 +0000 (10:56 +0300)
committerTimo Sirainen <tss@iki.fi>
Wed, 4 Jul 2012 07:56:53 +0000 (10:56 +0300)
src/auth/auth-cache.c
src/auth/auth-cache.h
src/auth/auth-master-connection.c
src/auth/auth-request.h
src/auth/auth.c
src/auth/test-auth-cache.c

index 1cf18f88c0bbf77e6ae0c1e8b3ff3e64c3739459..598a2ab6b54aa57e8b1691c8324269450a720793 100644 (file)
@@ -23,36 +23,72 @@ struct auth_cache {
        unsigned long long pos_size, neg_size;
 };
 
-static const struct var_expand_table *
-auth_request_var_expand_tab_find(const char *key, unsigned int size)
+static bool
+auth_request_var_expand_tab_find(const char *key, unsigned int size,
+                                unsigned int *idx_r)
 {
        const struct var_expand_table *tab = auth_request_var_expand_static_tab;
        unsigned int i;
 
        for (i = 0; tab[i].key != '\0' || tab[i].long_key != NULL; i++) {
                if (size == 1) {
-                       if (key[0] == tab[i].key)
-                               return &tab[i];
+                       if (key[0] == tab[i].key) {
+                               *idx_r = i;
+                               return TRUE;
+                       }
                } else if (tab[i].long_key != NULL) {
                        if (strncmp(key, tab[i].long_key, size) == 0 &&
-                           tab[i].long_key[size] == '\0')
-                               return &tab[i];
+                           tab[i].long_key[size] == '\0') {
+                               *idx_r = i;
+                               return TRUE;
+                       }
                }
        }
-       return NULL;
+       return FALSE;
+}
+
+static void
+auth_cache_key_add_var(string_t *str, const char *data, unsigned int len)
+{
+       if (str_len(str) > 0)
+               str_append_c(str, '\t');
+       str_append_c(str, '%');
+       if (len == 1)
+               str_append_c(str, data[0]);
+       else {
+               str_append_c(str, '{');
+               str_append_n(str, data, len);
+               str_append_c(str, '}');
+       }
+}
+
+static void auth_cache_key_add_tab_idx(string_t *str, unsigned int i)
+{
+       const struct var_expand_table *tab =
+               &auth_request_var_expand_static_tab[i];
+
+       if (str_len(str) > 0)
+               str_append_c(str, '\t');
+       str_append_c(str, '%');
+       if (tab->key != '\0')
+               str_append_c(str, tab->key);
+       else {
+               str_append_c(str, '{');
+               str_append(str, tab->long_key);
+               str_append_c(str, '}');
+       }
 }
 
 char *auth_cache_parse_key(pool_t pool, const char *query)
 {
-       const struct var_expand_table *tab;
        string_t *str;
-       bool key_seen[100];
-       unsigned int idx, size, tab_idx;
-       bool add_key;
+       bool key_seen[AUTH_REQUEST_VAR_TAB_COUNT];
+       const char *extra_vars;
+       unsigned int i, idx, size, tab_idx;
 
        memset(key_seen, 0, sizeof(key_seen));
 
-       str = str_new(pool, 32);
+       str = t_str_new(32);
        for (; *query != '\0'; ) {
                if (*query != '%') {
                        query++;
@@ -66,34 +102,45 @@ char *auth_cache_parse_key(pool_t pool, const char *query)
                }
                query += idx;
 
-               tab = auth_request_var_expand_tab_find(query, size);
-               if (tab == NULL) {
+               if (!auth_request_var_expand_tab_find(query, size, &tab_idx)) {
                        /* just add the key. it would be nice to prevent
                           duplicates here as well, but that's just too
                           much trouble and probably very rare. */
-                       add_key = TRUE;
+                       auth_cache_key_add_var(str, query, size);
                } else {
-                       tab_idx = tab - auth_request_var_expand_static_tab;
                        i_assert(tab_idx < N_ELEMENTS(key_seen));
-                       /* @UNSAFE */
-                       add_key = !key_seen[tab_idx];
                        key_seen[tab_idx] = TRUE;
                }
-               if (add_key) {
-                       if (str_len(str) != 0)
-                               str_append_c(str, '\t');
-                       str_append_c(str, '%');
-                       if (size == 1)
-                               str_append_c(str, query[0]);
-                       else {
-                               str_append_c(str, '{');
-                               str_append_n(str, query, size);
-                               str_append_c(str, '}');
-                       }
-               }
                query += size;
        }
-       return str_free_without_data(&str);
+
+       if (key_seen[AUTH_REQUEST_VAR_TAB_USERNAME_IDX] &&
+           key_seen[AUTH_REQUEST_VAR_TAB_DOMAIN_IDX]) {
+               /* %n and %d both used -> replace with %u */
+               key_seen[AUTH_REQUEST_VAR_TAB_USER_IDX] = TRUE;
+               key_seen[AUTH_REQUEST_VAR_TAB_USERNAME_IDX] = FALSE;
+               key_seen[AUTH_REQUEST_VAR_TAB_DOMAIN_IDX] = FALSE;
+       }
+
+       /* we rely on these being at the beginning */
+       i_assert(AUTH_REQUEST_VAR_TAB_USER_IDX == 0);
+       i_assert(AUTH_REQUEST_VAR_TAB_USERNAME_IDX == 1);
+       i_assert(AUTH_REQUEST_VAR_TAB_DOMAIN_IDX == 2);
+
+       extra_vars = t_strdup(str_c(str));
+       str_truncate(str, 0);
+       for (i = 0; i < N_ELEMENTS(key_seen); i++) {
+               if (key_seen[i])
+                       auth_cache_key_add_tab_idx(str, i);
+       }
+
+       if (*extra_vars != '\0') {
+               if (str_len(str) > 0)
+                       str_append_c(str, '\t');
+               str_append(str, extra_vars);
+       }
+
+       return p_strdup(pool, str_c(str));
 }
 
 static void
@@ -142,8 +189,8 @@ static void sig_auth_cache_clear(const siginfo_t *si ATTR_UNUSED, void *context)
 {
        struct auth_cache *cache = context;
 
-       i_info("SIGHUP received, clearing cache");
-       auth_cache_clear(cache);
+       i_info("SIGHUP received, %u cache entries flushed",
+              auth_cache_clear(cache));
 }
 
 static void sig_auth_cache_stats(const siginfo_t *si ATTR_UNUSED, void *context)
@@ -200,11 +247,69 @@ void auth_cache_free(struct auth_cache **_cache)
        i_free(cache);
 }
 
-void auth_cache_clear(struct auth_cache *cache)
+unsigned int auth_cache_clear(struct auth_cache *cache)
 {
+       unsigned int ret = hash_table_count(cache->hash);
+
        while (cache->tail != NULL)
                auth_cache_node_destroy(cache, cache->tail);
        hash_table_clear(cache->hash, FALSE);
+       return ret;
+}
+
+static bool auth_cache_node_is_user(struct auth_cache_node *node,
+                                   const char *username)
+{
+       const char *data = node->data;
+       unsigned int username_len;
+
+       /* The cache nodes begin with "P"/"U", passdb/userdb ID, "/" and
+          then usually followed by the username. It's too much trouble to
+          keep track of all the cache keys, so we'll just match it as if it
+          was the username. If e.g. '%n' is used in the cache key instead of
+          '%u', it means that cache entries can be removed only when @domain
+          isn't in the username parameter. */
+       if (*data != 'P' && *data != 'U')
+               return FALSE;
+       data++;
+
+       while (*data >= '0' && *data <= '9')
+               data++;
+       if (*data != '/')
+               return FALSE;
+       data++;
+
+       username_len = strlen(username);
+       return strncmp(data, username, username_len) == 0 &&
+               (data[username_len] == '\t' || data[username_len] == '\0');
+}
+
+static bool auth_cache_node_is_one_of_users(struct auth_cache_node *node,
+                                           const char *const *usernames)
+{
+       unsigned int i;
+
+       for (i = 0; usernames[i] != NULL; i++) {
+               if (auth_cache_node_is_user(node, usernames[i]))
+                       return TRUE;
+       }
+       return FALSE;
+}
+
+unsigned int auth_cache_clear_users(struct auth_cache *cache,
+                                   const char *const *usernames)
+{
+       struct auth_cache_node *node, *next;
+       unsigned int ret = 0;
+
+       for (node = cache->tail; node != NULL; node = next) {
+               next = node->next;
+               if (auth_cache_node_is_one_of_users(node, usernames)) {
+                       auth_cache_node_destroy(cache, cache->tail);
+                       ret++;
+               }
+       }
+       return ret;
 }
 
 static const char *
@@ -216,12 +321,27 @@ auth_cache_escape(const char *string,
        return str_tabescape(string);
 }
 
+static const char *
+auth_request_expand_cache_key(const struct auth_request *request,
+                             const char *key)
+{
+       string_t *str;
+
+       /* Uniquely identify the request's passdb/userdb with the P/U prefix
+          and by "%!", which expands to the passdb/userdb ID number. */
+       key = t_strconcat(request->userdb_lookup ? "U" : "P", "%!/", key, NULL);
+
+       str = t_str_new(256);
+       var_expand(str, key,
+                  auth_request_get_var_expand_table(request, auth_cache_escape));
+       return str_c(str);
+}
+
 const char *
 auth_cache_lookup(struct auth_cache *cache, const struct auth_request *request,
                  const char *key, struct auth_cache_node **node_r,
                  bool *expired_r, bool *neg_expired_r)
 {
-       string_t *str;
        struct auth_cache_node *node;
        const char *value;
        unsigned int ttl_secs;
@@ -230,13 +350,8 @@ auth_cache_lookup(struct auth_cache *cache, const struct auth_request *request,
        *expired_r = FALSE;
        *neg_expired_r = FALSE;
 
-       /* %! is prepended automatically. it contains the passdb ID number. */
-       str = t_str_new(256);
-       var_expand(str, t_strconcat(request->userdb_lookup ? "U" : "P",
-                                   "%!/", key, NULL),
-                  auth_request_get_var_expand_table(request, auth_cache_escape));
-
-       node = hash_table_lookup(cache->hash, str_c(str));
+       key = auth_request_expand_cache_key(request, key);
+       node = hash_table_lookup(cache->hash, key);
        if (node == NULL) {
                cache->miss_count++;
                return NULL;
@@ -269,9 +384,8 @@ auth_cache_lookup(struct auth_cache *cache, const struct auth_request *request,
 void auth_cache_insert(struct auth_cache *cache, struct auth_request *request,
                       const char *key, const char *value, bool last_success)
 {
-       string_t *str;
         struct auth_cache_node *node;
-       size_t data_size, alloc_size, value_len = strlen(value);
+       size_t data_size, alloc_size, key_len, value_len = strlen(value);
        char *current_username;
 
        if (*value == '\0' && cache->neg_ttl_secs == 0) {
@@ -286,15 +400,12 @@ void auth_cache_insert(struct auth_cache *cache, struct auth_request *request,
            request->requested_login_user == NULL)
                request->user = t_strdup_noconst(request->translated_username);
 
-       /* %! is prepended automatically. it contains the db ID number. */
-       str = t_str_new(256);
-       var_expand(str, t_strconcat(request->userdb_lookup ? "U" : "P",
-                                   "%!/", key, NULL),
-                  auth_request_get_var_expand_table(request, auth_cache_escape));
+       key = auth_request_expand_cache_key(request, key);
+       key_len = strlen(key);
 
        request->user = current_username;
 
-       data_size = str_len(str) + 1 + value_len + 1;
+       data_size = key_len + 1 + value_len + 1;
        alloc_size = sizeof(struct auth_cache_node) -
                sizeof(node->data) + data_size;
 
@@ -302,7 +413,7 @@ void auth_cache_insert(struct auth_cache *cache, struct auth_request *request,
        while (cache->size_left < alloc_size && cache->tail != NULL)
                auth_cache_node_destroy(cache, cache->tail);
 
-       node = hash_table_lookup(cache->hash, str_c(str));
+       node = hash_table_lookup(cache->hash, key);
        if (node != NULL) {
                /* key is already in cache (probably expired), remove it */
                auth_cache_node_destroy(cache, node);
@@ -313,8 +424,8 @@ void auth_cache_insert(struct auth_cache *cache, struct auth_request *request,
        node->created = time(NULL);
        node->alloc_size = alloc_size;
        node->last_success = last_success;
-       memcpy(node->data, str_data(str), str_len(str));
-       memcpy(node->data + str_len(str) + 1, value, value_len);
+       memcpy(node->data, key, key_len);
+       memcpy(node->data + key_len + 1, value, value_len);
 
        auth_cache_node_link_head(cache, node);
 
@@ -331,17 +442,12 @@ void auth_cache_insert(struct auth_cache *cache, struct auth_request *request,
 }
 
 void auth_cache_remove(struct auth_cache *cache,
-                      const struct auth_request *request,
-                      const char *key)
+                      const struct auth_request *request, const char *key)
 {
-       string_t *str;
        struct auth_cache_node *node;
 
-       str = t_str_new(256);
-       var_expand(str, key,
-                  auth_request_get_var_expand_table(request, auth_cache_escape));
-
-       node = hash_table_lookup(cache->hash, str_c(str));
+       key = auth_request_expand_cache_key(request, key);
+       node = hash_table_lookup(cache->hash, key);
        if (node == NULL)
                return;
 
index d82507cb3d1f39c1d79e390c208ab2c16eee2421..9839b438a6f22d31e4590d28deafb3f11a4f8ff5 100644 (file)
@@ -28,8 +28,10 @@ struct auth_cache *auth_cache_new(size_t max_size, unsigned int ttl_secs,
                                  unsigned int neg_ttl_secs);
 void auth_cache_free(struct auth_cache **cache);
 
-/* Clear the cache. */
-void auth_cache_clear(struct auth_cache *cache);
+/* Clear the cache. Returns how many entries were removed. */
+unsigned int auth_cache_clear(struct auth_cache *cache);
+unsigned int auth_cache_clear_users(struct auth_cache *cache,
+                                   const char *const *usernames);
 
 /* Look key from cache. key should be the same string as returned by
    auth_cache_parse_key(). Returned node can't be used after any other
index d8a4491175654100db219e6bd949a78c88ecef85..e3c270f8ae6e42b097a289c4bbc3b1ec6bc4a8dd 100644 (file)
@@ -18,6 +18,7 @@
 #include "userdb.h"
 #include "userdb-blocking.h"
 #include "master-interface.h"
+#include "passdb-cache.h"
 #include "auth-request-handler.h"
 #include "auth-client-connection.h"
 #include "auth-master-connection.h"
@@ -136,6 +137,30 @@ master_input_request(struct auth_master_connection *conn, const char *args)
        return TRUE;
 }
 
+static int
+master_input_cache_flush(struct auth_master_connection *conn, const char *args)
+{
+       const char *const *list;
+       unsigned int count;
+
+       /* <id> [<user> [<user> [..]] */
+       list = t_strsplit_tab(args);
+       if (list[0] == NULL) {
+               i_error("BUG: doveadm sent broken CACHE-FLUSH");
+               return FALSE;
+       }
+
+       if (list[1] == NULL) {
+               /* flush the whole cache */
+               count = auth_cache_clear(passdb_cache);
+       } else {
+               count = auth_cache_clear_users(passdb_cache, list+1);
+       }
+       (void)o_stream_send_str(conn->output,
+               t_strdup_printf("OK\t%s\t%u\n", list[0], count));
+       return TRUE;
+}
+
 static int
 master_input_auth_request(struct auth_master_connection *conn, const char *args,
                          const char *cmd, struct auth_request **request_r,
@@ -566,6 +591,8 @@ auth_master_input_line(struct auth_master_connection *conn, const char *line)
                i_assert(conn->userdb_restricted_uid == 0);
                if (strncmp(line, "REQUEST\t", 8) == 0)
                        return master_input_request(conn, line + 8);
+               if (strncmp(line, "CACHE-FLUSH\t", 12) == 0)
+                       return master_input_cache_flush(conn, line + 12);
                if (strncmp(line, "CPID\t", 5) == 0) {
                        i_error("Authentication client trying to connect to "
                                "master socket");
index d450e6b571e2e1432af92c1018a47fe93c38fb6e..b28ab350fd11ef2d7f5a49111f2c3c3a0c7f6cbd 100644 (file)
@@ -125,6 +125,10 @@ struct auth_request {
 typedef void auth_request_proxy_cb_t(bool success, struct auth_request *);
 
 extern unsigned int auth_request_state_count[AUTH_REQUEST_STATE_MAX];
+#define AUTH_REQUEST_VAR_TAB_USER_IDX 0
+#define AUTH_REQUEST_VAR_TAB_USERNAME_IDX 1
+#define AUTH_REQUEST_VAR_TAB_DOMAIN_IDX 2
+#define AUTH_REQUEST_VAR_TAB_COUNT 19
 extern const struct var_expand_table auth_request_var_expand_static_tab[];
 
 struct auth_request *
index 5c2d605ab0c05d877b35420543b6e54253daa7ed..74f642e751ea9e7810cfd52febb1c9b45a6afcf6 100644 (file)
@@ -288,6 +288,15 @@ void auths_init(void)
 {
        struct auth *const *auth;
 
+       /* sanity checks */
+       i_assert(auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_USER_IDX].key == 'u');
+       i_assert(auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_USERNAME_IDX].key == 'n');
+       i_assert(auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_DOMAIN_IDX].key == 'd');
+       i_assert(auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_COUNT].key == '\0' &&
+                auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_COUNT].long_key == NULL);
+       i_assert(auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_COUNT-1].key != '\0' ||
+                auth_request_var_expand_static_tab[AUTH_REQUEST_VAR_TAB_COUNT-1].long_key != NULL);
+
        array_foreach(&auths, auth)
                auth_init(*auth);
 }
index db80d38da339384b354c640fe59d35474109db8e..a8b17c8c4338dc5f796db80be8b36b65f0f518d8 100644 (file)
@@ -6,6 +6,11 @@
 #include "test-common.h"
 
 const struct var_expand_table auth_request_var_expand_static_tab[] = {
+       /* these 3 must be in this order */
+       { 'u', NULL, "user" },
+       { 'n', NULL, "username" },
+       { 'd', NULL, "domain" },
+
        { 'a', NULL, NULL },
        { '\0', NULL, "longb" },
        { 'c', NULL, "longc" },
@@ -24,14 +29,21 @@ static void test_auth_cache_parse_key(void)
        struct {
                const char *in, *out;
        } tests[] = {
+               { "%n@%d", "%u" },
+               { "%{username}@%{domain}", "%u" },
+               { "%n%d%u", "%u" },
+               { "%n", "%n" },
+               { "%d", "%d" },
+               { "%a%b%u", "%u\t%a\t%b" },
+
                { "foo%5.5Mabar", "%a" },
                { "foo%5.5M{longb}bar", "%{longb}" },
                { "foo%5.5Mcbar", "%c" },
-               { "foo%5.5M{longc}bar", "%{longc}" },
+               { "foo%5.5M{longc}bar", "%c" },
                { "%a%b", "%a\t%b" },
                { "%a%{longb}%a", "%a\t%{longb}" },
-               { "%{longc}%c", "%{longc}" },
-               { "%c%a%{longc}%c", "%c\t%a" },
+               { "%{longc}%c", "%c" },
+               { "%c%a%{longc}%c", "%a\t%c" },
                { "%a%{env:foo}%{env:foo}%a", "%a\t%{env:foo}\t%{env:foo}" }
        };
        const char *cache_key;