]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
dict sql: Added configuration for mapping dict paths to SQL fields.
authorTimo Sirainen <tss@iki.fi>
Wed, 27 Aug 2008 09:10:21 +0000 (12:10 +0300)
committerTimo Sirainen <tss@iki.fi>
Wed, 27 Aug 2008 09:10:21 +0000 (12:10 +0300)
--HG--
branch : HEAD

doc/dovecot-dict-sql-example.conf [new file with mode: 0644]
src/dict/Makefile.am
src/lib-dict/Makefile.am
src/lib-dict/dict-sql-settings.c [new file with mode: 0644]
src/lib-dict/dict-sql-settings.h [new file with mode: 0644]
src/lib-dict/dict-sql.c
src/lib-dict/dict.c

diff --git a/doc/dovecot-dict-sql-example.conf b/doc/dovecot-dict-sql-example.conf
new file mode 100644 (file)
index 0000000..674a25f
--- /dev/null
@@ -0,0 +1,39 @@
+#connect = host=localhost dbname=mails user=testuser password=pass
+
+# CREATE TABLE quota (
+#   username varchar(100) not null,
+#   bytes bigint not null default 0,
+#   messages integer not null default 0,
+#   primary key (username)
+# );
+
+map {
+  pattern = priv/quota/storage
+  table = quota
+  username_field = username
+  value_field = bytes
+}
+map {
+  pattern = priv/quota/messages
+  table = quota
+  username_field = username
+  value_field = messages
+}
+
+# CREATE TABLE expires (
+#   username varchar(100) not null,
+#   mailbox varchar(255) not null,
+#   expire_stamp integer not null,
+#   primary key (username, mailbox)
+# );
+
+map {
+  pattern = shared/expire/$user/$mailbox
+  table = expires
+  value_field = expire_stamp
+
+  fields {
+    username = $user
+    mailbox = $mailbox
+  }
+}
index 660b3045443f753d06fe15ee59a53bdb16dc16f6..69ca29cea575bec35bb6ac6c72d32c82d7cfeb70 100644 (file)
@@ -15,6 +15,7 @@ libs = \
        ../lib-dict/libdict_backend.a \
        ../lib-dict/libdict.a \
        ../lib-sql/libsql.a \
+       ../lib-settings/libsettings.a \
        ../lib/liblib.a
 
 dict_LDADD = \
index fd657bc2df7cad06e03a7769eda2b12f44075183..1b2dc8aaf14599369d5a3983836805fdc038d11d 100644 (file)
@@ -5,6 +5,7 @@ dict_drivers = @dict_drivers@
 AM_CPPFLAGS = \
        -I$(top_srcdir)/src/lib \
        -I$(top_srcdir)/src/lib-sql \
+       -I$(top_srcdir)/src/lib-settings \
        -DPKG_RUNDIR=\""$(rundir)"\" \
        $(SQL_CFLAGS)
 
@@ -14,7 +15,8 @@ base_sources = \
 
 backend_sources = \
        dict-db.c \
-       dict-sql.c
+       dict-sql.c \
+       dict-sql-settings.c
 
 libdict_a_SOURCES = \
        $(base_sources)
@@ -27,7 +29,8 @@ headers = \
        dict.h \
        dict-client.h \
        dict-private.h \
-       dict-sql.h
+       dict-sql.h \
+       dict-sql-settings.h
 
 if INSTALL_HEADERS
   pkginc_libdir=$(pkgincludedir)/src/lib-dict
