]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
dict-sql: If value isn't a hexblob as expected, log an error instead of killing the...
authorTimo Sirainen <tss@iki.fi>
Fri, 25 Sep 2015 15:16:56 +0000 (18:16 +0300)
committerTimo Sirainen <tss@iki.fi>
Fri, 25 Sep 2015 15:16:56 +0000 (18:16 +0300)
src/lib-dict/dict-sql.c

index 649bb5b0f7d89aae91944f0ec35e4bca9bfaca13..93cfd3c5cb9fd46b4d1df76f970b03fc2d3a5b1a 100644 (file)
@@ -211,44 +211,48 @@ sql_dict_find_map(struct sql_dict *dict, const char *path,
        return NULL;
 }
 
-static void
+static int
 sql_dict_value_escape(string_t *str, struct sql_dict *dict,
                      bool value_is_hexblob, const char *field_name,
-                     const char *value, const char *value_suffix)
+                     const char *value, const char *value_suffix,
+                     const char **error_r)
 {
        buffer_t *buf;
 
        if (!value_is_hexblob) {
                str_printfa(str, "'%s%s'", sql_escape_string(dict->db, value),
                            value_suffix);
-               return;
+               return 0;
        }
 
        buf = buffer_create_dynamic(pool_datastack_create(), strlen(value)/2);
        if (hex_to_binary(value, buf) < 0) {
                /* we shouldn't get untrusted input here. it's also a bit
                   annoying to handle this error. */
-               i_fatal("dict-sql: field %s value isn't hexblob: %s",
-                       field_name, value);
+               *error_r = t_strdup_printf("field %s value isn't hexblob: %s",
+                                          field_name, value);
+               return -1;
        }
        str_append(buf, value_suffix);
        str_append(str, sql_escape_blob(dict->db, buf->data, buf->used));
+       return 0;
 }
 
-static void
+static int
 sql_dict_field_escape_value(string_t *str, struct sql_dict *dict,
                            const struct dict_sql_field *field,
-                           const char *value, const char *value_suffix)
+                           const char *value, const char *value_suffix,
+                           const char **error_r)
 {
-       sql_dict_value_escape(str, dict, field->value_is_hexblob,
-                             field->name, value, value_suffix);
+       return sql_dict_value_escape(str, dict, field->value_is_hexblob,
+                                    field->name, value, value_suffix, error_r);
 }
 
-static void
+static int
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                     const ARRAY_TYPE(const_string) *values_arr,
                     char key1, enum sql_recurse_type recurse_type,
-                    string_t *query)
+                    string_t *query, const char **error_r)
 {
        const struct dict_sql_field *sql_fields;
        const char *const *values;
@@ -262,7 +266,7 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
 
        if (count2 == 0 && !priv) {
                /* we want everything */
-               return;
+               return 0;
        }
 
        str_append(query, " WHERE");
@@ -272,7 +276,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                if (i > 0)
                        str_append(query, " AND");
                str_printfa(query, " %s = ", sql_fields[i].name);
-               sql_dict_field_escape_value(query, dict, &sql_fields[i], values[i], "");
+               if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+                                               values[i], "", error_r) < 0)
+                       return -1;
        }
        switch (recurse_type) {
        case SQL_DICT_RECURSE_NONE:
@@ -282,11 +288,13 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                        str_append(query, " AND");
                if (i < count2) {
                        str_printfa(query, " %s LIKE ", sql_fields[i].name);
-                       sql_dict_field_escape_value(query, dict, &sql_fields[i],
-                                                   values[i], "/%");
+                       if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+                                                       values[i], "/%", error_r) < 0)
+                               return -1;
                        str_printfa(query, " AND %s NOT LIKE ", sql_fields[i].name);
