]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
dict-sql: Use sql_statement_bind_*()
authorTimo Sirainen <timo.sirainen@dovecot.fi>
Tue, 15 Aug 2017 13:38:33 +0000 (16:38 +0300)
committerAki Tuomi <aki.tuomi@dovecot.fi>
Fri, 8 Sep 2017 10:18:32 +0000 (13:18 +0300)
src/lib-dict/dict-sql.c

index f4a690bfff07d0873368101ae82a188f4de970ff..d9677cc9387d4a8af60437a7b306528315e39ff6 100644 (file)
@@ -23,6 +23,16 @@ enum sql_recurse_type {
        SQL_DICT_RECURSE_FULL
 };
 
+struct sql_dict_param {
+       enum dict_sql_type value_type;
+
+       const char *value_str;
+       int64_t value_int64;
+       const void *value_binary;
+       size_t value_binary_size;
+};
+ARRAY_DEFINE_TYPE(sql_dict_param, struct sql_dict_param);
+
 struct sql_dict_iterate_context {
        struct dict_iterate_context ctx;
        pool_t pool;
@@ -209,39 +219,74 @@ sql_dict_find_map(struct sql_dict *dict, const char *path,
        return NULL;
 }
 
+static void
+sql_dict_statement_bind(struct sql_statement *stmt, unsigned int column_idx,
+                       const struct sql_dict_param *param)
+{
+       switch (param->value_type) {
+       case DICT_SQL_TYPE_STRING:
+               sql_statement_bind_str(stmt, column_idx, param->value_str);
+               break;
+       case DICT_SQL_TYPE_INT:
+       case DICT_SQL_TYPE_UINT:
+               sql_statement_bind_int64(stmt, column_idx, param->value_int64);
+               break;
+       case DICT_SQL_TYPE_HEXBLOB:
+               sql_statement_bind_binary(stmt, column_idx, param->value_binary,
+                                         param->value_binary_size);
+               break;
+       }
+}
+
+static struct sql_statement *
+sql_dict_statement_init(struct sql_db *db, const char *query,
+                       const ARRAY_TYPE(sql_dict_param) *params)
+{
+       struct sql_statement *stmt = sql_statement_init(db, query);
+       const struct sql_dict_param *param;
+
+       array_foreach(params, param) {
+               sql_dict_statement_bind(stmt, array_foreach_idx(params, param),
+                                       param);
+       }
+       return stmt;
+}
+
 static int
-sql_dict_value_escape(string_t *str, struct sql_dict *dict,
-                     const struct dict_sql_map *map,
-                     enum dict_sql_type value_type, const char *field_name,
-                     const char *value, const char *value_suffix,
-                     const char **error_r)
+sql_dict_value_get(const struct dict_sql_map *map,
+                  enum dict_sql_type value_type, const char *field_name,
+                  const char *value, const char *value_suffix,
+                  ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
+       struct sql_dict_param *param;
        buffer_t *buf;
-       int64_t snum;
-       uint64_t num;
+
+       param = array_append_space(params);
+       param->value_type = value_type;
 
        switch (value_type) {
        case DICT_SQL_TYPE_STRING:
-               str_printfa(str, "'%s%s'", sql_escape_string(dict->db, value),
-                           value_suffix);
+               if (value_suffix[0] != '\0')
+                       value = t_strconcat(value, value_suffix, NULL);
+               param->value_str = value;
                return 0;
        case DICT_SQL_TYPE_INT:
-               if (value_suffix[0] != '\0' || str_to_int64(value, &snum) < 0) {
+               if (value_suffix[0] != '\0' ||
+                   str_to_int64(value, &param->value_int64) < 0) {
                        *error_r = t_strdup_printf(
                                "%s field's value isn't 64bit signed integer: %s%s (in pattern: %s)",
                                field_name, value, value_suffix, map->pattern);
                        return -1;
                }
-               str_printfa(str, "%"PRId64, snum);
                return 0;
        case DICT_SQL_TYPE_UINT:
-               if (value_suffix[0] != '\0' || str_to_uint64(value, &num) < 0) {
+               if (value_suffix[0] != '\0' || value[0] == '-' ||
+                   str_to_int64(value, &param->value_int64) < 0) {
                        *error_r = t_strdup_printf(
                                "%s field's value isn't 64bit unsigned integer: %s%s (in pattern: %s)",
                                field_name, value, value_suffix, map->pattern);
                        return -1;
                }
-               str_printfa(str, "%"PRIu64, num);
                return 0;
        case DICT_SQL_TYPE_HEXBLOB:
                break;
@@ -256,26 +301,28 @@ sql_dict_value_escape(string_t *str, struct sql_dict *dict,
                return -1;
        }
        str_append(buf, value_suffix);
-       str_append(str, sql_escape_blob(dict->db, buf->data, buf->used));
+       param->value_binary = buf->data;
+       param->value_binary_size = buf->used;
        return 0;
 }
 
 static int
-sql_dict_field_escape_value(string_t *str, struct sql_dict *dict,
-                           const struct dict_sql_map *map,
-                           const struct dict_sql_field *field,
-                           const char *value, const char *value_suffix,
-                           const char **error_r)
+sql_dict_field_get_value(const struct dict_sql_map *map,
+                        const struct dict_sql_field *field,
+                        const char *value, const char *value_suffix,
+                        ARRAY_TYPE(sql_dict_param) *params,
+                        const char **error_r)
 {
-       return sql_dict_value_escape(str, dict, map, field->value_type,
-                                    field->name, value, value_suffix, error_r);
+       return sql_dict_value_get(map, field->value_type, field->name,
+                                 value, value_suffix, params, error_r);
 }
 
 static int
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                     const ARRAY_TYPE(const_string) *values_arr,
                     char key1, enum sql_recurse_type recurse_type,
-                    string_t *query, const char **error_r)
+                    string_t *query, ARRAY_TYPE(sql_dict_param) *params,
+                    const char **error_r)
 {
        const struct dict_sql_field *sql_fields;
        const char *const *values;
@@ -298,9 +345,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
        for (i = 0; i < exact_count; i++) {
                if (i > 0)
                        str_append(query, " AND");
-               str_printfa(query, " %s = ", sql_fields[i].name);
-               if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                               values[i], "", error_r) < 0)
+               str_printfa(query, " %s = ?", sql_fields[i].name);
+               if (sql_dict_field_get_value(map, &sql_fields[i], values[i], "",
+                                            params, error_r) < 0)
                        return -1;
        }
        switch (recurse_type) {
@@ -310,13 +357,15 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                if (i > 0)
                        str_append(query, " AND");
                if (i < count2) {
-                       str_printfa(query, " %s LIKE ", sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%", error_r) < 0)
+                       str_printfa(query, " %s LIKE ?", sql_fields[i].name);
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%",
+                                                    params, error_r) < 0)
                                return -1;
-                       str_printfa(query, " AND %s NOT LIKE ", sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%/%", error_r) < 0)
+                       str_printfa(query, " AND %s NOT LIKE ?", sql_fields[i].name);
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%/%",
+                                                    params, error_r) < 0)
                                return -1;
                } else {
                        str_printfa(query, " %s LIKE '%%' AND "
@@ -330,8 +379,9 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
                                str_append(query, " AND");
                        str_printfa(query, " %s LIKE ",
                                    sql_fields[i].name);
-                       if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-                                                       values[i], "/%", error_r) < 0)
+                       if (sql_dict_field_get_value(map, &sql_fields[i],
+                                                    values[i], "/%",
+                                                    params, error_r) < 0)
                                return -1;
                }
                break;
