]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
dict-sql: Use sql_statement_bind_*()
authorTimo Sirainen <timo.sirainen@dovecot.fi>
Tue, 15 Aug 2017 13:38:33 +0000 (16:38 +0300)
committerTimo Sirainen <timo.sirainen@dovecot.fi>
Wed, 27 Sep 2017 10:37:02 +0000 (13:37 +0300)
src/lib-dict/dict-sql.c

index 2be9537aa858f9be56aa033d79d30c0d7e55f717..ebb2d858633c774e5a77a24983a1e44b0e267ae0 100644 (file)
@@ -33,6 +33,16 @@ struct sql_dict {
        unsigned int has_on_duplicate_key:1;
 };
 
+struct sql_dict_param {
+       enum dict_sql_type value_type;
+
+       const char *value_str;
+       int64_t value_int64;
+       const void *value_binary;
+       size_t value_binary_size;
+};
+ARRAY_DEFINE_TYPE(sql_dict_param, struct sql_dict_param);
+
 struct sql_dict_iterate_context {
        struct dict_iterate_context ctx;
        pool_t pool;
@@ -220,39 +230,74 @@ sql_dict_find_map(struct sql_dict *dict, const char *path,
        return NULL;
 }
 
+static void
+sql_dict_statement_bind(struct sql_statement *stmt, unsigned int column_idx,
+                       const struct sql_dict_param *param)
+{
+       switch (param->value_type) {
+       case DICT_SQL_TYPE_STRING:
+               sql_statement_bind_str(stmt, column_idx, param->value_str);
+               break;
+       case DICT_SQL_TYPE_INT:
+       case DICT_SQL_TYPE_UINT:
+               sql_statement_bind_int64(stmt, column_idx, param->value_int64);
+               break;
+       case DICT_SQL_TYPE_HEXBLOB:
+               sql_statement_bind_binary(stmt, column_idx, param->value_binary,
+                                         param->value_binary_size);
+               break;
+       }
+}
+
+static struct sql_statement *
+sql_dict_statement_init(struct sql_db *db, const char *query,
+                       const ARRAY_TYPE(sql_dict_param) *params)
+{
+       struct sql_statement *stmt = sql_statement_init(db, query);
+       const struct sql_dict_param *param;
+
+       array_foreach(params, param) {
+               sql_dict_statement_bind(stmt, array_foreach_idx(params, param),
+                                       param);
+       }
+       return stmt;
+}
+
 static int
-sql_dict_value_escape(string_t *str, struct sql_dict *dict,
-                     const struct dict_sql_map *map,
-                     enum dict_sql_type value_type, const char *field_name,
-                     const char *value, const char *value_suffix,
-                     const char **error_r)
+sql_dict_value_get(const struct dict_sql_map *map,
+                  enum dict_sql_type value_type, const char *field_name,
+                  const char *value, const char *value_suffix,
+                  ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
+       struct sql_dict_param *param;
        buffer_t *buf;
-       int64_t snum;
-       uint64_t num;
+
+       param = array_append_space(params);
+       param->value_type = value_type;
 
        switch (value_type) {
        case DICT_SQL_TYPE_STRING:
-               str_printfa(str, "'%s%s'", sql_escape_string(dict->db, value),
-                           value_suffix);
+               if (value_suffix[0] != '\0')
+                       value = t_strconcat(value, value_suffix, NULL);
+               param->value_str = value;
                return 0;
        case DICT_SQL_TYPE_INT:
-               if (value_suffix[0] != '\0' || str_to_int64(value, &snum) < 0) {
+               if (value_suffix[0] != '\0' ||
+                   str_to_int64(value, &param->value_int64) < 0) {
                        *error_r = t_strdup_printf(
                                "%s field's value isn't 64bit signed integer: %s%s (in pattern: %s)",
                                field_name, value, value_suffix, map->pattern);
                        return -1;
                }
-               str_printfa(str, "%"PRId64, snum);
                return 0;
        case DICT_SQL_TYPE_UINT:
-               if (value_suffix[0] != '\0' || str_to_uint64(value, &num) < 0) {
+               if (value_suffix[0] != '\0' || value[0] == '-' ||
+                   str_to_int64(value, &param->value_int64) < 0) {
                        *error_r = t_strdup_printf(
                                "%s field's value isn't 64bit unsigned integer: %s%s (in pattern: %s)",
                                field_name, value, value_suffix, map->pattern);
                        return -1;
                }
-               str_printfa(str, "%llu", (unsigned long long)num);
                return 0;
        case DICT_SQL_TYPE_HEXBLOB:
                break;
@@ -267,26 +312,28 @@ sql_dict_value_escape(string_t *str, struct sql_dict *dict,
                return -1;
        }
        str_append(buf, value_suffix);
-       str_append(str, sql_escape_blob(dict->db, buf->data, buf->used));
+       param->value_binary = buf->data;
+       param->value_binary_size = buf->used;
        return 0;
 }
 
 static int
-sql_dict_field_escape_value(string_t *str, struct sql_dict *dict,
-                           const struct dict_sql_map *map,
-                           const struct dict_sql_field *field,
-                           const char *value, const char *value_suffix,
-                           const char **error_r)
+sql_dict_field_get_value(const struct dict_sql_map *map,
+                        const struct dict_sql_field *field,
+                        const char *value, const char *value_suffix,
+                        ARRAY_TYPE(sql_dict_param) *params,
+                        const char **error_r)
 {
-       return sql_dict_value_escape(str, dict, map, field->value_type,
-                                    field->name, value, value_suffix, error_r);
+       return sql_dict_value_get(map, field->value_type, field->name,
+                                 value, value_suffix, params, error_r);
 }
 
 static int
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                     const ARRAY_TYPE(const_string) *values_arr,
                     char key1, enum sql_recurse_type recurse_type,
-                    string_t *query, const char **error_r)
+                    string_t *query, ARRAY_TYPE(sql_dict_param) *params,
+                    const char **error_r)
 {
        const struct dict_sql_field *sql_fields;
        const char *const *values;
@@ -309,9 +356,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
        for (i = 0; i < exact_count; i++) {
                if (i > 0)
                        str_append(query, " AND");
-               str_printfa(query, " %s = ", sql_fields[i].name);
-               if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                               values[i], "", error_r) < 0)
+               str_printfa(query, " %s = ?", sql_fields[i].name);
+               if (sql_dict_field_get_value(map, &sql_fields[i], values[i], "",
+                                            params, error_r) < 0)
                        return -1;
        }
        switch (recurse_type) {
@@ -321,13 +368,15 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                if (i > 0)
                        str_append(query, " AND");
                if (i < count2) {
-                       str_printfa(query, " %s LIKE ", sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%", error_r) < 0)
+                       str_printfa(query, " %s LIKE ?", sql_fields[i].name);
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%",
+                                                    params, error_r) < 0)
                                return -1;
-                       str_printfa(query, " AND %s NOT LIKE ", sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%/%", error_r) < 0)
+                       str_printfa(query, " AND %s NOT LIKE ?", sql_fields[i].name);
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%/%",
+                                                    params, error_r) < 0)
                                return -1;
                } else {
                        str_printfa(query, " %s LIKE '%%' AND "
@@ -341,8 +390,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                                str_append(query, " AND");
                        str_printfa(query, " %s LIKE ",
                                    sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%", error_r) < 0)
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%",
+                                                    params, error_r) < 0)
                                return -1;
                }
                break;