-                       sql_dict_field_escape_value(query, dict, &sql_fields[i],
-                                                   values[i], "/%/%");
+                       if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+                                                       values[i], "/%/%", error_r) < 0)
+                               return -1;
                } else {
                        str_printfa(query, " %s LIKE '%%' AND "
                                    "%s NOT LIKE '%%/%%'",
@@ -299,8 +307,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                                str_append(query, " AND");
                        str_printfa(query, " %s LIKE ",
                                    sql_fields[i].name);
-                       sql_dict_field_escape_value(query, dict, &sql_fields[i],
-                                                   values[i], "/%");
+                       if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+                                                       values[i], "/%", error_r) < 0)
+                               return -1;
                }
                break;
        }
@@ -310,6 +319,7 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                str_printfa(query, " %s = '%s'", map->username_field,
                            sql_escape_string(dict->db, dict->username));
        }
+       return 0;
 }
 
 static int
@@ -318,6 +328,7 @@ sql_lookup_get_query(struct sql_dict *dict, const char *key,
 {
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
+       const char *error;
 
        map = *map_r = sql_dict_find_map(dict, key, &values);
        if (map == NULL) {
@@ -326,8 +337,11 @@ sql_lookup_get_query(struct sql_dict *dict, const char *key,
        }
        str_printfa(query, "SELECT %s FROM %s",
                    map->value_field, map->table);
-       sql_dict_where_build(dict, map, &values, key[0],
-                            SQL_DICT_RECURSE_NONE, query);
+       if (sql_dict_where_build(dict, map, &values, key[0],
+                                SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+               i_error("sql dict lookup: Failed to lookup key %s: %s", key, error);
+               return -1;
+       }
        return 0;
 }
 
@@ -483,9 +497,9 @@ sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx,
        return NULL;
 }
 
-static bool
+static int
 sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
-                                 string_t *query)
+                                 string_t *query, const char **error_r)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        const struct dict_sql_map *map;
@@ -495,8 +509,10 @@ sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
        unsigned int i, count;
 
        map = sql_dict_iterate_find_next_map(ctx, &values);
-       if (map == NULL)
-               return FALSE;
+       if (map == NULL) {
+               *error_r = "Invalid/unmapped path";
+               return 0;
+       }
 
        if (ctx->result != NULL) {
                sql_result_unref(ctx->result);
@@ -529,9 +545,10 @@ sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
                recurse_type = SQL_DICT_RECURSE_NONE;
        else
                recurse_type = SQL_DICT_RECURSE_ONE;
-       sql_dict_where_build(dict, map, &values,
-                            ctx->paths[ctx->path_idx][0],
-                            recurse_type, query);
+       if (sql_dict_where_build(dict, map, &values,
+                                ctx->paths[ctx->path_idx][0],
+                                recurse_type, query, error_r) < 0)
+               return -1;
 
        if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
                str_append(query, " ORDER BY ");
@@ -544,7 +561,7 @@ sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
                str_printfa(query, " ORDER BY %s", map->value_field);
 
        ctx->map = map;
-       return TRUE;
+       return 1;
 }
 
 static void sql_dict_iterate_callback(struct sql_result *result,
@@ -556,17 +573,20 @@ static void sql_dict_iterate_callback(struct sql_result *result,
                ctx->ctx.async_callback(ctx->ctx.async_context);
 }
 
-static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
+static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx,
+                                      const char **error_r)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
-       bool ret;
+       char *error = NULL;
+       int ret;
 
        T_BEGIN {
                string_t *query = t_str_new(256);
 
-               ret = sql_dict_iterate_build_next_query(ctx, query);
-               if (!ret) {
+               ret = sql_dict_iterate_build_next_query(ctx, query, error_r);
+               if (ret <= 0) {
                        /* failed */
+                       error = i_strdup(*error_r);
                } else if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) {
                        ctx->result = sql_query_s(dict->db, str_c(query));
                } else {
@@ -575,6 +595,8 @@ static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
                                  sql_dict_iterate_callback, ctx);
                }
        } T_END;
+       *error_r = t_strdup(error);
+       i_free(error);
        return ret;
 }
 