@@ -348,7 +398,7 @@ sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
 static int
 sql_lookup_get_query(struct sql_dict *dict, const char *key,
                     string_t *query, const struct dict_sql_map **map_r,
-                    const char **error_r)
+                    ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
@@ -363,7 +413,8 @@ sql_lookup_get_query(struct sql_dict *dict, const char *key,
        str_printfa(query, "SELECT %s FROM %s",
                    map->value_field, map->table);
        if (sql_dict_where_build(dict, map, &values, key[0],
-                                SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+                                SQL_DICT_RECURSE_NONE, query,
+                                params, &error) < 0) {
                *error_r = t_strdup_printf(
                        "sql dict lookup: Failed to lookup key %s: %s", key, error);
                return -1;
@@ -435,15 +486,17 @@ static int sql_dict_lookup(struct dict *_dict, pool_t pool, const char *key,
        const struct dict_sql_map *map;
        struct sql_result *result = NULL;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        int ret;
 
        *value_r = NULL;
 
-       if (sql_lookup_get_query(dict, key, query, &map, error_r) < 0)
+       t_array_init(&params, 4);
+       if (sql_lookup_get_query(dict, key, query, &map, &params, error_r) < 0)
                return -1;
 
        struct sql_statement *stmt =
-               sql_statement_init(dict->db, str_c(query));
+               sql_dict_statement_init(dict->db, str_c(query), &params);
        result = sql_statement_query_s(&stmt);
        ret = sql_result_next_row(result);
        if (ret < 0) {
@@ -497,9 +550,11 @@ sql_dict_lookup_async(struct dict *_dict, const char *key,
        const struct dict_sql_map *map;
        struct sql_dict_lookup_context *ctx;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
 
-       if (sql_lookup_get_query(dict, key, query, &map, &error) < 0) {
+       t_array_init(&params, 4);
+       if (sql_lookup_get_query(dict, key, query, &map, &params, &error) < 0) {
                struct dict_lookup_result result;
 
                i_zero(&result);
@@ -512,7 +567,7 @@ sql_dict_lookup_async(struct dict *_dict, const char *key,
                ctx->context = context;
                ctx->map = map;
                struct sql_statement *stmt =
-                       sql_statement_init(dict->db, str_c(query));
+                       sql_dict_statement_init(dict->db, str_c(query), &params);
                sql_statement_query(&stmt, sql_dict_lookup_async_callback, ctx);
        }
 }
@@ -553,7 +608,9 @@ sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx,
 
 static int
 sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
-                                 string_t *query, const char **error_r)
+                                 string_t *query,
+                                 ARRAY_TYPE(sql_dict_param) *params,
+                                 const char **error_r)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        const struct dict_sql_map *map;
@@ -605,7 +662,7 @@ sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
                recurse_type = SQL_DICT_RECURSE_ONE;
        if (sql_dict_where_build(dict, map, &values,
                                 ctx->paths[ctx->path_idx][0],
-                                recurse_type, query, error_r) < 0)
+                                recurse_type, query, params, error_r) < 0)
                return -1;
 
        if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
@@ -641,11 +698,13 @@ static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
        unsigned int path_idx = ctx->path_idx;
        int ret;
 
-       ret = sql_dict_iterate_build_next_query(ctx, query, &error);
+       t_array_init(&params, 4);
+       ret = sql_dict_iterate_build_next_query(ctx, query, &params, &error);
        if (ret <= 0) {
                /* this is expected error */
                if (ret == 0)
@@ -658,7 +717,7 @@ static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
        }
 
        struct sql_statement *stmt =
-               sql_statement_init(dict->db, str_c(query));
+               sql_dict_statement_init(dict->db, str_c(query), &params);
        if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) {
                ctx->result = sql_statement_query_s(&stmt);
        } else {
@@ -923,10 +982,12 @@ static void sql_dict_transaction_rollback(struct dict_transaction_context *_ctx)
 
 static struct sql_statement *
 sql_dict_transaction_stmt_init(struct sql_dict_transaction_context *ctx,
-                              const char *query)
+                              const char *query,
+                              const ARRAY_TYPE(sql_dict_param) *params)
 {
        struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
-       struct sql_statement *stmt = sql_statement_init(dict->db, query);
+       struct sql_statement *stmt =
+               sql_dict_statement_init(dict->db, query, params);
 
        if (ctx->ctx.timestamp.tv_sec != 0)
                sql_statement_set_timestamp(stmt, &ctx->ctx.timestamp);
@@ -947,7 +1008,9 @@ struct dict_sql_build_query {
 };
 
 static int sql_dict_set_query(const struct dict_sql_build_query *build,
-                             const char **query_r, const char **error_r)
+                             const char **query_r,
+                             ARRAY_TYPE(sql_dict_param) *params,
+                             const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -973,9 +1036,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
                enum dict_sql_type value_type =
                        fields[i].map->value_types[0];
-               if (sql_dict_value_escape(suffix, dict, fields[i].map,
-                                         value_type, "value", fields[i].value,
-                                         "", error_r) < 0)
+               str_append_c(suffix, '?');
+               if (sql_dict_value_get(fields[i].map,
+                                      value_type, "value", fields[i].value,
+                                      "", params, error_r) < 0)
                        return -1;
        }
        if (build->key1 == DICT_PATH_PRIVATE[0]) {
@@ -990,9 +1054,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
        i_assert(count == count2);
        for (i = 0; i < count; i++) {
                str_printfa(prefix, ",%s", sql_fields[i].name);
-               str_append_c(suffix, ',');
-               if (sql_dict_field_escape_value(suffix, dict, fields[0].map, &sql_fields[i],
-                                               extra_values[i], "", error_r) < 0)
+               str_append(suffix, ",?");
+               if (sql_dict_field_get_value(fields[0].map, &sql_fields[i],
+                                            extra_values[i], "",
+                                            params, error_r) < 0)
                        return -1;
        }
 
@@ -1014,9 +1079,10 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
                enum dict_sql_type value_type =
                        fields[i].map->value_types[0];
-               if (sql_dict_value_escape(prefix, dict, fields[i].map,
-                                         value_type, "value", fields[i].value,
-                                         "", error_r) < 0)
+               str_append_c(prefix, '?');
+               if (sql_dict_value_get(fields[i].map,
+                                      value_type, "value", fields[i].value,
+                                      "", params, error_r) < 0)
                        return -1;
        }
        *query_r = str_c(prefix);
@@ -1025,7 +1091,8 @@ static int sql_dict_set_query(const struct dict_sql_build_query *build,
 
 static int
 sql_dict_update_query(const struct dict_sql_build_query *build,
-                     const char **query_r, const char **error_r)
+                     const char **query_r, ARRAY_TYPE(sql_dict_param) *params,
+                     const char **error_r)
 {
        struct sql_dict *dict = build->dict;
        const struct dict_sql_build_query_field *fields;
@@ -1042,15 +1109,13 @@ sql_dict_update_query(const struct dict_sql_build_query *build,
                        t_strcut(fields[i].map->value_field, ',');
                if (i > 0)
                        str_append_c(query, ',');
-               str_printfa(query, "%s=%s", first_value_field,
+               str_printfa(query, "%s=%s+?", first_value_field,
                            first_value_field);
-               if (fields[i].value[0] != '-')
-                       str_append_c(query, '+');
-               str_append(query, fields[i].value);
        }
 
        if (sql_dict_where_build(dict, fields[0].map, build->extra_values,
-                                build->key1, SQL_DICT_RECURSE_NONE, query, error_r) < 0)
+                                build->key1, SQL_DICT_RECURSE_NONE, query,
+                                params, error_r) < 0)
                return -1;
        *query_r = str_c(query);
        return 0;
@@ -1066,6 +1131,7 @@ static void sql_dict_set_real(struct dict_transaction_context *_ctx,
        ARRAY_TYPE(const_string) values;
        struct dict_sql_build_query build;
        struct dict_sql_build_query_field field;
+       ARRAY_TYPE(sql_dict_param) params;
        const char *query, *error;
 
        if (ctx->error != NULL)
@@ -1088,12 +1154,13 @@ static void sql_dict_set_real(struct dict_transaction_context *_ctx,
        build.extra_values = &values;
        build.key1 = key[0];
 
-       if (sql_dict_set_query(&build, &query, &error) < 0) {
+       t_array_init(&params, 4);
+       if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
                ctx->error = i_strdup_printf("dict-sql: Failed to set %s=%s: %s",
                                             key, value, error);
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, query);
+                       sql_dict_transaction_stmt_init(ctx, query, &params);
                sql_update_stmt(ctx->sql_ctx, &stmt);
        }
 }
@@ -1107,6 +1174,7 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
        const struct dict_sql_map *map;
        ARRAY_TYPE(const_string) values;
        string_t *query = t_str_new(256);
+       ARRAY_TYPE(sql_dict_param) params;
        const char *error;
 
        if (ctx->error != NULL)
@@ -1124,13 +1192,15 @@ static void sql_dict_unset(struct dict_transaction_context *_ctx,
        }
 
        str_printfa(query, "DELETE FROM %s", map->table);
+       t_array_init(&params, 4);
        if (sql_dict_where_build(dict, map, &values, key[0],
-                                SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+                                SQL_DICT_RECURSE_NONE, query,
+                                &params, &error) < 0) {
                ctx->error = i_strdup_printf(
                        "dict-sql: Failed to delete %s: %s", key, error);
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, str_c(query));
+                       sql_dict_transaction_stmt_init(ctx, str_c(query), &params);
                sql_update_stmt(ctx->sql_ctx, &stmt);
        }
 }