@@ -359,7 +409,7 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
 static int
 sql_lookup_get_query(struct sql_dict *dict, const char *key,
                     string_t *query, const struct dict_sql_map **map_r,
-                    const char **error_r)
+                    ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
@@ -374,7 +424,8 @@ sql_lookup_get_query(struct sql_dict *dict, const char *key,
        str_printfa(query, "SELECT %s FROM %s",
                    map->value_field, map->table);
        if (sql_dict_where_build(dict, map, &values, key[0],
-                                SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+                                SQL_DICT_RECURSE_NONE, query,
+                                params, &error) < 0) {
                *error_r = t_strdup_printf(
                        "sql dict lookup: Failed to lookup key %s: %s", key, error);
                return -1;
@@ -446,18 +497,20 @@ static int sql_dict_lookup(struct dict *_dict, pool_t pool,
        const struct dict_sql_map *map;
        struct sql_result *result = NULL;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
        int ret;
 
-       ret = sql_lookup_get_query(dict, key, query, &map, &error);
-       if (ret < 0) {
+       *value_r = NULL;
+
+       t_array_init(&params, 4);
+       if (sql_lookup_get_query(dict, key, query, &map, &params, &error) < 0) {
                i_error("%s", error);
-               *value_r = NULL;
                return -1;
        }
 
        struct sql_statement *stmt =
-               sql_statement_init(dict->db, str_c(query));
+               sql_dict_statement_init(dict->db, str_c(query), &params);
        result = sql_statement_query_s(&stmt);
        ret = sql_result_next_row(result);
        if (ret <= 0) {
@@ -514,9 +567,11 @@ sql_dict_lookup_async(struct dict *_dict, const char *key,
        const struct dict_sql_map *map;
        struct sql_dict_lookup_context *ctx;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
 
-       if (sql_lookup_get_query(dict, key, query, &map, &error) < 0) {
+       t_array_init(&params, 4);
+       if (sql_lookup_get_query(dict, key, query, &map, &params, &error) < 0) {
                struct dict_lookup_result result;
 
                i_zero(&result);
@@ -529,7 +584,7 @@ sql_dict_lookup_async(struct dict *_dict, const char *key,
                ctx->context = context;
                ctx->map = map;
                struct sql_statement *stmt =
-                       sql_statement_init(dict->db, str_c(query));
+                       sql_dict_statement_init(dict->db, str_c(query), &params);
                sql_statement_query(&stmt, sql_dict_lookup_async_callback, ctx);
        }
 }
@@ -570,7 +625,9 @@ sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx,
 
 static int
 sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
-                                 string_t *query, const char **error_r)
+                                 string_t *query,
+                                 ARRAY_TYPE(sql_dict_param) *params,
+                                 const char **error_r)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        const struct dict_sql_map *map;
@@ -622,7 +679,7 @@ sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
                recurse_type = SQL_DICT_RECURSE_ONE;
        if (sql_dict_where_build(dict, map, &values,
                                 ctx->paths[ctx->path_idx][0],
-                                recurse_type, query, error_r) < 0)
+                                recurse_type, query, params, error_r) < 0)
                return -1;
 
        if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
@@ -658,11 +715,13 @@ static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
        unsigned int path_idx = ctx->path_idx;
        int ret;
 
-       ret = sql_dict_iterate_build_next_query(ctx, query, &error);
+       t_array_init(&params, 4);
+       ret = sql_dict_iterate_build_next_query(ctx, query, &params, &error);
        if (ret < 0) {
                /* failed */
                i_error("sql dict iterate failed for %s: %s",
@@ -674,7 +733,7 @@ static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
        }
 
        struct sql_statement *stmt =
-               sql_statement_init(dict->db, str_c(query));
+               sql_dict_statement_init(dict->db, str_c(query), &params);
        if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) {
                ctx->result = sql_statement_query_s(&stmt);
        } else {
@@ -928,10 +987,12 @@ static void sql_dict_transaction_rollback(struct dict_transaction_context *_ctx)
 
 static struct sql_statement *
 sql_dict_transaction_stmt_init(struct sql_dict_transaction_context *ctx,
-                              const char *query)
+                              const char *query,
+                              const ARRAY_TYPE(sql_dict_param) *params)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
-       struct sql_statement *stmt = sql_statement_init(dict->db, query);
+       struct sql_statement *stmt =
+               sql_dict_statement_init(dict->db, query, params);
 
        if (ctx->ctx.timestamp.tv_sec != 0)
                sql_statement_set_timestamp(stmt, &ctx->ctx.timestamp);
@@ -952,7 +1013,9 @@ struct dict_sql_build_query {
 };
 
 static int sql_dict_set_query(const struct dict_sql_build_query *build,
-                             const char **query_r, const char **error_r)
+                             const char **query_r,
+                             ARRAY_TYPE(sql_dict_param) *params,
+                             const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -978,9 +1041,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
                enum dict_sql_type value_type =
                        fields[i].map->value_types[0];
-               if (sql_dict_value_escape(suffix, dict, fields[i].map,
-                                         value_type, "value", fields[i].value,
-                                         "", error_r) < 0)
+               str_append_c(suffix, '?');
+               if (sql_dict_value_get(fields[i].map,
+                                      value_type, "value", fields[i].value,
+                                      "", params, error_r) < 0)
                        return -1;
        }
        if (build->key1 == DICT_PATH_PRIVATE[0]) {
@@ -995,9 +1059,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
        i_assert(count == count2);
        for (i = 0; i < count; i++) {
                str_printfa(prefix, ",%s", sql_fields[i].name);
-               str_append_c(suffix, ',');
-               if (sql_dict_field_escape_value(suffix, dict, fields[0].map, &sql_fields[i],
-                                               extra_values[i], "", error_r) < 0)
+               str_append(suffix, ",?");
+               if (sql_dict_field_get_value(fields[0].map, &sql_fields[i],
+                                            extra_values[i], "",
+                                            params, error_r) < 0)
                        return -1;
        }
 
@@ -1019,9 +1084,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
                enum dict_sql_type value_type =
                        fields[i].map->value_types[0];
-               if (sql_dict_value_escape(prefix, dict, fields[i].map,
-                                         value_type, "value", fields[i].value,
-                                         "", error_r) < 0)
+               str_append_c(prefix, '?');
+               if (sql_dict_value_get(fields[i].map,
+                                      value_type, "value", fields[i].value,
+                                      "", params, error_r) < 0)
                        return -1;
        }
        *query_r = str_c(prefix);
@@ -1030,7 +1096,8 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
 static int
 sql_dict_update_query(const struct dict_sql_build_query *build,
-                     const char **query_r, const char **error_r)
+                     const char **query_r, ARRAY_TYPE(sql_dict_param) *params,
+                     const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -1047,15 +1114,13 @@ sql_dict_update_query(const struct dict_sql_build_query *build,
                        t_strcut(fields[i].map->value_field, ',');
                if (i > 0)
                        str_append_c(query, ',');
-               str_printfa(query, "%s=%s", first_value_field,
+               str_printfa(query, "%s=%s+?", first_value_field,
                            first_value_field);
-               if (fields[i].value[0] != '-')
-                       str_append_c(query, '+');
-               str_append(query, fields[i].value);
        }
 
        if (sql_dict_where_build(dict, fields[0].map, build->extra_values,
-                                build->key1, SQL_DICT_RECURSE_NONE, query, error_r) < 0)
+                                build->key1, SQL_DICT_RECURSE_NONE, query,
+                                params, error_r) < 0)
                return -1;
        *query_r = str_c(query);
        return 0;
@@ -1071,6 +1136,7 @@ static void sql_dict_set_real(struct dict_transaction_context *_ctx,
        ARRAY_TYPE(const_string) values;
        struct dict_sql_build_query build;
        struct dict_sql_build_query_field field;
+       ARRAY_TYPE(sql_dict_param) params;
        const char *query, *error;
 
        map = sql_dict_find_map(dict, key, &values);
@@ -1090,13 +1156,13 @@ static void sql_dict_set_real(struct dict_transaction_context *_ctx,
        build.extra_values = &values;
        build.key1 = key[0];
 
-       if (sql_dict_set_query(&build, &query, &error) < 0) {
+       if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
                i_error("dict-sql: Failed to set %s=%s: %s",
                        key, value, error);
                ctx->failed = TRUE;
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, query);
+                       sql_dict_transaction_stmt_init(ctx, query, &params);
                sql_update_stmt(ctx->sql_ctx, &stmt);
        }
 }
@@ -1110,6 +1176,7 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
 
        if (ctx->prev_inc_map != NULL)
@@ -1125,13 +1192,15 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
        }
 
        str_printfa(query, "DELETE FROM %s", map->table);