diff --git a/src/lib-dict/dict-sql-settings.c b/src/lib-dict/dict-sql-settings.c
new file mode 100644 (file)
index 0000000..f071efc
--- /dev/null
@@ -0,0 +1,215 @@
+/* Copyright (c) 2008 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "array.h"
+#include "str.h"
+#include "settings.h"
+#include "dict-sql-settings.h"
+
+#include <ctype.h>
+
+enum section_type {
+       SECTION_ROOT = 0,
+       SECTION_MAP,
+       SECTION_FIELDS
+};
+
+struct dict_sql_map_field {
+       const char *sql_field, *variable;
+};
+
+struct setting_parser_ctx {
+       pool_t pool;
+       struct dict_sql_settings *set;
+       enum section_type type;
+
+       struct dict_sql_map cur_map;
+       ARRAY_DEFINE(cur_fields, struct dict_sql_map_field);
+};
+
+#define DEF_STR(name) DEF_STRUCT_STR(name, dict_sql_map)
+
+static struct setting_def dict_sql_map_setting_defs[] = {
+       DEF_STR(pattern),
+       DEF_STR(table),
+       DEF_STR(username_field),
+       DEF_STR(value_field),
+
+       { 0, NULL, 0 }
+};
+
+static const char *pattern_read_name(const char **pattern)
+{
+       const char *p = *pattern, *name;
+
+       if (*p == '{') {
+               /* ${name} */
+               name = ++p;
+               p = strchr(p, '}');
+               if (p == NULL) {
+                       /* error, but allow anyway */
+                       *pattern += strlen(*pattern);
+                       return "";
+               }
+               *pattern = p + 1;
+       } else {
+               /* $name - ends at the first non-alnum_ character */
+               name = p;
+               for (; *p != '\0'; p++) {
+                       if (!i_isalnum(*p) && *p != '_')
+                               break;
+               }
+               *pattern = p;
+       }
+       name = t_strdup_until(name, p);
+       return name;
+}
+
+static const char *dict_sql_fields_map(struct setting_parser_ctx *ctx)
+{
+       struct dict_sql_map_field *fields;
+       string_t *pattern;
+       const char *p, *name;
+       unsigned int i, count;
+
+       p_array_init(&ctx->cur_map.sql_fields, ctx->pool, count);
+
+       /* go through the variables in the pattern, replace them with plain
+          '$' character and add its sql field */
+       pattern = t_str_new(strlen(ctx->cur_map.pattern) + 1);
+       fields = array_get_modifiable(&ctx->cur_fields, &count);
+       for (p = ctx->cur_map.pattern; *p != '\0';) {
+               if (*p != '$') {
+                       str_append_c(pattern, *p);
+                       p++;
+                       continue;
+               }
+               p++;
+               str_append_c(pattern, '$');
+
+               name = pattern_read_name(&p);
+               for (i = 0; i < count; i++) {
+                       if (fields[i].variable != NULL &&
+                           strcmp(fields[i].variable, name) == 0)
+                               break;
+               }
+               if (i == count) {
+                       return t_strconcat("Missing SQL field for variable: ",
+                                          name, NULL);
+               }
+
+               /* mark this field as used */
+               fields[i].variable = NULL;
+               array_append(&ctx->cur_map.sql_fields,
+                            &fields[i].sql_field, 1);
+       }
+
+       /* make sure there aren't any unused fields */
+       for (i = 0; i < count; i++) {
+               if (fields[i].variable != NULL) {
+                       return t_strconcat("Unused variable: ",
+                                          fields[i].variable, NULL);
+               }
+       }
+
+       if (ctx->set->max_field_count < count)
+               ctx->set->max_field_count = count;
+       ctx->cur_map.pattern = p_strdup(ctx->pool, str_c(pattern));
+       return NULL;
+}
+
+static const char *dict_sql_map_finish(struct setting_parser_ctx *ctx)
+{
+       if (!array_is_created(&ctx->cur_map.sql_fields)) {
+               /* no fields besides value. allocate the array anyway. */
+               p_array_init(&ctx->cur_map.sql_fields, ctx->pool, 1);
+               if (strchr(ctx->cur_map.pattern, '$') != NULL)
+                       return "Missing fields for pattern variables";
+       }
+       array_append(&ctx->set->maps, &ctx->cur_map, 1);
+       memset(&ctx->cur_map, 0, sizeof(ctx->cur_map));
+       return NULL;
+}
+
+static const char *
+parse_setting(const char *key, const char *value,
+             struct setting_parser_ctx *ctx)
+{
+       struct dict_sql_map_field *field;
+
+       switch (ctx->type) {
+       case SECTION_ROOT:
+               if (strcmp(key, "connect") == 0) {
+                       ctx->set->connect = p_strdup(ctx->pool, value);
+                       return NULL;
+               }
+               break;
+       case SECTION_MAP:
+               return parse_setting_from_defs(ctx->pool,
+                                              dict_sql_map_setting_defs,
+                                              &ctx->cur_map, key, value);
+       case SECTION_FIELDS:
+               if (*value != '$') {
+                       return t_strconcat("Value is missing '$' for field: ",
+                                          key, NULL);
+               }
+               field = array_append_space(&ctx->cur_fields);
+               field->sql_field = p_strdup(ctx->pool, key);
+               field->variable = p_strdup(ctx->pool, value + 1);
+               return NULL;
+       }
+       return t_strconcat("Unknown setting: ", key, NULL);
+}
+
+static bool
+parse_section(const char *type, const char *name ATTR_UNUSED,
+             struct setting_parser_ctx *ctx, const char **error_r)
+{
+       switch (ctx->type) {
+       case SECTION_ROOT:
+               if (type == NULL)
+                       return FALSE;
+               if (strcmp(type, "map") == 0) {
+                       array_clear(&ctx->cur_fields);
+                       ctx->type = SECTION_MAP;
+                       return TRUE;
+               }
+               break;
+       case SECTION_MAP:
+               if (type == NULL) {
+                       ctx->type = SECTION_ROOT;
+                       *error_r = dict_sql_map_finish(ctx);
+                       return FALSE;
+               }
+               if (strcmp(type, "fields") == 0) {
+                       ctx->type = SECTION_FIELDS;
+                       return TRUE;
+               }
+               break;
+       case SECTION_FIELDS:
+               if (type == NULL) {
+                       ctx->type = SECTION_MAP;
+                       *error_r = dict_sql_fields_map(ctx);
+                       return FALSE;
+               }
+               break;
+       }
+       *error_r = t_strconcat("Unknown section: ", type, NULL);
+       return FALSE;
+}
+
+struct dict_sql_settings *dict_sql_settings_read(pool_t pool, const char *path)
+{
+       struct setting_parser_ctx ctx;
+
+       memset(&ctx, 0, sizeof(ctx));
+       ctx.pool = pool;
+       ctx.set = p_new(pool, struct dict_sql_settings, 1);
+       t_array_init(&ctx.cur_fields, 16);
+       p_array_init(&ctx.set->maps, pool, 8);
+
+       if (!settings_read(path, NULL, parse_setting, parse_section, &ctx))
+               return NULL;
+
+       return ctx.set;
+}
diff --git a/src/lib-dict/dict-sql-settings.h b/src/lib-dict/dict-sql-settings.h
new file mode 100644 (file)
index 0000000..d173a2f
--- /dev/null
@@ -0,0 +1,24 @@
+#ifndef DICT_SQL_SETTINGS_H
+#define DICT_SQL_SETTINGS_H
+
+struct dict_sql_map {
+       /* pattern is in simplified form: all variables are stored as simple
+          '$' character. fields array is sorted by the variable index. */
+       const char *pattern;
+       const char *table;
+       const char *username_field;
+       const char *value_field;
+
+       ARRAY_TYPE(const_string) sql_fields;
+};
+
+struct dict_sql_settings {
+       const char *connect;
+
+       unsigned int max_field_count;
+       ARRAY_DEFINE(maps, struct dict_sql_map);
+};
+
+struct dict_sql_settings *dict_sql_settings_read(pool_t pool, const char *path);
+
+#endif
index acda729f665f03ada98833112d31d4e909baeea7..c670c900710fb5ec69514d398bcf36a82c50bb8d 100644 (file)
@@ -7,6 +7,7 @@
 #include "sql-api-private.h"
 #include "sql-pool.h"
 #include "dict-private.h"
