]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib-sql: mysql - Move settings to struct mysql_settings
authorTimo Sirainen <timo.sirainen@open-xchange.com>
Thu, 29 Feb 2024 12:55:07 +0000 (14:55 +0200)
committerAki Tuomi <aki.tuomi@open-xchange.com>
Fri, 17 Jan 2025 08:39:58 +0000 (10:39 +0200)
src/lib-sql/Makefile.am
src/lib-sql/driver-mysql.c
src/lib-sql/sql-api-private.h

index e59e5314ae10f7a05155532757e9ad15d1b02ee1..26344ee73d8dead30f628e7ff7f4f884a15c45aa 100644 (file)
@@ -38,6 +38,7 @@ AM_CPPFLAGS = \
        -I$(top_srcdir)/src/lib \
        -I$(top_srcdir)/src/lib-test \
        -I$(top_srcdir)/src/lib-settings \
+       -I$(top_srcdir)/src/lib-ssl-iostream \
        $(SQL_CFLAGS)
 
 dist_sources = \
index cb7eef8e0840f383db947fd81a16be99506c5810..0d5eeb7e8f90239e11808cf96ac437348a18793f 100644 (file)
@@ -7,6 +7,8 @@
 #include "str.h"
 #include "net.h"
 #include "time-util.h"
+#include "settings-parser.h"
+#include "ssl-settings.h"
 #include "sql-api-private.h"
 
 #ifdef BUILD_MYSQL
 #define MYSQL_DEFAULT_READ_TIMEOUT_SECS 30
 #define MYSQL_DEFAULT_WRITE_TIMEOUT_SECS 30
 