@@ -584,6 +606,7 @@ sql_dict_iterate_init(struct dict *_dict, const char *const *paths,
 {
        struct sql_dict_iterate_context *ctx;
        unsigned int i, path_count;
+       const char *error;
        pool_t pool;
 
        pool = pool_alloconly_create("sql dict iterate", 512);
@@ -598,9 +621,9 @@ sql_dict_iterate_init(struct dict *_dict, const char *const *paths,
                ctx->paths[i] = p_strdup(pool, paths[i]);
 
        ctx->key = str_new(pool, 256);
-       if (!sql_dict_iterate_next_query(ctx)) {
-               i_error("sql dict iterate: Invalid/unmapped path: %s",
-                       paths[0]);
+       if (sql_dict_iterate_next_query(ctx, &error) <= 0) {
+               i_error("sql dict iterate failed for %s: %s",
+                       paths[0], error);
                ctx->result = NULL;
                ctx->failed = TRUE;
                return &ctx->ctx;
@@ -613,7 +636,7 @@ static bool sql_dict_iterate(struct dict_iterate_context *_ctx,
 {
        struct sql_dict_iterate_context *ctx =
                (struct sql_dict_iterate_context *)_ctx;
-       const char *p, *value;
+       const char *p, *value, *error;
        unsigned int i, sql_field_i, count;
        int ret;
 
@@ -635,8 +658,9 @@ static bool sql_dict_iterate(struct dict_iterate_context *_ctx,
                /* see if there are more results in the next map.
                   don't do it if we're looking for an exact match, since we
                   already should have handled it. */
-               if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0 ||
-                   !sql_dict_iterate_next_query(ctx))
+               if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0)
+                       return FALSE;
+               if ((ret = sql_dict_iterate_next_query(ctx, &error)) == 0)
                        return FALSE;
        }
        if (ret < 0) {
@@ -806,7 +830,8 @@ struct dict_sql_build_query {
        bool inc;
 };
 
-static const char *sql_dict_set_query(const struct dict_sql_build_query *build)
+static int sql_dict_set_query(const struct dict_sql_build_query *build,
+                             const char **query_r, const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -831,9 +856,10 @@ static const char *sql_dict_set_query(const struct dict_sql_build_query *build)
                if (build->inc)
                        str_append(suffix, fields[i].value);
                else {
-                       sql_dict_value_escape(suffix, dict,
+                       if (sql_dict_value_escape(suffix, dict,
                                fields[i].map->value_hexblob,
-                               "value", fields[i].value, "");
+                               "value", fields[i].value, "", error_r) < 0)
+                               return -1;
                }
        }
        if (build->key1 == DICT_PATH_PRIVATE[0]) {
@@ -849,14 +875,17 @@ static const char *sql_dict_set_query(const struct dict_sql_build_query *build)
        for (i = 0; i < count; i++) {
                str_printfa(prefix, ",%s", sql_fields[i].name);
                str_append_c(suffix, ',');
-               sql_dict_field_escape_value(suffix, dict, &sql_fields[i],
-                                           extra_values[i], "");
+               if (sql_dict_field_escape_value(suffix, dict, &sql_fields[i],
+                                               extra_values[i], "", error_r) < 0)
+                       return -1;
        }
 
        str_append_str(prefix, suffix);
        str_append_c(prefix, ')');
-       if (!dict->has_on_duplicate_key)
-               return str_c(prefix);
+       if (!dict->has_on_duplicate_key) {
+               *query_r = str_c(prefix);
+               return 0;
+       }
 
        str_append(prefix, " ON DUPLICATE KEY UPDATE ");
        for (i = 0; i < field_count; i++) {
@@ -869,16 +898,19 @@ static const char *sql_dict_set_query(const struct dict_sql_build_query *build)
                                    fields[i].map->value_field,
                                    fields[i].value);
                } else {
-                       sql_dict_value_escape(prefix, dict,
+                       if (sql_dict_value_escape(prefix, dict,
                                fields[i].map->value_hexblob,
-                               "value", fields[i].value, "");
+                               "value", fields[i].value, "", error_r) < 0)
+                               return -1;
                }
        }
-       return str_c(prefix);
+       *query_r = str_c(prefix);
+       return 0;
 }
 
-static const char *
-sql_dict_update_query(const struct dict_sql_build_query *build)
+static int
+sql_dict_update_query(const struct dict_sql_build_query *build,
+                     const char **query_r, const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -902,9 +934,11 @@ sql_dict_update_query(const struct dict_sql_build_query *build)
                str_append(query, fields[i].value);
        }
 
-       sql_dict_where_build(dict, fields[0].map, build->extra_values,
-                            build->key1, SQL_DICT_RECURSE_NONE, query);
-       return str_c(query);
+       if (sql_dict_where_build(dict, fields[0].map, build->extra_values,
+                                build->key1, SQL_DICT_RECURSE_NONE, query, error_r) < 0)
+               return -1;
+       *query_r = str_c(query);
+       return 0;
 }
 
 static void sql_dict_set(struct dict_transaction_context *_ctx,
@@ -929,7 +963,7 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
        T_BEGIN {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field field;
-               const char *query;
+               const char *query, *error;
 
                field.map = map;
                field.value = value;
@@ -941,8 +975,13 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
                build.extra_values = &values;
                build.key1 = key[0];
 
-               query = sql_dict_set_query(&build);
-               sql_update(ctx->sql_ctx, query);
+               if (sql_dict_set_query(&build, &query, &error) < 0) {
+                       i_error("dict-sql: Failed to set %s=%s: %s",
+                               key, value, error);
+                       ctx->failed = TRUE;
+               } else {
+                       sql_update(ctx->sql_ctx, query);
+               }
        } T_END;
 }
 
@@ -967,11 +1006,16 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
 
        T_BEGIN {
                string_t *query = t_str_new(256);
+               const char *error;
 
                str_printfa(query, "DELETE FROM %s", map->table);
-               sql_dict_where_build(dict, map, &values, key[0],
-                                    SQL_DICT_RECURSE_NONE, query);
-               sql_update(ctx->sql_ctx, str_c(query));
+               if (sql_dict_where_build(dict, map, &values, key[0],
+                                        SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+                       i_error("dict-sql: Failed to delete %s: %s", key, error);
+                       ctx->failed = TRUE;
+               } else {
+                       sql_update(ctx->sql_ctx, str_c(query));
+               }
        } T_END;
 }
 
@@ -1015,6 +1059,7 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
        T_BEGIN {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field field;
+               const char *query, *error;
 
                field.map = map;
                field.value = t_strdup_printf("%lld", diff);
@@ -1027,8 +1072,13 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
                build.key1 = key[0];
                build.inc = TRUE;
 
-               sql_update_get_rows(ctx->sql_ctx, sql_dict_update_query(&build),
-                                   sql_dict_next_inc_row(ctx));
+               if (sql_dict_update_query(&build, &query, &error) < 0) {
+                       i_error("dict-sql: Failed to increase %s: %s", key, error);
+                       ctx->failed = TRUE;
+               } else {
+                       sql_update_get_rows(ctx->sql_ctx, query,
+                                           sql_dict_next_inc_row(ctx));
+               }
        } T_END;
 }
 
@@ -1107,6 +1157,7 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
        } else T_BEGIN {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field *field;
+               const char *query, *error;
 
                memset(&build, 0, sizeof(build));
                build.dict = dict;
@@ -1122,8 +1173,13 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
                field->map = map;
                field->value = t_strdup_printf("%lld", diff);
 
-               sql_update_get_rows(ctx->sql_ctx, sql_dict_update_query(&build),
-                                   sql_dict_next_inc_row(ctx));
+               if (sql_dict_update_query(&build, &query, &error) < 0) {
+                       i_error("dict-sql: Failed to increase %s: %s", key, error);
+                       ctx->failed = TRUE;
+               } else {
+                       sql_update_get_rows(ctx->sql_ctx, query,
+                                           sql_dict_next_inc_row(ctx));
+               }
 
                i_free_and_null(ctx->prev_inc_key);
                ctx->prev_inc_map = NULL;