]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib-sql: Add support for prepared SQL statements.
authorTimo Sirainen <timo.sirainen@dovecot.fi>
Tue, 22 Aug 2017 08:37:17 +0000 (11:37 +0300)
committerAki Tuomi <aki.tuomi@dovecot.fi>
Fri, 8 Sep 2017 10:18:32 +0000 (13:18 +0300)
This initial implementation doesn't use prepared statements in drivers, but
simply generates the query string internally.

src/lib-sql/sql-api-private.h
src/lib-sql/sql-api.c
src/lib-sql/sql-api.h

index ca62a44770eae95c140800491a9962ebec3cf314..dad8646ef61751cb953ea9493eeea5a8d872dad5 100644 (file)
@@ -78,6 +78,33 @@ struct sql_db_vfuncs {
                       unsigned int *affected_rows);
        const char *(*escape_blob)(struct sql_db *db,
                                   const unsigned char *data, size_t size);
+
+       struct sql_prepared_statement *
+               (*prepared_statement_init)(struct sql_db *db,
+                                          const char *query_template);
+       void (*prepared_statement_deinit)(struct sql_prepared_statement *prep_stmt);
+
+
+       struct sql_statement *
+               (*statement_init)(struct sql_db *db, const char *query_template);
+       struct sql_statement *
+               (*statement_init_prepared)(struct sql_prepared_statement *prep_stmt);
+       void (*statement_abort)(struct sql_statement *stmt);
+       void (*statement_set_timestamp)(struct sql_statement *stmt,
+                                       const struct timespec *ts);
+       void (*statement_bind_str)(struct sql_statement *stmt,
+                                  unsigned int column_idx, const char *value);
+       void (*statement_bind_binary)(struct sql_statement *stmt,
+                                     unsigned int column_idx, const void *value,
+                                     size_t value_size);
+       void (*statement_bind_int64)(struct sql_statement *stmt,
+                                    unsigned int column_idx, int64_t value);
+       void (*statement_query)(struct sql_statement *stmt,
+                               sql_query_callback_t *callback, void *context);
+       struct sql_result *(*statement_query_s)(struct sql_statement *stmt);
+       void (*update_stmt)(struct sql_transaction_context *ctx,
+                           struct sql_statement *stmt,
+                           unsigned int *affected_rows);
 };
 
 struct sql_db {
@@ -127,6 +154,18 @@ struct sql_result_vfuncs {
                     sql_query_callback_t *callback, void *context);
 };
 
+struct sql_prepared_statement {
+       struct sql_db *db;
+};
+
+struct sql_statement {
+       struct sql_db *db;
+
+       pool_t pool;
+       const char *query_template;
+       ARRAY_TYPE(const_string) args;
+};
+
 struct sql_field_map {
        enum sql_field_type type;
        size_t offset;
@@ -169,5 +208,6 @@ void sql_db_set_state(struct sql_db *db, enum sql_db_state state);
 
 void sql_transaction_add_query(struct sql_transaction_context *ctx, pool_t pool,
                               const char *query, unsigned int *affected_rows);
+const char *sql_statement_get_query(struct sql_statement *stmt);
 
 #endif
index 8981104dd699f2db11139d682d5f4786a877c37b..20e0e270c8f610b2ff05d9f5c29901457f125541 100644 (file)
@@ -3,10 +3,16 @@
 #include "lib.h"
 #include "array.h"
 #include "ioloop.h"
+#include "str.h"
 #include "sql-api-private.h"
 
 #include <time.h>
 
+struct default_sql_prepared_statement {
+       struct sql_prepared_statement prep_stmt;
+       char *query_template;
+};
+
 struct sql_db_module_register sql_db_module_register = { 0 };
 ARRAY_TYPE(sql_drivers) sql_drivers;
 
@@ -148,6 +154,225 @@ struct sql_result *sql_query_s(struct sql_db *db, const char *query)
        return db->v.query_s(db, query);
 }
 