+#include "dict-sql-settings.h"
 #include "dict-sql.h"
 
 #include <unistd.h>
@@ -19,19 +20,22 @@ struct sql_dict {
 
        pool_t pool;
        struct sql_db *db;
-
-       const char *connect_string, *username;
-       const char *table, *select_field, *where_field, *username_field;
+       const char *username;
+       const struct dict_sql_settings *set;
+       unsigned int prev_map_match_idx;
 
        unsigned int has_on_duplicate_key:1;
 };
 
 struct sql_dict_iterate_context {
        struct dict_iterate_context ctx;
+       enum dict_iterate_flags flags;
+       char *path;
 
        struct sql_result *result;
-       char *prev_key;
-       bool priv;
+       string_t *key;
+       const struct dict_sql_map *map;
+       unsigned int key_prefix_len, next_map_idx;
 };
 
 struct sql_dict_transaction_context {
@@ -45,77 +49,6 @@ struct sql_dict_transaction_context {
 
 static struct sql_pool *dict_sql_pool;
 
-static void sql_dict_config_parse_line(struct sql_dict *dict, const char *line)
-{
-       const char *p, *value;
-
-       while (*line == ' ') line++;
-       value = strchr(line, '=');
-       if (value == NULL)
-               return;
-
-       for (p = value; p[-1] == ' ' && p != line; p--) ;
-       line = t_strdup_until(line, p);
-       value++;
-       while (*value == ' ') value++;
-
-       if (strcmp(line, "connect") == 0)
-               dict->connect_string = p_strdup(dict->pool, value);
-       else if (strcmp(line, "table") == 0)
-               dict->table = p_strdup(dict->pool, value);
-       else if (strcmp(line, "select_field") == 0)
-               dict->select_field = p_strdup(dict->pool, value);
-       else if (strcmp(line, "where_field") == 0)
-               dict->where_field = p_strdup(dict->pool, value);
-       else if (strcmp(line, "username_field") == 0)
-               dict->username_field = p_strdup(dict->pool, value);
-}
-
-static int sql_dict_read_config(struct sql_dict *dict, const char *path)
-{
-       struct istream *input;
-       const char *line;
-       int fd;
-
-       fd = open(path, O_RDONLY);
-       if (fd == -1) {
-               i_error("open(%s) failed: %m", path);
-               return -1;
-       }
-
-       input = i_stream_create_fd(fd, (size_t)-1, FALSE);
-       while ((line = i_stream_read_next_line(input)) != NULL) {
-               T_BEGIN {
-                       sql_dict_config_parse_line(dict, line);
-               } T_END;
-       }
-       i_stream_destroy(&input);
-       (void)close(fd);
-
-       if (dict->connect_string == NULL) {
-               i_error("%s: 'connect' missing", path);
-               return -1;
-       }
-       if (dict->table == NULL) {
-               i_error("%s: 'table' missing", path);
-               return -1;
-       }
-       if (dict->select_field == NULL) {
-               i_error("%s: 'select_field' missing", path);
-               return -1;
-       }
-       if (dict->where_field == NULL) {
-               i_error("%s: 'where_field' missing", path);
-               return -1;
-       }
-       if (dict->username_field == NULL) {
-               i_error("%s: 'username_field' missing", path);
-               return -1;
-       }
-
-       return 0;
-}
-
 static struct dict *
 sql_dict_init(struct dict *driver, const char *uri,
              enum dict_data_type value_type ATTR_UNUSED,
@@ -129,8 +62,8 @@ sql_dict_init(struct dict *driver, const char *uri,
        dict->pool = pool;
        dict->dict = *driver;
        dict->username = p_strdup(pool, username);
-
-       if (sql_dict_read_config(dict, uri) < 0) {
+       dict->set = dict_sql_settings_read(pool, uri);
+       if (dict->set == NULL) {
                pool_unref(&pool);
                return NULL;
        }
@@ -139,7 +72,7 @@ sql_dict_init(struct dict *driver, const char *uri,
        dict->has_on_duplicate_key = strcmp(driver->name, "mysql") == 0;
 
        dict->db = sql_pool_new(dict_sql_pool, driver->name,
-                               dict->connect_string);
+                               dict->set->connect);
        return &dict->dict;
 }
 
@@ -151,58 +84,135 @@ static void sql_dict_deinit(struct dict *_dict)
        pool_unref(&dict->pool);
 }
 
-static int sql_path_fix(const char **path, bool *private_r)
+static bool
+dict_sql_map_match(const struct dict_sql_map *map, const char *path,
+                  ARRAY_TYPE(const_string) *values, bool partial_ok)
 {
-       const char *p;
-       size_t len;
+       const char *pat, *field, *p;
+
+       array_clear(values);
+       pat = map->pattern;
+       while (*pat != '\0' && *path != '\0') {
+               if (*pat == '$') {
+                       /* variable */
+                       pat++;
+                       if (*pat == '\0') {
+                               /* pattern ended with this variable,
+                                  it'll match the rest of the path */
+                               array_append(values, &path, 1);
+                               return TRUE;
+                       }
+                       /* pattern matches until the next '/' in path */
+                       p = strchr(path, '/');
+                       if (p == NULL)
+                               return FALSE;
+                       field = t_strdup_until(path, p);
+                       array_append(values, &field, 1);
+                       path = p;
+               } else if (*pat == *path) {
+                       pat++;
+                       path++;
+               } else {
+                       return FALSE;
+               }
+       }
+       if (*pat == '\0')
+               return *path == '\0';
+       else if (!partial_ok)
+               return FALSE;
+       else {
+               /* partial matches must end with '/' */
+               return pat == map->pattern || pat[-1] == '/';
+       }
+}
 
-       p = strchr(*path, '/');
-       if (p == NULL)
-               return -1;
-       len = p - *path;
+static const struct dict_sql_map *
+sql_dict_find_map(struct sql_dict *dict, const char *path,
+                 ARRAY_TYPE(const_string) *values)
+{
+       const struct dict_sql_map *maps;
+       unsigned int i, idx, count;
 
-       if (strncmp(*path, DICT_PATH_PRIVATE, len) == 0)
-               *private_r = TRUE;
-       else if (strncmp(*path, DICT_PATH_SHARED, len) == 0)
-               *private_r = FALSE;
-       else
-               return -1;
+       t_array_init(values, dict->set->max_field_count);
+       maps = array_get(&dict->set->maps, &count);
+       for (i = 0; i < count; i++) {
+               /* start matching from the previously successful match */
+               idx = (dict->prev_map_match_idx + i) % count;
+               if (dict_sql_map_match(&maps[idx], path, values, FALSE)) {
+                       dict->prev_map_match_idx = idx;
+                       return &maps[idx];
+               }
+       }
+       return NULL;
+}
+
+static void
+sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
+                    const ARRAY_TYPE(const_string) *values_arr,
+                    const char *key, string_t *query)
+{
+       const char *const *sql_fields, *const *values;
+       unsigned int i, count, count2;
+       bool priv = *key == DICT_PATH_PRIVATE[0];
+
+       sql_fields = array_get(&map->sql_fields, &count);
+       values = array_get(values_arr, &count2);
+       /* if we came here from iteration code there may be less values */
+       i_assert(count2 <= count);
+
+       if (count2 == 0 && !priv) {
+               /* we want everything */
+               return;
+       }
 
-       *path += len + 1;
-       return 0;
+       str_append(query, "WHERE");
+       for (i = 0; i < count2; i++) {
+               if (i > 0)
+                       str_append(query, " AND");
+               str_printfa(query, " %s = '%s'", sql_fields[i],
+                           sql_escape_string(dict->db, values[i]));
+       }
+       if (priv) {
+               if (count2 > 0)
+                       str_append(query, " AND");
+               str_printfa(query, " %s = '%s'", map->username_field,
+                           sql_escape_string(dict->db, dict->username));
+       }
 }
 
 static int sql_dict_lookup(struct dict *_dict, pool_t pool,
                           const char *key, const char **value_r)
 {
        struct sql_dict *dict = (struct sql_dict *)_dict;
+       const struct dict_sql_map *map;
+       ARRAY_TYPE(const_string) values;
        struct sql_result *result;
        int ret;
-       bool priv;
 
-       if (sql_path_fix(&key, &priv) < 0) {
+       map = sql_dict_find_map(dict, key, &values);
+       if (map == NULL) {
+               i_error("sql dict lookup: Invalid/unmapped key: %s", key);
                *value_r = NULL;
-               return -1;
+               return 0;
        }
 
        T_BEGIN {
                string_t *query = t_str_new(256);
-               str_printfa(query, "SELECT %s FROM %s WHERE %s = '%s'",
-                           dict->select_field, dict->table,
-                           dict->where_field,
-                           sql_escape_string(dict->db, key));
-               if (priv) {
-                       str_printfa(query, " AND %s = '%s'",
-                                   dict->username_field,
-                                   sql_escape_string(dict->db, dict->username));
-               }
+
+               str_printfa(query, "SELECT %s FROM %s ",
+                           map->value_field, map->table);
+               sql_dict_where_build(dict, map, &values, key, query);
                result = sql_query_s(dict->db, str_c(query));
        } T_END;
 
        ret = sql_result_next_row(result);
-       if (ret <= 0)
+       if (ret <= 0) {
+               if (ret < 0) {
+                       i_error("dict sql lookup failed: %s",
+                               sql_result_get_error(result));
+               }
                *value_r = NULL;
-       else {
+       else {
                *value_r =
                        p_strdup(pool, sql_result_get_field_value(result, 0));
        }
@@ -211,64 +221,85 @@ static int sql_dict_lookup(struct dict *_dict, pool_t pool,
        return ret;
 }
 
+static const struct dict_sql_map *
+sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx,
+                              ARRAY_TYPE(const_string) *values)
+{
+       struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
+       const struct dict_sql_map *maps;
+       unsigned int i, count;
+
+       t_array_init(values, dict->set->max_field_count);
+       maps = array_get(&dict->set->maps, &count);
+       for (i = ctx->next_map_idx; i < count; i++) {
+               if (dict_sql_map_match(&maps[i], ctx->path, values, TRUE) &&
+                   ((ctx->flags & DICT_ITERATE_FLAG_RECURSE) != 0 ||
+                    array_count(values)+1 == array_count(&maps[i].sql_fields))) {
+                       ctx->next_map_idx = i + 1;
+                       return &maps[i];
+               }
+       }
+       return NULL;
+}
+
+static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
+{
+       struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
+       const struct dict_sql_map *map;
+       ARRAY_TYPE(const_string) values;
+       const char *const *sql_fields;
+       unsigned int i, count;
+
+       map = sql_dict_iterate_find_next_map(ctx, &values);
+       if (map == NULL)
+               return FALSE;
+
+       T_BEGIN {
+               string_t *query = t_str_new(256);
+
+               str_printfa(query, "SELECT %s", map->value_field);
+               /* get all missing fields */
+               sql_fields = array_get(&map->sql_fields, &count);
+               for (i = array_count(&values); i < count; i++)
+                       str_printfa(query, ",%s", sql_fields[i]);
+               str_printfa(query, " FROM %s ", map->table);
+               sql_dict_where_build(dict, map, &values, ctx->path, query);
+
+               if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
+                       str_append(query, "ORDER BY ");
+                       for (i = array_count(&values); i < count; i++) {
+                               str_printfa(query, "%s", sql_fields[i]);
+                               if (i < count-1)
+                                       str_append_c(query, ',');
+                       }
+               } else if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0)
+                       str_printfa(query, "ORDER BY %s", map->value_field);
+               ctx->result = sql_query_s(dict->db, str_c(query));
+       } T_END;
+
+       ctx->map = map;
+       return TRUE;
+}
+
 static struct dict_iterate_context *
 sql_dict_iterate_init(struct dict *_dict, const char *path, 
                      enum dict_iterate_flags flags)
 {
-       struct sql_dict *dict = (struct sql_dict *)_dict;
        struct sql_dict_iterate_context *ctx;
-       unsigned int len;
-       bool priv;
 
        ctx = i_new(struct sql_dict_iterate_context, 1);
        ctx->ctx.dict = _dict;
-
-       if (sql_path_fix(&path, &priv) < 0) {
+       ctx->path = i_strdup(path);
+       ctx->flags = flags;
+       ctx->key = str_new(default_pool, 256);
+       str_append(ctx->key, path);
+       ctx->key_prefix_len = str_len(ctx->key);
+
+       if (!sql_dict_iterate_next_query(ctx)) {
+               i_error("sql dict iterate: Invalid/unmapped path: %s", path);
                ctx->result = NULL;
                return &ctx->ctx;
        }
-       ctx->priv = priv;
-
-       T_BEGIN {
-               string_t *query = t_str_new(256);
-               str_printfa(query, "SELECT %s, %s FROM %s WHERE ",
-                           dict->where_field, dict->select_field,
-                           dict->table);
-               len = str_len(query);
-
-               if (*path != '\0') {
-                       str_printfa(query, "%s LIKE '%s/%%' AND ",
-                                   dict->where_field,
-                                   sql_escape_string(dict->db, path));
-               }
-               if (priv) {
-                       str_printfa(query, "%s = '%s' AND ",
-                                   dict->username_field,
-                                   sql_escape_string(dict->db,
-                                                     dict->username));
-               }
-               if ((flags & DICT_ITERATE_FLAG_RECURSE) != 0) {
-                       /* get everything */
-               } else if (*path == '\0') {
-                       str_printfa(query, "%s NOT LIKE '%%/%%' AND ",
-                                   dict->where_field);
-               } else {
-                       str_printfa(query, "%s NOT LIKE '%s/%%/%%' AND ",
-                                   dict->where_field,
-                                   sql_escape_string(dict->db, path));
-               }
-               if (str_len(query) == len)
-                       str_truncate(query, str_len(query) - 6);
-               else
-                       str_truncate(query, str_len(query) - 4);
-
-               if ((flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0)
-                       str_printfa(query, "ORDER BY %s", dict->where_field);
-               else if ((flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0)
-                       str_printfa(query, "ORDER BY %s", dict->select_field);
-               ctx->result = sql_query_s(dict->db, str_c(query));
-       } T_END;
-
        return &ctx->ctx;
 }
 
@@ -277,24 +308,41 @@ static int sql_dict_iterate(struct dict_iterate_context *_ctx,
 {
        struct sql_dict_iterate_context *ctx =
                (struct sql_dict_iterate_context *)_ctx;
-       const char *key;
+       const char *p;
+       unsigned int i, count;
        int ret;
 
        if (ctx->result == NULL)
                return -1;
 
-       if ((ret = sql_result_next_row(ctx->result)) <= 0)
+       while ((ret = sql_result_next_row(ctx->result)) == 0) {
+               /* see if there are more results in the next map */
+               if (!sql_dict_iterate_next_query(ctx))
+                       return 0;
+       }
+       if (ret < 0) {
+               i_error("dict sql iterate failed: %s",
+                       sql_result_get_error(ctx->result));
                return ret;
+       }
 
-       key = sql_result_get_field_value(ctx->result, 0);
-       i_free(ctx->prev_key);
-       if (ctx->priv)
-               ctx->prev_key = i_strconcat(DICT_PATH_PRIVATE, key, NULL);
-       else
-               ctx->prev_key = i_strconcat(DICT_PATH_SHARED, key, NULL);
+       /* convert fetched row to dict key */
+       str_truncate(ctx->key, ctx->key_prefix_len);
+       count = sql_result_get_fields_count(ctx->result);
+       i = 1;
+       for (p = ctx->map->pattern + ctx->key_prefix_len; *p != '\0'; p++) {
+               if (*p != '$')
+                       str_append_c(ctx->key, *p);
+               else {
+                       i_assert(i < count);
+                       str_append(ctx->key,
+                                  sql_result_get_field_value(ctx->result, i));
+                       i++;
+               }
+       }
 
-       *key_r = ctx->prev_key;
-       *value_r = sql_result_get_field_value(ctx->result, 1);
+       *key_r = str_c(ctx->key);
+       *value_r = sql_result_get_field_value(ctx->result, 0);
        return 1;
 }
 
@@ -304,7 +352,8 @@ static void sql_dict_iterate_deinit(struct dict_iterate_context *_ctx)
                (struct sql_dict_iterate_context *)_ctx;
 
        sql_result_free(ctx->result);
-       i_free(ctx->prev_key);
+       str_free(&ctx->key);
+       i_free(ctx->path);
        i_free(ctx);
 }
 
@@ -354,32 +403,51 @@ static void sql_dict_transaction_rollback(struct dict_transaction_context *_ctx)
 }
 
 static const char *
-sql_dict_set_query(struct sql_dict *dict, const char *key, const char *value,
-                  bool priv)
+sql_dict_set_query(struct sql_dict *dict, const struct dict_sql_map *map,
+                  const ARRAY_TYPE(const_string) *values_arr,
+                  const char *key, const char *value, bool inc)
 {
-       string_t *str;
-
-       str = t_str_new(256);
-       if (priv) {
-               str_printfa(str, "INSERT INTO %s (%s, %s, %s) "
-                           "VALUES ('%s', '%s', '%s')",
-                           dict->table, dict->select_field, dict->where_field,
-                           dict->username_field,
-                           sql_escape_string(dict->db, value),
-                           sql_escape_string(dict->db, key),
+       const char *const *sql_fields, *const *values;
+       unsigned int i, count, count2;
+       string_t *prefix, *suffix;
+
+       prefix = t_str_new(64);
+       suffix = t_str_new(256);
+       str_printfa(prefix, "INSERT INTO %s (%s", map->table, map->value_field);
+       str_append(suffix, ") VALUES (");
+       if (inc)
+               str_append(suffix, value);
+       else
+               str_printfa(suffix, "'%s'", sql_escape_string(dict->db, value));
+       if (*key == DICT_PATH_PRIVATE[0]) {
+               str_printfa(prefix, ",%s", map->username_field);
+               str_printfa(suffix, ",'%s'",
                            sql_escape_string(dict->db, dict->username));
-       } else {
-               str_printfa(str, "INSERT INTO %s (%s, %s) VALUES ('%s', '%s')",
-                           dict->table, dict->select_field, dict->where_field,
-                           sql_escape_string(dict->db, value),
-                           sql_escape_string(dict->db, key));
        }
+
+       /* add the other fields from the key */
+       sql_fields = array_get(&map->sql_fields, &count);
+       values = array_get(values_arr, &count2);
+       i_assert(count == count2);
+       for (i = 0; i < count; i++) {
+               str_printfa(prefix, ",%s", sql_fields[i]);
+               str_printfa(suffix, ",'%s'",
+                           sql_escape_string(dict->db, values[i]));
+       }
+
+       str_append_str(prefix, suffix);
+       str_append_c(prefix, ')');
        if (dict->has_on_duplicate_key) {
-               str_printfa(str, " ON DUPLICATE KEY UPDATE %s = '%s'",
-                           dict->select_field,
-                           sql_escape_string(dict->db, value));
+               str_printfa(prefix, " ON DUPLICATE KEY UPDATE %s =",
+                           map->value_field);
+               if (inc)
+                       str_printfa(prefix, "%s+%s", map->value_field, value);
+               else {
+                       str_printfa(prefix, "'%s'",
+                                   sql_escape_string(dict->db, value));
+               }
        }
-       return str_c(str);
+       return str_c(prefix);
 }
 
 static void sql_dict_set(struct dict_transaction_context *_ctx,
@@ -388,10 +456,12 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
        struct sql_dict_transaction_context *ctx =
                (struct sql_dict_transaction_context *)_ctx;
        struct sql_dict *dict = (struct sql_dict *)_ctx->dict;
-       bool priv;
+       const struct dict_sql_map *map;
+       ARRAY_TYPE(const_string) values;
 
-       if (sql_path_fix(&key, &priv) < 0) {
-               i_error("sql dict: Invalid key: %s", key);
+       map = sql_dict_find_map(dict, key, &values);
+       if (map == NULL) {
+               i_error("sql dict set: Invalid/unmapped key: %s", key);
                ctx->failed = TRUE;
                return;
        }
@@ -399,87 +469,49 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
        T_BEGIN {
                const char *query;
 
-               query = sql_dict_set_query(dict, key, value, priv);
+               query = sql_dict_set_query(dict, map, &values, key, value,
+                                          FALSE);
                sql_update(ctx->sql_ctx, query);
        } T_END;
 }
 
-static const char *
-sql_dict_unset_query(struct sql_dict *dict, const char *key, bool priv)
-{
-       if (priv) {
-               return t_strdup_printf(
-                       "DELETE FROM %s WHERE %s = '%s' AND %s = '%s'",
-                       dict->table, dict->where_field,
-                       sql_escape_string(dict->db, key),
-                       dict->username_field,
-                       sql_escape_string(dict->db, dict->username));
-       } else {
-               return t_strdup_printf(
-                       "DELETE FROM %s WHERE %s = '%s'",
-                       dict->table, dict->where_field,
-                       sql_escape_string(dict->db, key));
-       }
-}
-
 static void sql_dict_unset(struct dict_transaction_context *_ctx,
                           const char *key)
 {
        struct sql_dict_transaction_context *ctx =
                (struct sql_dict_transaction_context *)_ctx;
        struct sql_dict *dict = (struct sql_dict *)_ctx->dict;
-       bool priv;
+       const struct dict_sql_map *map;
+       ARRAY_TYPE(const_string) values;
 
-       if (sql_path_fix(&key, &priv) < 0) {
-               i_error("sql dict: Invalid key: %s", key);
+       map = sql_dict_find_map(dict, key, &values);
+       if (map == NULL) {
+               i_error("sql dict unset: Invalid/unmapped key: %s", key);
                ctx->failed = TRUE;
                return;
        }
 
        T_BEGIN {
-               const char *query;
+               string_t *query = t_str_new(256);
 
-               query = sql_dict_unset_query(dict, key, priv);
-               sql_update(ctx->sql_ctx, query);
+               str_printfa(query, "DELETE FROM %s ", map->table);
+               sql_dict_where_build(dict, map, &values, key, query);
+               sql_update(ctx->sql_ctx, str_c(query));
        } T_END;
 }
 
-static const char *
-sql_dict_atomic_inc_query(struct sql_dict *dict, const char *key,
-                         long long diff, bool priv)
-{
-       string_t *str;
-
-       str = t_str_new(256);
-       if (priv) {
-               str_printfa(str, "INSERT INTO %s (%s, %s, %s) "
-                           "VALUES (%lld, '%s', '%s')",
-                           dict->table, dict->select_field, dict->where_field,
-                           dict->username_field,
-                           diff, sql_escape_string(dict->db, key),
-                           sql_escape_string(dict->db, dict->username));
-       } else {
-               str_printfa(str, "INSERT INTO %s (%s, %s) VALUES (%lld, '%s')",
-                           dict->table, dict->select_field, dict->where_field,
-                           diff, sql_escape_string(dict->db, key));
-       }
-       if (dict->has_on_duplicate_key) {
-               str_printfa(str, " ON DUPLICATE KEY UPDATE %s = %s + %lld",
-                           dict->select_field, dict->select_field, diff);
-       }
-       return str_c(str);
-}
-
 static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
                                const char *key, long long diff)
 {
        struct sql_dict_transaction_context *ctx =
                (struct sql_dict_transaction_context *)_ctx;
        struct sql_dict *dict = (struct sql_dict *)_ctx->dict;
-       bool priv;
+       const struct dict_sql_map *map;
+       ARRAY_TYPE(const_string) values;
 
-       if (sql_path_fix(&key, &priv) < 0) {
-               i_error("sql dict: Invalid key: %s", key);
+       map = sql_dict_find_map(dict, key, &values);
+       if (map == NULL) {
+               i_error("sql dict atomic inc: Invalid/unmapped key: %s", key);
                ctx->failed = TRUE;
                return;
        }
@@ -487,7 +519,8 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
        T_BEGIN {
                const char *query;
 
-               query = sql_dict_atomic_inc_query(dict, key, diff, priv);
+               query = sql_dict_set_query(dict, map, &values, key,
+                                          dec2str(diff), TRUE);
                sql_update(ctx->sql_ctx, query);
        } T_END;
 }
index c6c9ba0106b22062771972284ff045279141b9b0..23cfb889edb1fcda31379655c1ab3791433b9e57 100644 (file)
@@ -84,9 +84,16 @@ void dict_deinit(struct dict **_dict)
        dict->v.deinit(dict);
 }
 
+static bool dict_key_prefix_is_valid(const char *key)
+{
+       return strncmp(key, DICT_PATH_SHARED, strlen(DICT_PATH_SHARED)) == 0 ||
+               strncmp(key, DICT_PATH_PRIVATE, strlen(DICT_PATH_PRIVATE)) == 0;
+}
+
 int dict_lookup(struct dict *dict, pool_t pool, const char *key,
                const char **value_r)
 {
+       i_assert(dict_key_prefix_is_valid(key));
        return dict->v.lookup(dict, pool, key, value_r);
 }
 
@@ -94,6 +101,7 @@ struct dict_iterate_context *
 dict_iterate_init(struct dict *dict, const char *path, 
                  enum dict_iterate_flags flags)
 {
+       i_assert(dict_key_prefix_is_valid(path));
        return dict->v.iterate_init(dict, path, flags);
 }
 
@@ -135,6 +143,8 @@ void dict_transaction_rollback(struct dict_transaction_context **_ctx)
 void dict_set(struct dict_transaction_context *ctx,
              const char *key, const char *value)
 {
+       i_assert(dict_key_prefix_is_valid(key));
+
        ctx->dict->v.set(ctx, key, value);
        ctx->changed = TRUE;
 }
@@ -142,6 +152,8 @@ void dict_set(struct dict_transaction_context *ctx,
 void dict_unset(struct dict_transaction_context *ctx,
                const char *key)
 {
+       i_assert(dict_key_prefix_is_valid(key));
+
        ctx->dict->v.unset(ctx, key);
        ctx->changed = TRUE;
 }
@@ -149,6 +161,8 @@ void dict_unset(struct dict_transaction_context *ctx,
 void dict_atomic_inc(struct dict_transaction_context *ctx,
                     const char *key, long long diff)
 {
+       i_assert(dict_key_prefix_is_valid(key));
+
        if (diff != 0) {
                ctx->dict->v.atomic_inc(ctx, key, diff);
                ctx->changed = TRUE;