+       t_array_init(&params, 4);
        if (sql_dict_where_build(dict, map, &values, key[0],
-                                SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+                                SQL_DICT_RECURSE_NONE, query,
+                                &params, &error) < 0) {
                i_error("dict-sql: Failed to delete %s: %s", key, error);
                ctx->failed = TRUE;
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, str_c(query));
+                       sql_dict_transaction_stmt_init(ctx, str_c(query), &params);
                sql_update_stmt(ctx->sql_ctx, &stmt);
        }
 }
@@ -1171,6 +1240,8 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
        ARRAY_TYPE(const_string) values;
        struct dict_sql_build_query build;
        struct dict_sql_build_query_field field;
+       ARRAY_TYPE(sql_dict_param) params;
+       struct sql_dict_param *param;
        const char *query, *error;
 
        map = sql_dict_find_map(dict, key, &values);
@@ -1186,12 +1257,17 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
        build.extra_values = &values;
        build.key1 = key[0];
 
-       if (sql_dict_update_query(&build, &query, &error) < 0) {
+       t_array_init(&params, 4);
+       param = array_append_space(&params);
+       param->value_type = DICT_SQL_TYPE_INT;
+       param->value_int64 = diff;
+
+       if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
                i_error("dict-sql: Failed to increase %s: %s", key, error);
                ctx->failed = TRUE;
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, query);
+                       sql_dict_transaction_stmt_init(ctx, query, &params);
                sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
                                         sql_dict_next_inc_row(ctx));
        }