@@ -1159,6 +1229,8 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
        ARRAY_TYPE(const_string) values;
        struct dict_sql_build_query build;
        struct dict_sql_build_query_field field;
+       ARRAY_TYPE(sql_dict_param) params;
+       struct sql_dict_param *param;
        const char *query, *error;
 
        if (ctx->error != NULL)
@@ -1177,12 +1249,17 @@ static void sql_dict_atomic_inc_real(struct sql_dict_transaction_context *ctx,
        build.extra_values = &values;
        build.key1 = key[0];
 
-       if (sql_dict_update_query(&build, &query, &error) < 0) {
+       t_array_init(&params, 4);
+       param = array_append_space(&params);
+       param->value_type = DICT_SQL_TYPE_INT;
+       param->value_int64 = diff;
+
+       if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
                ctx->error = i_strdup_printf(
                        "dict-sql: Failed to increase %s: %s", key, error);
        } else {
                struct sql_statement *stmt =
-                       sql_dict_transaction_stmt_init(ctx, query);
+                       sql_dict_transaction_stmt_init(ctx, query, &params);
                sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
                                         sql_dict_next_inc_row(ctx));
        }
@@ -1274,6 +1351,7 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
        } else {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field *field;
+               ARRAY_TYPE(sql_dict_param) params;
                const char *query, *error;
 
                i_zero(&build);
@@ -1289,12 +1367,13 @@ static void sql_dict_set(struct dict_transaction_context *_ctx,
                field->map = map;
                field->value = value;
 
-               if (sql_dict_set_query(&build, &query, &error) < 0) {
+               t_array_init(&params, 4);
+               if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
                        ctx->error = i_strdup_printf(
                                "dict-sql: Failed to set %s: %s", key, error);
                } else {
                        struct sql_statement *stmt =
-                               sql_dict_transaction_stmt_init(ctx, query);
+                               sql_dict_transaction_stmt_init(ctx, query, &params);
                        sql_update_stmt(ctx->sql_ctx, &stmt);
                }
                i_free_and_null(ctx->prev_set_value);
@@ -1338,6 +1417,8 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
        } else {
                struct dict_sql_build_query build;
                struct dict_sql_build_query_field *field;
+               ARRAY_TYPE(sql_dict_param) params;
+               struct sql_dict_param *param;
                const char *query, *error;
 
                i_zero(&build);
@@ -1353,12 +1434,21 @@ static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx,
                field->map = map;
                field->value = t_strdup_printf("%lld", diff);
 
-               if (sql_dict_update_query(&build, &query, &error) < 0) {
+               t_array_init(&params, 4);
+               param = array_append_space(&params);
+               param->value_type = DICT_SQL_TYPE_INT;
+               param->value_int64 = ctx->prev_inc_diff;
+
+               param = array_append_space(&params);
+               param->value_type = DICT_SQL_TYPE_INT;
+               param->value_int64 = diff;
+
+               if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
                        ctx->error = i_strdup_printf(
                                "dict-sql: Failed to increase %s: %s", key, error);
                } else {
                        struct sql_statement *stmt =
-                               sql_dict_transaction_stmt_init(ctx, query);
+                               sql_dict_transaction_stmt_init(ctx, query, &params);
                        sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
                                                 sql_dict_next_inc_row(ctx));
                }