+static struct sql_prepared_statement *
+default_sql_prepared_statement_init(struct sql_db *db,
+                                   const char *query_template)
+{
+       struct default_sql_prepared_statement *prep_stmt;
+
+       prep_stmt = i_new(struct default_sql_prepared_statement, 1);
+       prep_stmt->prep_stmt.db = db;
+       prep_stmt->query_template = i_strdup(query_template);
+       return &prep_stmt->prep_stmt;
+}
+
+static void
+default_sql_prepared_statement_deinit(struct sql_prepared_statement *_prep_stmt)
+{
+       struct default_sql_prepared_statement *prep_stmt =
+               (struct default_sql_prepared_statement *)_prep_stmt;
+
+       i_free(prep_stmt->query_template);
+       i_free(prep_stmt);
+}
+
+static struct sql_statement *
+default_sql_statement_init_prepared(struct sql_prepared_statement *_stmt)
+{
+       struct default_sql_prepared_statement *stmt =
+               (struct default_sql_prepared_statement *)_stmt;
+       return sql_statement_init(_stmt->db, stmt->query_template);
+}
+
+const char *sql_statement_get_query(struct sql_statement *stmt)
+{
+       string_t *query = t_str_new(128);
+       const char *const *args;
+       unsigned int i, args_count, arg_pos = 0;
+
+       args = array_get(&stmt->args, &args_count);
+
+       for (i = 0; stmt->query_template[i] != '\0'; i++) {
+               if (stmt->query_template[i] == '?') {
+                       if (arg_pos >= args_count ||
+                           args[arg_pos] == NULL) {
+                               i_panic("lib-sql: Missing bind for arg #%u in statement: %s",
+                                       arg_pos, stmt->query_template);
+                       }
+                       str_append(query, args[arg_pos++]);
+               } else {
+                       str_append_c(query, stmt->query_template[i]);
+               }
+       }
+       if (arg_pos != args_count) {
+               i_panic("lib-sql: Too many bind args (%u) for statement: %s",
+                       args_count, stmt->query_template);
+       }
+       return str_c(query);
+}
+
+static void
+default_sql_statement_query(struct sql_statement *stmt,
+                           sql_query_callback_t *callback, void *context)
+{
+       sql_query(stmt->db, sql_statement_get_query(stmt),
+                 callback, context);
+       pool_unref(&stmt->pool);
+}
+
+static struct sql_result *
+default_sql_statement_query_s(struct sql_statement *stmt)
+{
+       struct sql_result *result =
+               sql_query_s(stmt->db, sql_statement_get_query(stmt));
+       pool_unref(&stmt->pool);
+       return result;
+}
+
+static void default_sql_update_stmt(struct sql_transaction_context *ctx,
+                                   struct sql_statement *stmt,
+                                   unsigned int *affected_rows)
+{
+       ctx->db->v.update(ctx, sql_statement_get_query(stmt),
+                         affected_rows);
+       pool_unref(&stmt->pool);
+}
+
+struct sql_prepared_statement *
+sql_prepared_statement_init(struct sql_db *db, const char *query_template)
+{
+       if (db->v.prepared_statement_init != NULL)
+               return db->v.prepared_statement_init(db, query_template);
+       else
+               return default_sql_prepared_statement_init(db, query_template);
+}
+
+void sql_prepared_statement_deinit(struct sql_prepared_statement **_prep_stmt)
+{
+       struct sql_prepared_statement *prep_stmt = *_prep_stmt;
+
+       *_prep_stmt = NULL;
+       if (prep_stmt->db->v.prepared_statement_deinit != NULL)
+               prep_stmt->db->v.prepared_statement_deinit(prep_stmt);
+       else
+               default_sql_prepared_statement_deinit(prep_stmt);
+}
+
+static void
+sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db)
+{
+       stmt->db = db;
+       p_array_init(&stmt->args, stmt->pool, 8);
+}
+
+struct sql_statement *
+sql_statement_init(struct sql_db *db, const char *query_template)
+{
+       struct sql_statement *stmt;
+
+       if (db->v.statement_init != NULL)
+               stmt = db->v.statement_init(db, query_template);
+       else {
+               pool_t pool = pool_alloconly_create("sql statement", 1024);
+               stmt = p_new(pool, struct sql_statement, 1);
+               stmt->pool = pool;
+       }
+       stmt->query_template = p_strdup(stmt->pool, query_template);
+       sql_statement_init_fields(stmt, db);
+       return stmt;
+}
+
+struct sql_statement *
+sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt)
+{
+       struct sql_statement *stmt;
+
+       if (prep_stmt->db->v.statement_init_prepared == NULL)
+               return default_sql_statement_init_prepared(prep_stmt);
+
+       stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt);
+       sql_statement_init_fields(stmt, prep_stmt->db);
+       return stmt;
+}
+
+void sql_statement_abort(struct sql_statement **_stmt)
+{
+       struct sql_statement *stmt = *_stmt;
+
+       *_stmt = NULL;
+       if (stmt->db->v.statement_abort != NULL)
+               stmt->db->v.statement_abort(stmt);
+       pool_unref(&stmt->pool);
+}
+
+void sql_statement_set_timestamp(struct sql_statement *stmt,
+                                const struct timespec *ts)
+{
+       if (stmt->db->v.statement_set_timestamp != NULL)
+               stmt->db->v.statement_set_timestamp(stmt, ts);
+}
+
+void sql_statement_bind_str(struct sql_statement *stmt,
+                           unsigned int column_idx, const char *value)
+{
+       const char *escaped_value =
+               p_strdup_printf(stmt->pool, "'%s'",
+                               sql_escape_string(stmt->db, value));
+       array_idx_set(&stmt->args, column_idx, &escaped_value);
+
+       if (stmt->db->v.statement_bind_str != NULL)
+               stmt->db->v.statement_bind_str(stmt, column_idx, value);
+}
+
+void sql_statement_bind_binary(struct sql_statement *stmt,
+                              unsigned int column_idx, const void *value,
+                              size_t value_size)
+{
+       const char *value_str =
+               p_strdup_printf(stmt->pool, "%s",
+                               sql_escape_blob(stmt->db, value, value_size));
+       array_idx_set(&stmt->args, column_idx, &value_str);
+
+       if (stmt->db->v.statement_bind_binary != NULL) {
+               stmt->db->v.statement_bind_binary(stmt, column_idx,
+                                                 value, value_size);
+       }
+}
+
+void sql_statement_bind_int64(struct sql_statement *stmt,
+                             unsigned int column_idx, int64_t value)
+{
+       const char *value_str = p_strdup_printf(stmt->pool, "%"PRId64, value);
+       array_idx_set(&stmt->args, column_idx, &value_str);
+
+       if (stmt->db->v.statement_bind_int64 != NULL)
+               stmt->db->v.statement_bind_int64(stmt, column_idx, value);
+}
+
+#undef sql_statement_query
+void sql_statement_query(struct sql_statement **_stmt,
+                        sql_query_callback_t *callback, void *context)
+{
+       struct sql_statement *stmt = *_stmt;
+
+       *_stmt = NULL;
+       if (stmt->db->v.statement_query != NULL)
+               stmt->db->v.statement_query(stmt, callback, context);
+       else
+               default_sql_statement_query(stmt, callback, context);
+}
+
+struct sql_result *sql_statement_query_s(struct sql_statement **_stmt)
+{
+       struct sql_statement *stmt = *_stmt;
+
+       *_stmt = NULL;
+       if (stmt->db->v.statement_query_s != NULL)
+               return stmt->db->v.statement_query_s(stmt);
+       else
+               return default_sql_statement_query_s(stmt);
+}
+
 void sql_result_ref(struct sql_result *result)
 {
        result->refcount++;
@@ -408,12 +633,37 @@ void sql_update(struct sql_transaction_context *ctx, const char *query)
        ctx->db->v.update(ctx, query, NULL);
 }
 