@@ -1283,6 +1359,7 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
        } else {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field *field;
+               ARRAY_TYPE(sql_dict_param) params;
                const char *query, *error;
 
                i_zero(&build);
@@ -1298,12 +1375,13 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
                field->map = map;
                field->value = value;
 
-               if (sql_dict_set_query(&build, &query, &error) < 0) {
+               t_array_init(&params, 4);
+               if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
                        i_error("dict-sql: Failed to set %s: %s", key, error);
                        ctx->failed = TRUE;
                } else {
                        struct sql_statement *stmt =
-                               sql_dict_transaction_stmt_init(ctx, query);
+                               sql_dict_transaction_stmt_init(ctx, query, &params);
                        sql_update_stmt(ctx->sql_ctx, &stmt);
                }
                i_free_and_null(ctx->prev_set_value);
@@ -1344,6 +1422,8 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
        } else {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field *field;
+               ARRAY_TYPE(sql_dict_param) params;
+               struct sql_dict_param *param;
                const char *query, *error;
 
                i_zero(&build);
@@ -1359,12 +1439,21 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
                field->map = map;
                field->value = t_strdup_printf("%lld", diff);
 
-               if (sql_dict_update_query(&build, &query, &error) < 0) {
+               t_array_init(&params, 4);
+               param = array_append_space(&params);
+               param->value_type = DICT_SQL_TYPE_INT;
+               param->value_int64 = ctx->prev_inc_diff;
+
+               param = array_append_space(&params);
+               param->value_type = DICT_SQL_TYPE_INT;
+               param->value_int64 = diff;
+
+               if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
                        i_error("dict-sql: Failed to increase %s: %s", key, error);
                        ctx->failed = TRUE;
                } else {
                        struct sql_statement *stmt =
-                               sql_dict_transaction_stmt_init(ctx, query);
+                               sql_dict_transaction_stmt_init(ctx, query, &params);
                        sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
                                                 sql_dict_next_inc_row(ctx));
                }