+struct mysql_settings {
+       pool_t pool;
+
+       const char *host;
+       in_port_t port;
+       const char *user;
+       const char *password;
+       const char *dbname;
+
+       bool ssl;
+       const char *option_file;
+       const char *option_group;
+       unsigned int client_flags;
+
+       unsigned int connect_timeout_secs;
+       unsigned int read_timeout_secs;
+       unsigned int write_timeout_secs;
+};
+
+static struct mysql_settings mysql_default_settings = {
+       .host = "",
+       .port = 0,
+       .user = "",
+       .password = "",
+       .dbname = "",
+
+       .ssl = FALSE,
+       .option_file = "",
+       .option_group = "client",
+       .client_flags = 0,
+
+       .connect_timeout_secs = SQL_CONNECT_TIMEOUT_SECS,
+       .read_timeout_secs = MYSQL_DEFAULT_READ_TIMEOUT_SECS,
+       .write_timeout_secs = MYSQL_DEFAULT_WRITE_TIMEOUT_SECS,
+};
+
 struct mysql_db {
        struct sql_db api;
 
        pool_t pool;
-       const char *user, *password, *dbname, *host, *unix_socket;
-       const char *ssl_cert, *ssl_key, *ssl_ca, *ssl_ca_path, *ssl_cipher;
-       int ssl_verify_server_cert;
-       const char *option_file, *option_group;
-       in_port_t port;
-       unsigned int client_flags;
-       unsigned int connect_timeout, read_timeout, write_timeout;
+       const struct mysql_settings *set;
+       const struct ssl_settings *ssl_set;
+
        time_t last_success;
 
        MYSQL *mysql;
        unsigned int next_query_connection;
-
-       bool ssl_set:1;
 };
 
 struct mysql_result {
@@ -82,56 +114,65 @@ static int driver_mysql_connect(struct sql_db *_db)
 {
        struct mysql_db *db = container_of(_db, struct mysql_db, api);
        const char *unix_socket, *host;
-       unsigned long client_flags = db->client_flags;
+       unsigned long client_flags = db->set->client_flags;
        unsigned int secs_used;
        time_t start_time;
        bool failed;
 
        i_assert(db->api.state == SQL_DB_STATE_DISCONNECTED);
 
-       if (db->host == NULL) {
+       if (db->set->host[0] == '\0') {
                /* assume option_file overrides the host, or if not we'll just
                   connect to localhost */
                unix_socket = NULL;
                host = NULL;
-       } else if (*db->host == '/') {
-               unix_socket = db->host;
+       } else if (*db->set->host == '/') {
+               unix_socket = db->set->host;
                host = NULL;
        } else {
                unix_socket = NULL;
-               host = db->host;
+               host = db->set->host;
        }
 
-       if (db->option_file != NULL) {
+       if (db->set->option_file[0] != '\0') {
                mysql_options(db->mysql, MYSQL_READ_DEFAULT_FILE,
-                             db->option_file);
+                             db->set->option_file);
        }
 
-       if (db->host != NULL)
-               event_set_append_log_prefix(_db->event, t_strdup_printf("mysql(%s): ", db->host));
-
-       mysql_options(db->mysql, MYSQL_OPT_CONNECT_TIMEOUT, &db->connect_timeout);
-       mysql_options(db->mysql, MYSQL_OPT_READ_TIMEOUT, &db->read_timeout);
-       mysql_options(db->mysql, MYSQL_OPT_WRITE_TIMEOUT, &db->write_timeout);
-       mysql_options(db->mysql, MYSQL_READ_DEFAULT_GROUP,
-                     db->option_group != NULL ? db->option_group : "client");
+       mysql_options(db->mysql, MYSQL_OPT_CONNECT_TIMEOUT, &db->set->connect_timeout_secs);
+       mysql_options(db->mysql, MYSQL_OPT_READ_TIMEOUT, &db->set->read_timeout_secs);
+       mysql_options(db->mysql, MYSQL_OPT_WRITE_TIMEOUT, &db->set->write_timeout_secs);
+       mysql_options(db->mysql, MYSQL_READ_DEFAULT_GROUP, db->set->option_group);
 
-       if (!db->ssl_set && (db->ssl_ca != NULL || db->ssl_ca_path != NULL)) {
+       if (db->set->ssl) {
 #ifdef HAVE_MYSQL_SSL
-               mysql_ssl_set(db->mysql, db->ssl_key, db->ssl_cert,
-                             db->ssl_ca, db->ssl_ca_path
+               struct settings_file key_file, cert_file, ca_file;
+               settings_file_get(db->ssl_set->ssl_client_key_file,
+                                 unsafe_data_stack_pool, &key_file);
+               settings_file_get(db->ssl_set->ssl_client_cert_file,
+                                 unsafe_data_stack_pool, &cert_file);
+               settings_file_get(db->ssl_set->ssl_client_ca_file,
+                                 unsafe_data_stack_pool, &ca_file);
+               mysql_ssl_set(db->mysql,
+                             key_file.path[0] == '\0' ? NULL : key_file.path,
+                             cert_file.path[0] == '\0' ? NULL : cert_file.path,
+                             ca_file.path[0] == '\0' ? NULL : ca_file.path,
+                             (db->ssl_set->ssl_client_ca_dir[0] != '\0' ?
+                              db->ssl_set->ssl_client_ca_dir : NULL)
 #ifdef HAVE_MYSQL_SSL_CIPHER
-                             , db->ssl_cipher
+                             , db->ssl_set->ssl_cipher_list
 #endif
                             );
 #ifdef HAVE_MYSQL_SSL_VERIFY_SERVER_CERT
+               int ssl_verify_server_cert =
+                       ssl_set->ssl_client_require_valid_cert ? 1 : 0;
+
                mysql_options(db->mysql, MYSQL_OPT_SSL_VERIFY_SERVER_CERT,
-                             (void *)&db->ssl_verify_server_cert);
+                             (void *)&ssl_verify_server_cert);
 #endif
-               db->ssl_set = TRUE;
 #else
-               const char *error = "SSL support not compiled in "
-                       "(remove ssl_ca and ssl_ca_path settings)";
+               const char *error = "mysql: SSL support not compiled in "
+                       "(remove ssl_client_ca_file and ssl_client_ca_dir settings)";
                i_free(_db->last_connect_error);
                _db->last_connect_error = i_strdup(error);
                e_error(_db->event, "%s", error);
@@ -147,9 +188,11 @@ static int driver_mysql_connect(struct sql_db *_db)
 #endif
        /* CLIENT_MULTI_RESULTS allows the use of stored procedures */
        start_time = time(NULL);
-       failed = mysql_real_connect(db->mysql, host, db->user, db->password,
-                                   db->dbname, db->port, unix_socket,
-                                   client_flags) == NULL;
+       failed = mysql_real_connect(db->mysql, host,
+               db->set->user[0] == '\0' ? NULL : db->set->user,
+               db->set->password[0] == '\0' ? NULL : db->set->password,
+               db->set->dbname, db->set->port,
+               unix_socket, client_flags) == NULL;
        secs_used = time(NULL) - start_time;
        if (failed) {
                /* connecting could have taken a while. make sure that any
@@ -162,7 +205,8 @@ static int driver_mysql_connect(struct sql_db *_db)
                sql_db_set_state(&db->api, SQL_DB_STATE_DISCONNECTED);
                e_error(_db->event, "Connect failed to database (%s): %s - "
                        "waiting for %u seconds before retry",
-                       db->dbname, mysql_error(db->mysql), db->api.connect_delay);
+                       db->set->dbname, mysql_error(db->mysql),
+                       db->api.connect_delay);
                i_free(_db->last_connect_error);
                _db->last_connect_error = i_strdup(mysql_error(db->mysql));
                sql_disconnect(&db->api);
@@ -181,18 +225,25 @@ static void driver_mysql_disconnect(struct sql_db *_db)
                mysql_close(db->mysql);
 }
 
-static int driver_mysql_parse_connect_string(struct mysql_db *db,
-                                            const char *connect_string,
-                                            const char **error_r)
+static int
+driver_mysql_parse_connect_string(pool_t pool, const char *connect_string,
+                                 const struct mysql_settings **set_r,
+                                 const struct ssl_settings **ssl_set_r,
+                                 const char **error_r)
 {
+       struct mysql_settings *set;
+       struct ssl_settings *ssl_set;
        const char *const *args, *name, *value;
        const char **field;
 
-       db->ssl_cipher = "HIGH";
-       db->ssl_verify_server_cert = 1;
-       db->connect_timeout = SQL_CONNECT_TIMEOUT_SECS;
-       db->read_timeout = MYSQL_DEFAULT_READ_TIMEOUT_SECS;
-       db->write_timeout = MYSQL_DEFAULT_WRITE_TIMEOUT_SECS;
+       set = p_new(pool, struct mysql_settings, 1);
+       *set = mysql_default_settings;
+       set->pool = pool;
+       ssl_set = p_new(pool, struct ssl_settings, 1);
+       *ssl_set = ssl_default_settings;
+       ssl_set->pool = pool;
+
+       ssl_set->ssl_cipher_list = "HIGH";
 
        args = t_strsplit_spaces(connect_string, " ");
        for (; *args != NULL; args++) {
@@ -208,79 +259,78 @@ static int driver_mysql_parse_connect_string(struct mysql_db *db,
                field = NULL;
                if (strcmp(name, "host") == 0 ||
                    strcmp(name, "hostaddr") == 0)
-                       field = &db->host;
+                       field = &set->host;
                else if (strcmp(name, "user") == 0)
-                       field = &db->user;
+                       field = &set->user;
                else if (strcmp(name, "password") == 0)
-                       field = &db->password;
+                       field = &set->password;
                else if (strcmp(name, "dbname") == 0)
-                       field = &db->dbname;
+                       field = &set->dbname;
                else if (strcmp(name, "port") == 0) {
-                       if (net_str2port(value, &db->port) < 0) {
+                       if (net_str2port(value, &set->port) < 0) {
                                *error_r = t_strdup_printf("Invalid port number: %s", value);
                                return -1;
                        }
                } else if (strcmp(name, "client_flags") == 0) {
-                       if (str_to_uint(value, &db->client_flags) < 0) {
+                       if (str_to_uint(value, &set->client_flags) < 0) {
                                *error_r = t_strdup_printf("Invalid client flags: %s", value);
                                return -1;
                        }
                } else if (strcmp(name, "connect_timeout") == 0) {
-                       if (str_to_uint(value, &db->connect_timeout) < 0) {
+                       if (str_to_uint(value, &set->connect_timeout_secs) < 0) {
                                *error_r = t_strdup_printf("Invalid read_timeout: %s", value);
                                return -1;
                        }
                } else if (strcmp(name, "read_timeout") == 0) {
-                       if (str_to_uint(value, &db->read_timeout) < 0) {
+                       if (str_to_uint(value, &set->read_timeout_secs) < 0) {
                                *error_r = t_strdup_printf("Invalid read_timeout: %s", value);
                                return -1;
                        }
                } else if (strcmp(name, "write_timeout") == 0) {
-                       if (str_to_uint(value, &db->write_timeout) < 0) {
+                       if (str_to_uint(value, &set->write_timeout_secs) < 0) {
                                *error_r = t_strdup_printf("Invalid read_timeout: %s", value);
                                return -1;
                        }
-               } else if (strcmp(name, "ssl_cert") == 0)
-                       field = &db->ssl_cert;
-               else if (strcmp(name, "ssl_key") == 0)
-                       field = &db->ssl_key;
-               else if (strcmp(name, "ssl_ca") == 0)
-                       field = &db->ssl_ca;
-               else if (strcmp(name, "ssl_ca_path") == 0)
-                       field = &db->ssl_ca_path;
+               } else if (strcmp(name, "ssl_cert") == 0) {
+                       field = &ssl_set->ssl_client_cert_file;
+                       value = t_strconcat(value, "\n", NULL);
+               } else if (strcmp(name, "ssl_key") == 0) {
+                       field = &ssl_set->ssl_client_key_file;
+                       value = t_strconcat(value, "\n", NULL);
+               } else if (strcmp(name, "ssl_ca") == 0) {
+                       field = &ssl_set->ssl_client_ca_file;
+                       value = t_strconcat(value, "\n", NULL);
+               } else if (strcmp(name, "ssl_ca_path") == 0)
+                       field = &ssl_set->ssl_client_ca_dir;
                else if (strcmp(name, "ssl_cipher") == 0)
-                       field = &db->ssl_cipher;
+                       field = &ssl_set->ssl_cipher_list;
                else if (strcmp(name, "ssl_verify_server_cert") == 0) {
                        if (strcmp(value, "yes") == 0)
-                               db->ssl_verify_server_cert = 1;
+                               ssl_set->ssl_client_require_valid_cert = TRUE;
                        else if (strcmp(value, "no") == 0)
-                               db->ssl_verify_server_cert = 0;
+                               ssl_set->ssl_client_require_valid_cert = FALSE;
                        else {
                                *error_r = t_strdup_printf("Invalid boolean: %s", value);
                                return -1;
                        }
                } else if (strcmp(name, "option_file") == 0)
-                       field = &db->option_file;
+                       field = &set->option_file;
                else if (strcmp(name, "option_group") == 0)
-                       field = &db->option_group;
+                       field = &set->option_group;
                else {
                        *error_r = t_strdup_printf("Unknown connect string: %s", name);
                        return -1;
                }
                if (field != NULL)
-                       *field = p_strdup(db->pool, value);
+                       *field = p_strdup(pool, value);
        }
 
-       if (db->host == NULL && db->option_file == NULL) {
+       if (set->host[0] == '\0' && set->option_file[0] == '\0') {
                *error_r = "No hosts given in connect string";
                return -1;
        }
-       if (db->mysql == NULL) {
-               db->mysql = p_new(db->pool, MYSQL, 1);
-               MYSQL *ptr = mysql_init(db->mysql);
-               if (ptr == NULL)
-                       i_fatal_status(FATAL_OUTOFMEM, "mysql_init() failed");
-       }
+       *set_r = set;
+       *ssl_set_r = ssl_set;
        return 0;
 }
 
@@ -300,7 +350,8 @@ static int driver_mysql_init_full_v(const struct sql_legacy_settings *set,
        event_add_category(db->api.event, &event_category_mysql);
        event_set_append_log_prefix(db->api.event, "mysql: ");
        T_BEGIN {
-               ret = driver_mysql_parse_connect_string(db, set->connect_string, &error);
+               ret = driver_mysql_parse_connect_string(pool,
+                       set->connect_string, &db->set, &db->ssl_set, error_r);
                error = p_strdup(db->pool, error);
        } T_END;
 
index a582000e199ec86239bc7caf65bca6d2f1031c3e..afe2af6b8aed15f7b493d400fcd91ea3055a3958 100644 (file)
@@ -15,6 +15,7 @@ enum sql_db_state {
        SQL_DB_STATE_BUSY
 };
 
+/* <settings checks> */
 /* Minimum delay between reconnecting to same server */
 #define SQL_CONNECT_MIN_DELAY 1
 /* Maximum time to avoiding reconnecting to same server */
@@ -30,6 +31,7 @@ enum sql_db_state {
 #define SQL_QUERY_TIMEOUT_SECS 60
 /* Default max. number of connections to create per host */
 #define SQL_DEFAULT_CONNECTION_LIMIT 5
+/* </settings checks> */
 
 #define SQL_DB_IS_READY(db) \
        ((db)->state == SQL_DB_STATE_IDLE)