+void sql_update_stmt(struct sql_transaction_context *ctx,
+                    struct sql_statement **_stmt)
+{
+       struct sql_statement *stmt = *_stmt;
+
+       *_stmt = NULL;
+       if (ctx->db->v.update_stmt != NULL)
+               ctx->db->v.update_stmt(ctx, stmt, NULL);
+       else
+               default_sql_update_stmt(ctx, stmt, NULL);
+}
+
 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
                         unsigned int *affected_rows)
 {
        ctx->db->v.update(ctx, query, affected_rows);
 }
 
+void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
+                             struct sql_statement **_stmt,
+                             unsigned int *affected_rows)
+{
+       struct sql_statement *stmt = *_stmt;
+
+       *_stmt = NULL;
+       if (ctx->db->v.update_stmt != NULL)
+               ctx->db->v.update_stmt(ctx, stmt, affected_rows);
+       else
+               default_sql_update_stmt(ctx, stmt, affected_rows);
+}
+
 void sql_db_set_state(struct sql_db *db, enum sql_db_state state)
 {
        enum sql_db_state old_state = db->state;
index 256b77995de80420498cc8dfa7297c5f43ed6574..768dc7de6a06dcb7a030e85bfc3555fbcbb95367 100644 (file)
@@ -8,7 +8,11 @@ enum sql_db_flags {
        /* Set if queries are not executed asynchronously */
        SQL_DB_FLAG_BLOCKING            = 0x01,
        /* Set if database wants to use connection pooling */
-       SQL_DB_FLAG_POOLED              = 0x02
+       SQL_DB_FLAG_POOLED              = 0x02,
+       /* Prepared statements are supported by the database. If they aren't,
+          the functions can still be used, but they're just internally
+          convered into regular statements. */
+       SQL_DB_FLAG_PREP_STATEMENTS     = 0x04,
 };
 
 enum sql_field_type {
@@ -112,6 +116,33 @@ void sql_query(struct sql_db *db, const char *query,
 /* Execute blocking SQL query and return result. */
 struct sql_result *sql_query_s(struct sql_db *db, const char *query);
 
+struct sql_prepared_statement *
+sql_prepared_statement_init(struct sql_db *db, const char *query_template);
+void sql_prepared_statement_deinit(struct sql_prepared_statement **prep_stmt);
+
+struct sql_statement *
+sql_statement_init(struct sql_db *db, const char *query_template);
+struct sql_statement *
+sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt);
+void sql_statement_abort(struct sql_statement **stmt);
+void sql_statement_set_timestamp(struct sql_statement *stmt,
+                                const struct timespec *ts);
+void sql_statement_bind_str(struct sql_statement *stmt,
+                           unsigned int column_idx, const char *value);
+void sql_statement_bind_binary(struct sql_statement *stmt,
+                              unsigned int column_idx, const void *value,
+                              size_t value_size);
+void sql_statement_bind_int64(struct sql_statement *stmt,
+                             unsigned int column_idx, int64_t value);
+void sql_statement_query(struct sql_statement **stmt,
+                        sql_query_callback_t *callback, void *context);
+#define sql_statement_query(stmt, callback, context) \
+       sql_statement_query(stmt, \
+               (sql_query_callback_t *)callback, context + \
+               CALLBACK_TYPECHECK(callback, void (*)( \
+                       struct sql_result *, typeof(context))))
+struct sql_result *sql_statement_query_s(struct sql_statement **stmt);
+
 void sql_result_setup_fetch(struct sql_result *result,
                            const struct sql_field_def *fields,
                            void *dest, size_t dest_size);
@@ -179,9 +210,14 @@ void sql_transaction_rollback(struct sql_transaction_context **ctx);
 
 /* Execute query in given transaction. */
 void sql_update(struct sql_transaction_context *ctx, const char *query);
+void sql_update_stmt(struct sql_transaction_context *ctx,
+                    struct sql_statement **stmt);
 /* Save the number of rows updated by this query. The value is set before
    commit callback is called. */
 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
                         unsigned int *affected_rows);
+void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
+                             struct sql_statement **stmt,
+                             unsigned int *affected_rows);
 
 #endif