]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
dict-sql: Dictionary iteration fixes.
authorTimo Sirainen <tss@iki.fi>
Sun, 12 Oct 2008 10:36:54 +0000 (13:36 +0300)
committerTimo Sirainen <tss@iki.fi>
Sun, 12 Oct 2008 10:36:54 +0000 (13:36 +0300)
--HG--
branch : HEAD

src/lib-dict/dict-sql.c

index c670c900710fb5ec69514d398bcf36a82c50bb8d..2188ffec9c8867752dbdb9f55157d84795aceb3b 100644 (file)
 
 #define DICT_SQL_MAX_UNUSED_CONNECTIONS 10
 
+enum sql_recurse_type {
+       SQL_DICT_RECURSE_NONE,
+       SQL_DICT_RECURSE_ONE,
+       SQL_DICT_RECURSE_FULL
+};
+
 struct sql_dict {
        struct dict dict;
 
@@ -35,7 +41,7 @@ struct sql_dict_iterate_context {
        struct sql_result *result;
        string_t *key;
        const struct dict_sql_map *map;
-       unsigned int key_prefix_len, next_map_idx;
+       unsigned int key_prefix_len, pattern_prefix_len, next_map_idx;
 };
 
 struct sql_dict_transaction_context {
@@ -86,9 +92,12 @@ static void sql_dict_deinit(struct dict *_dict)
 
 static bool
 dict_sql_map_match(const struct dict_sql_map *map, const char *path,
-                  ARRAY_TYPE(const_string) *values, bool partial_ok)
+                  ARRAY_TYPE(const_string) *values, unsigned int *pat_len_r,
+                  unsigned int *path_len_r, bool partial_ok)
 {
+       const char *path_start = path;
        const char *pat, *field, *p;
+       unsigned int len;
 
        array_clear(values);
        pat = map->pattern;
@@ -99,16 +108,39 @@ dict_sql_map_match(const struct dict_sql_map *map, const char *path,
                        if (*pat == '\0') {
                                /* pattern ended with this variable,
                                   it'll match the rest of the path */
-                               array_append(values, &path, 1);
+                               len = strlen(path);
+                               if (partial_ok) {
+                                       /* iterating - the last field never
+                                          matches fully. if there's a trailing
+                                          '/', drop it. */
+                                       pat--;
+                                       if (path[len-1] == '/') {
+                                               field = t_strndup(path, len-1);
+                                               array_append(values, &field, 1);
+                                       } else {
+                                               array_append(values, &path, 1);
+                                       }
+                               } else {
+                                       array_append(values, &path, 1);
+                                       path += len;
+                               }
+                               *path_len_r = path - path_start;
+                               *pat_len_r = pat - map->pattern;
                                return TRUE;
                        }
                        /* pattern matches until the next '/' in path */
                        p = strchr(path, '/');
-                       if (p == NULL)
-                               return FALSE;
-                       field = t_strdup_until(path, p);
-                       array_append(values, &field, 1);
-                       path = p;
+                       if (p != NULL) {
+                               field = t_strdup_until(path, p);
+                               array_append(values, &field, 1);
+                               path = p;
+                       } else {
+                               /* no '/' anymore, but it'll still match a
+                                  partial */
+                               array_append(values, &path, 1);
+                               path += strlen(path);
+                               pat++;
+                       }
                } else if (*pat == *path) {
                        pat++;
                        path++;
@@ -122,6 +154,8 @@ dict_sql_map_match(const struct dict_sql_map *map, const char *path,
                return FALSE;
        else {
                /* partial matches must end with '/' */
+               *path_len_r = path - path_start;
+               *pat_len_r = pat - map->pattern;
                return pat == map->pattern || pat[-1] == '/';
        }
 }
@@ -131,14 +165,15 @@ sql_dict_find_map(struct sql_dict *dict, const char *path,
                  ARRAY_TYPE(const_string) *values)
 {
        const struct dict_sql_map *maps;
-       unsigned int i, idx, count;
+       unsigned int i, idx, count, len;
 
        t_array_init(values, dict->set->max_field_count);
        maps = array_get(&dict->set->maps, &count);
        for (i = 0; i < count; i++) {
                /* start matching from the previously successful match */
                idx = (dict->prev_map_match_idx + i) % count;
-               if (dict_sql_map_match(&maps[idx], path, values, FALSE)) {
+               if (dict_sql_map_match(&maps[idx], path, values,
+                                      &len, &len, FALSE)) {
                        dict->prev_map_match_idx = idx;
                        return &maps[idx];
                }
@@ -149,10 +184,11 @@ sql_dict_find_map(struct sql_dict *dict, const char *path,
 static void
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                     const ARRAY_TYPE(const_string) *values_arr,
-                    const char *key, string_t *query)
+                    const char *key, enum sql_recurse_type recurse_type,
+                    string_t *query)
 {
        const char *const *sql_fields, *const *values;
-       unsigned int i, count, count2;
+       unsigned int i, count, count2, exact_count;
        bool priv = *key == DICT_PATH_PRIVATE[0];
 
        sql_fields = array_get(&map->sql_fields, &count);
@@ -165,13 +201,43 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                return;
        }
 
-       str_append(query, "WHERE");
-       for (i = 0; i < count2; i++) {
+       str_append(query, " WHERE");
+       exact_count = count == count2 && recurse_type != SQL_DICT_RECURSE_NONE ?
+               count2-1 : count2;
+       for (i = 0; i < exact_count; i++) {
                if (i > 0)
                        str_append(query, " AND");
                str_printfa(query, " %s = '%s'", sql_fields[i],
                            sql_escape_string(dict->db, values[i]));
        }
+       switch (recurse_type) {
+       case SQL_DICT_RECURSE_NONE:
+               break;
+       case SQL_DICT_RECURSE_ONE:
+               if (i > 0)
+                       str_append(query, " AND");
+               if (i < count2) {
+                       str_printfa(query, " %s LIKE '%s/%%' AND "
+                                   "%s NOT LIKE '%s/%%/%%'",
+                                   sql_fields[i],
+                                   sql_escape_string(dict->db, values[i]),
+                                   sql_fields[i],
+                                   sql_escape_string(dict->db, values[i]));
+               } else {
+                       str_printfa(query, " %s LIKE '%%' AND "
+                                   "%s NOT LIKE '%%/%%'",
+                                   sql_fields[i], sql_fields[i]);
+               }
+               break;
+       case SQL_DICT_RECURSE_FULL:
+               if (i < count2) {
+                       if (i > 0)
+                               str_append(query, " AND");
+                       str_printfa(query, " %s LIKE '%s/%%'", sql_fields[i],
+                                   sql_escape_string(dict->db, values[i]));
+               }
+               break;
+       }
        if (priv) {
                if (count2 > 0)
                        str_append(query, " AND");
@@ -199,9 +265,10 @@ static int sql_dict_lookup(struct dict *_dict, pool_t pool,
        T_BEGIN {
                string_t *query = t_str_new(256);
 
-               str_printfa(query, "SELECT %s FROM %s ",
+               str_printfa(query, "SELECT %s FROM %s",
                            map->value_field, map->table);
-               sql_dict_where_build(dict, map, &values, key, query);
+               sql_dict_where_build(dict, map, &values, key,
+                                    SQL_DICT_RECURSE_NONE, query);
                result = sql_query_s(dict->db, str_c(query));
        } T_END;
 
@@ -227,14 +294,17 @@ sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx,
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        const struct dict_sql_map *maps;
-       unsigned int i, count;
+       unsigned int i, count, pat_len, path_len;
 
        t_array_init(values, dict->set->max_field_count);
        maps = array_get(&dict->set->maps, &count);
        for (i = ctx->next_map_idx; i < count; i++) {
-               if (dict_sql_map_match(&maps[i], ctx->path, values, TRUE) &&
+               if (dict_sql_map_match(&maps[i], ctx->path,
+                                      values, &pat_len, &path_len, TRUE) &&
                    ((ctx->flags & DICT_ITERATE_FLAG_RECURSE) != 0 ||
-                    array_count(values)+1 == array_count(&maps[i].sql_fields))) {
+                    array_count(values)+1 >= array_count(&maps[i].sql_fields))) {
+                       ctx->key_prefix_len = path_len;
+                       ctx->pattern_prefix_len = pat_len;
                        ctx->next_map_idx = i + 1;
                        return &maps[i];
                }
@@ -248,6 +318,7 @@ static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
        const char *const *sql_fields;
+       enum sql_recurse_type recurse_type;
        unsigned int i, count;
 
        map = sql_dict_iterate_find_next_map(ctx, &values);
@@ -260,20 +331,31 @@ static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
                str_printfa(query, "SELECT %s", map->value_field);
                /* get all missing fields */
                sql_fields = array_get(&map->sql_fields, &count);
-               for (i = array_count(&values); i < count; i++)
+               i = array_count(&values);
+               if (i == count) {
+                       /* we always want to know the last field since we're
+                          iterating its children */
+                       i_assert(i > 0);
+                       i--;
+               }
+               for (; i < count; i++)
                        str_printfa(query, ",%s", sql_fields[i]);
-               str_printfa(query, " FROM %s ", map->table);
-               sql_dict_where_build(dict, map, &values, ctx->path, query);
+               str_printfa(query, " FROM %s", map->table);
+
+               recurse_type = (ctx->flags & DICT_ITERATE_FLAG_RECURSE) == 0 ?
+                       SQL_DICT_RECURSE_ONE : SQL_DICT_RECURSE_FULL;
+               sql_dict_where_build(dict, map, &values, ctx->path,
+                                    recurse_type, query);
 
                if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
-                       str_append(query, "ORDER BY ");
+                       str_append(query, " ORDER BY ");
                        for (i = array_count(&values); i < count; i++) {
                                str_printfa(query, "%s", sql_fields[i]);
                                if (i < count-1)
                                        str_append_c(query, ',');
                        }
                } else if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0)
-                       str_printfa(query, "ORDER BY %s", map->value_field);
+                       str_printfa(query, " ORDER BY %s", map->value_field);
                ctx->result = sql_query_s(dict->db, str_c(query));
        } T_END;
 
@@ -292,8 +374,7 @@ sql_dict_iterate_init(struct dict *_dict, const char *path,
        ctx->path = i_strdup(path);
        ctx->flags = flags;
        ctx->key = str_new(default_pool, 256);
-       str_append(ctx->key, path);
-       ctx->key_prefix_len = str_len(ctx->key);
+       str_append(ctx->key, ctx->path);
 
        if (!sql_dict_iterate_next_query(ctx)) {
                i_error("sql dict iterate: Invalid/unmapped path: %s", path);
@@ -328,9 +409,13 @@ static int sql_dict_iterate(struct dict_iterate_context *_ctx,
 
        /* convert fetched row to dict key */
        str_truncate(ctx->key, ctx->key_prefix_len);
+       if (ctx->key_prefix_len > 0 &&
+           str_c(ctx->key)[ctx->key_prefix_len-1] != '/')
+               str_append_c(ctx->key, '/');
+
        count = sql_result_get_fields_count(ctx->result);
        i = 1;
-       for (p = ctx->map->pattern + ctx->key_prefix_len; *p != '\0'; p++) {
+       for (p = ctx->map->pattern + ctx->pattern_prefix_len; *p != '\0'; p++) {
                if (*p != '$')
                        str_append_c(ctx->key, *p);
                else {
@@ -351,7 +436,8 @@ static void sql_dict_iterate_deinit(struct dict_iterate_context *_ctx)
        struct sql_dict_iterate_context *ctx =
                (struct sql_dict_iterate_context *)_ctx;
 
-       sql_result_free(ctx->result);
+       if (ctx->result != NULL)
+               sql_result_free(ctx->result);
        str_free(&ctx->key);
        i_free(ctx->path);
        i_free(ctx);
@@ -494,8 +580,9 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
        T_BEGIN {
                string_t *query = t_str_new(256);
 
-               str_printfa(query, "DELETE FROM %s ", map->table);
-               sql_dict_where_build(dict, map, &values, key, query);
+               str_printfa(query, "DELETE FROM %s", map->table);
+               sql_dict_where_build(dict, map, &values, key,
+                                    SQL_DICT_RECURSE_NONE, query);
                sql_update(ctx->sql_ctx, str_c(query));
        } T_END;
 }