]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Update mysql driver to use trunk connections and non-blocking calls
authorNick Porter <nick@portercomputing.co.uk>
Mon, 27 May 2024 14:01:35 +0000 (15:01 +0100)
committerArran Cudbard-Bell <a.cudbardb@freeradius.org>
Fri, 7 Jun 2024 02:26:58 +0000 (22:26 -0400)
src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c
src/modules/rlm_sql/rlm_sql.h

index da835ab6da044493e1cafc3db17dbda11f1a6aa4..b7e35bc968acee6ee55ea974ceed071cb91a03fe 100644 (file)
@@ -65,9 +65,13 @@ static fr_table_num_sorted_t const server_warnings_table[] = {
 static size_t server_warnings_table_len = NUM_ELEMENTS(server_warnings_table);
 
 typedef struct {
-       MYSQL           db;
-       MYSQL           *sock;
-       MYSQL_RES       *result;
+       MYSQL           db;                     //!< Structure representing connection details.
+       MYSQL           *sock;                  //!< Connection details as returned by connection init functions.
+       MYSQL_RES       *result;                //!< Result from most recent query.
+       fr_connection_t *conn;                  //!< Generic connection structure for this connection.
+       int             fd;                     //!< fd for this connection's I/O events.
+       fr_sql_query_t  *query_ctx;             //!< Current query running on this connection.
+       int             status;                 //!< returned by the most recent non-blocking function call.
 } rlm_sql_mysql_conn_t;
 
 typedef struct {
@@ -129,18 +133,6 @@ static const conf_parser_t driver_config[] = {
 /* Prototypes */
 static sql_rcode_t sql_free_result(fr_sql_query_t *, rlm_sql_config_t const *);
 
-static int _sql_socket_destructor(rlm_sql_mysql_conn_t *conn)
-{
-       DEBUG2("Socket destructor called, closing socket");
-
-       if (conn->sock) {
-               mysql_close(conn->sock);
-               conn->sock = NULL;
-       }
-
-       return 0;
-}
-
 static int mod_instantiate(module_inst_ctx_t const *mctx)
 {
        rlm_sql_mysql_t         *inst = talloc_get_type_abort(mctx->mi->data, rlm_sql_mysql_t);
@@ -189,24 +181,61 @@ static int mod_load(void)
        return 0;
 }
 
-static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t const *config, fr_time_delta_t timeout)
+/** Callback for I/O events in response to mysql_real_connect_start()
+ */
+static void _sql_connect_io_notify(fr_event_list_t *el, int fd, UNUSED int flags, void *uctx)
 {
-       rlm_sql_mysql_t *inst = talloc_get_type_abort(handle->inst->driver_submodule->data, rlm_sql_mysql_t);
-       rlm_sql_mysql_conn_t *conn;
+       rlm_sql_mysql_conn_t    *c = talloc_get_type_abort(uctx, rlm_sql_mysql_conn_t);
+
+       fr_event_fd_delete(el, fd, FR_EVENT_FILTER_IO);
+
+       if (c->status == 0) goto connected;
+       c->status = mysql_real_connect_cont(&c->sock, &c->db, c->status);
 
-       unsigned int connect_timeout = (unsigned int)fr_time_delta_to_sec(timeout);
-       unsigned long sql_flags;
+       /*
+        *      If status is not zero, we're still waiting for something.
+        *      The event will be fired again when that happens.
+        */
+       if (c->status != 0) {
+               (void) fr_event_fd_insert(c, NULL, c->conn->el, c->fd,
+                                         c->status & MYSQL_WAIT_READ ? _sql_connect_io_notify : NULL,
+                                         c->status & MYSQL_WAIT_WRITE ? _sql_connect_io_notify : NULL, NULL, c);
+               return;
+       }
 
+connected:
+       if (!c->sock) {
+               ERROR("MySQL error: %s", mysql_error(&c->db));
+               fr_connection_signal_reconnect(c->conn, FR_CONNECTION_FAILED);
+               return;
+       }
+
+       DEBUG2("Connected to database on %s, server version %s, protocol version %i",
+              mysql_get_host_info(c->sock),
+              mysql_get_server_info(c->sock), mysql_get_proto_info(c->sock));
+
+       fr_connection_signal_connected(c->conn);
+}
+
+static fr_connection_state_t _sql_connection_init(void **h, fr_connection_t *conn, void *uctx)
+{
+       rlm_sql_t const         *sql = talloc_get_type_abort_const(uctx, rlm_sql_t);
+       rlm_sql_mysql_t const   *inst = talloc_get_type_abort(sql->driver_submodule->data, rlm_sql_mysql_t);
+       rlm_sql_mysql_conn_t    *c;
+       rlm_sql_config_t const  *config = &sql->config;
+
+       unsigned long           sql_flags;
        enum mysql_option       ssl_mysql_opt;
        unsigned int            ssl_mode = 0;
        bool                    ssl_mode_isset = false;
 
-       MEM(conn = handle->conn = talloc_zero(handle, rlm_sql_mysql_conn_t));
-       talloc_set_destructor(conn, _sql_socket_destructor);
+       MEM(c = talloc_zero(conn, rlm_sql_mysql_conn_t));
+       c->conn = conn;
+       c->fd = -1;
 
        DEBUG("Starting connect to MySQL server");
 
-       mysql_init(&(conn->db));
+       mysql_init(&c->db);
 
        /*
         *      If any of the TLS options are set, configure TLS
@@ -216,7 +245,7 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
         */
        if (inst->tls_ca_file || inst->tls_ca_path ||
            inst->tls_certificate_file || inst->tls_private_key_file) {
-               mysql_ssl_set(&(conn->db), inst->tls_private_key_file, inst->tls_certificate_file,
+               mysql_ssl_set(&(c->db), inst->tls_private_key_file, inst->tls_certificate_file,
                              inst->tls_ca_file, inst->tls_ca_path, inst->tls_cipher);
        }
 
@@ -245,12 +274,12 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
                ssl_mode_isset = true;
        }
 #endif
-       if (ssl_mode_isset) mysql_options(&(conn->db), ssl_mysql_opt, &ssl_mode);
+       if (ssl_mode_isset) mysql_options(&(c->db), ssl_mysql_opt, &ssl_mode);
 
-       if (inst->tls_crl_file) mysql_options(&(conn->db), MYSQL_OPT_SSL_CRL, inst->tls_crl_file);
-       if (inst->tls_crl_path) mysql_options(&(conn->db), MYSQL_OPT_SSL_CRLPATH, inst->tls_crl_path);
+       if (inst->tls_crl_file) mysql_options(&(c->db), MYSQL_OPT_SSL_CRL, inst->tls_crl_file);
+       if (inst->tls_crl_path) mysql_options(&(c->db), MYSQL_OPT_SSL_CRLPATH, inst->tls_crl_path);
 
-       mysql_options(&(conn->db), MYSQL_READ_DEFAULT_GROUP, "freeradius");
+       mysql_options(&(c->db), MYSQL_READ_DEFAULT_GROUP, "freeradius");
 
        /*
         *      We need to know about connection errors, and are capable
@@ -258,11 +287,9 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
         */
        {
                bool reconnect = 0;
-               mysql_options(&(conn->db), MYSQL_OPT_RECONNECT, &reconnect);
+               mysql_options(&(c->db), MYSQL_OPT_RECONNECT, &reconnect);
        }
 
-       mysql_options(&(conn->db), MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout);
-
        if (fr_time_delta_ispos(config->query_timeout)) {
                unsigned int read_timeout = fr_time_delta_to_sec(config->query_timeout);
                unsigned int write_timeout = fr_time_delta_to_sec(config->query_timeout);
@@ -285,8 +312,8 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
                 *      Connect timeout is actually connect timeout (according to the
                 *      docs) there are no automatic retries.
                 */
-               mysql_options(&(conn->db), MYSQL_OPT_READ_TIMEOUT, &read_timeout);
-               mysql_options(&(conn->db), MYSQL_OPT_WRITE_TIMEOUT, &write_timeout);
+               mysql_options(&(c->db), MYSQL_OPT_READ_TIMEOUT, &read_timeout);
+               mysql_options(&(c->db), MYSQL_OPT_WRITE_TIMEOUT, &write_timeout);
        }
 
        sql_flags = CLIENT_MULTI_RESULTS | CLIENT_FOUND_ROWS;
@@ -294,28 +321,58 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
 #ifdef CLIENT_MULTI_STATEMENTS
        sql_flags |= CLIENT_MULTI_STATEMENTS;
 #endif
-       conn->sock = mysql_real_connect(&(conn->db),
-                                       config->sql_server,
-                                       config->sql_login,
-                                       config->sql_password,
-                                       config->sql_db,
-                                       config->sql_port,
-                                       NULL,
-                                       sql_flags);
-       if (!conn->sock) {
-               ERROR("Couldn't connect to MySQL server %s@%s:%s", config->sql_login,
+
+       mysql_options(&c->db, MYSQL_OPT_NONBLOCK, 0);
+
+       c->status = mysql_real_connect_start(&c->sock, &c->db,
+                                            config->sql_server,
+                                            config->sql_login,
+                                            config->sql_password,
+                                            config->sql_db,
+                                            config->sql_port, NULL, sql_flags);
+
+       c->fd = mysql_get_socket(&c->db);
+
+       if (c->fd <= 0) {
+               ERROR("Could't connect to MySQL server %s@%s:%s", config->sql_login,
                      config->sql_server, config->sql_db);
-               ERROR("MySQL error: %s", mysql_error(&conn->db));
+               ERROR("MySQL error: %s", mysql_error(&c->db));
+       error:
+               talloc_free(c);
+               return FR_CONNECTION_STATE_FAILED;
+       }
 
-               conn->sock = NULL;
-               return RLM_SQL_ERROR;
+       if (c->status == 0) {
+               DEBUG2("Connected to database '%s' on %s, server version %s, protocol version %i",
+                      config->sql_db, mysql_get_host_info(c->sock),
+                      mysql_get_server_info(c->sock), mysql_get_proto_info(c->sock));
+               fr_connection_signal_connected(c->conn);
+               return FR_CONNECTION_STATE_CONNECTING;
        }
 
-       DEBUG2("Connected to database '%s' on %s, server version %s, protocol version %i",
-              config->sql_db, mysql_get_host_info(conn->sock),
-              mysql_get_server_info(conn->sock), mysql_get_proto_info(conn->sock));
+       if (fr_event_fd_insert(c, NULL, c->conn->el, c->fd,
+                              c->status & MYSQL_WAIT_READ ? _sql_connect_io_notify : NULL,
+                              c->status & MYSQL_WAIT_WRITE ? _sql_connect_io_notify : NULL, NULL, c) != 0) goto error;
 
-       return RLM_SQL_OK;
+       DEBUG2("Connecting to database '%s' on %s:%d, fd %d",
+              config->sql_db, config->sql_server, config->sql_port, c->fd);
+
+       *h = c;
+
+       return FR_CONNECTION_STATE_CONNECTING;
+}
+
+static void _sql_connection_close(fr_event_list_t *el, void *h, UNUSED void *uctx)
+{
+       rlm_sql_mysql_conn_t    *c = talloc_get_type_abort(h, rlm_sql_mysql_conn_t);
+
+       if (c->fd >= 0) {
+               fr_event_fd_delete(el, c->fd, FR_EVENT_FILTER_IO);
+               c->fd = -1;
+       }
+       mysql_close(&c->db);
+       c->query_ctx = NULL;
+       talloc_free(h);
 }
 
 /** Analyse the last error that occurred on the socket, and determine an action
@@ -387,26 +444,8 @@ static sql_rcode_t sql_check_error(MYSQL *server, int client_errno)
        return RLM_SQL_OK;
 }
 
-static unlang_action_t sql_query(rlm_rcode_t *p_result, UNUSED int *priority, UNUSED request_t *request, void *uctx)
-{
-       fr_sql_query_t          *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
-       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
-       char const *info;
-
-       mysql_query(conn->sock, query_ctx->query_str);
-       query_ctx->rcode = sql_check_error(conn->sock, 0);
-       if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
-
-       /* Only returns non-null string for INSERTS */
-       info = mysql_info(conn->sock);
-       if (info) DEBUG2("%s", info);
-
-       RETURN_MODULE_OK;
-}
-
-static sql_rcode_t sql_store_result(rlm_sql_handle_t *handle, UNUSED rlm_sql_config_t const *config)
+static sql_rcode_t sql_store_result(rlm_sql_mysql_conn_t *conn, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(handle->conn, rlm_sql_mysql_conn_t);
        sql_rcode_t rcode;
        int ret;
 
@@ -425,21 +464,9 @@ retry_store_result:
        return RLM_SQL_OK;
 }
 
-static unlang_action_t sql_select_query(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx)
-{
-       fr_sql_query_t  *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
-
-       sql_query(p_result, NULL, request, query_ctx);
-       if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
-
-       query_ctx->rcode = sql_store_result(query_ctx->handle, &query_ctx->inst->config);
-       if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
-       RETURN_MODULE_OK;
-}
-
 static int sql_num_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
 
        if (conn->result) return mysql_num_rows(conn->result);
 
@@ -448,7 +475,7 @@ static int sql_num_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const
 
 static sql_rcode_t sql_fields(char const **out[], fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
 
        unsigned int    fields, i;
        MYSQL_FIELD     *field_info;
@@ -480,8 +507,7 @@ static sql_rcode_t sql_fields(char const **out[], fr_sql_query_t *query_ctx, UNU
 static unlang_action_t sql_fetch_row(rlm_rcode_t *p_result, UNUSED int *priority, UNUSED request_t *request, void *uctx)
 {
        fr_sql_query_t          *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
-       rlm_sql_handle_t        *handle = query_ctx->handle;
-       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
        MYSQL_ROW               row;
        int                     ret;
        unsigned int            num_fields, i;
@@ -503,12 +529,13 @@ retry_fetch_row:
                query_ctx->rcode = sql_check_error(conn->sock, 0);
                if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
 
-               sql_free_result(query_ctx, &query_ctx->inst->config);
+               mysql_free_result(conn->result);
+               conn->result = NULL;
 
                ret = mysql_next_result(conn->sock);
                if (ret == 0) {
                        /* there are more results */
-                       if ((sql_store_result(handle, &query_ctx->inst->config) == 0) && (conn->result != NULL)) {
+                       if ((sql_store_result(conn, &query_ctx->inst->config) == 0) && (conn->result != NULL)) {
                                goto retry_fetch_row;
                        }
                } else if (ret > 0) {
@@ -541,7 +568,7 @@ retry_fetch_row:
 
 static sql_rcode_t sql_free_result(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
 
        if (conn->result) {
                mysql_free_result(conn->result);
@@ -636,10 +663,13 @@ static size_t sql_error(TALLOC_CTX *ctx, sql_log_entry_t out[], size_t outlen,
                        fr_sql_query_t *query_ctx, rlm_sql_config_t const *config)
 {
        rlm_sql_mysql_t const   *inst = talloc_get_type_abort_const(query_ctx->inst->driver_submodule->data, rlm_sql_mysql_t);
-       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t    *conn;
        char const              *error;
        size_t                  i = 0;
 
+       if (!query_ctx->tconn) return 0;
+       conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
+
        fr_assert(outlen > 0);
 
        error = mysql_error(conn->sock);
@@ -703,10 +733,31 @@ static size_t sql_error(TALLOC_CTX *ctx, sql_log_entry_t out[], size_t outlen,
  */
 static sql_rcode_t sql_finish_query(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t    *conn;
        int                     ret;
        MYSQL_RES               *result;
 
+       /*
+        *      If the query is not in a state which would return results, then do nothing.
+        */
+       if (query_ctx->treq && !(query_ctx->treq->state &
+           (FR_TRUNK_REQUEST_STATE_SENT | FR_TRUNK_REQUEST_STATE_IDLE | FR_TRUNK_REQUEST_STATE_COMPLETE))) return RLM_SQL_OK;
+
+       /*
+        *      If the connection doesn't exist there's nothing to do
+        */
+       if (!query_ctx->tconn || !query_ctx->tconn->conn || !query_ctx->tconn->conn->h) return RLM_SQL_ERROR;
+
+       conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
+
+       /*
+        *      If the connection is not active, then all that we can do is free any stored results
+        */
+       if (query_ctx->tconn->conn->state != FR_CONNECTION_STATE_CONNECTED) {
+               sql_free_result(query_ctx, config);
+               return RLM_SQL_OK;
+       }
+
        /*
         *      If there's no result associated with the
         *      connection handle, assume the first result in the
@@ -748,7 +799,7 @@ static sql_rcode_t sql_finish_query(fr_sql_query_t *query_ctx, rlm_sql_config_t
 
 static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->handle->conn, rlm_sql_mysql_conn_t);
+       rlm_sql_mysql_conn_t *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_mysql_conn_t);
 
        return mysql_affected_rows(conn->sock);
 }
@@ -756,8 +807,8 @@ static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t
 static size_t sql_escape_func(UNUSED request_t *request, char *out, size_t outlen, char const *in, void *arg)
 {
        size_t                  inlen;
-       rlm_sql_handle_t        *handle = talloc_get_type_abort(arg, rlm_sql_handle_t);
-       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(handle->conn, rlm_sql_mysql_conn_t);
+       fr_connection_t         *c = talloc_get_type_abort(arg, fr_connection_t);
+       rlm_sql_mysql_conn_t    *conn = talloc_get_type_abort(c->h, rlm_sql_mysql_conn_t);
 
        /* Check for potential buffer overflow */
        inlen = strlen(in);
@@ -765,9 +816,349 @@ static size_t sql_escape_func(UNUSED request_t *request, char *out, size_t outle
        /* Prevent integer overflow */
        if ((inlen * 2 + 1) <= inlen) return 0;
 
-       return mysql_real_escape_string(conn->sock, out, in, inlen);
+       return mysql_real_escape_string(&conn->db, out, in, inlen);
+}
+
+static void sql_conn_writable(UNUSED fr_event_list_t *el, UNUSED int fd, UNUSED int flags, void *uctx)
+{
+       fr_trunk_connection_t   *tconn = talloc_get_type_abort(uctx, fr_trunk_connection_t);
+       fr_trunk_connection_signal_writable(tconn);
+}
+
+static void sql_conn_readable(UNUSED fr_event_list_t *el, UNUSED int fd, UNUSED int flags, void *uctx)
+{
+       fr_trunk_connection_t   *tconn = talloc_get_type_abort(uctx, fr_trunk_connection_t);
+       fr_trunk_connection_signal_readable(tconn);
+}
+
+static void sql_conn_error(UNUSED fr_event_list_t *el, UNUSED int fd, UNUSED int flags, int fd_errno, void *uctx)
+{
+       fr_trunk_connection_t   *tconn = talloc_get_type_abort(uctx, fr_trunk_connection_t);
+       ERROR("%s - Connection failed: %s", tconn->conn->name, fr_syserror(fd_errno));
+       fr_connection_signal_reconnect(tconn->conn, FR_CONNECTION_FAILED);
+}
+
+/** Allocate an SQL trunk connection
+ *
+ * @param[in] tconn            Trunk handle.
+ * @param[in] el               Event list which will be used for I/O and timer events.
+ * @param[in] conn_conf                Configuration of the connection.
+ * @param[in] log_prefix       What to prefix log messages with.
+ * @param[in] uctx             User context passed to fr_trunk_alloc.
+ */
+static fr_connection_t *sql_trunk_connection_alloc(fr_trunk_connection_t *tconn, fr_event_list_t *el,
+                                                  fr_connection_conf_t const *conn_conf,
+                                                  char const *log_prefix, void *uctx)
+{
+       fr_connection_t         *conn;
+       rlm_sql_thread_t        *thread = talloc_get_type_abort(uctx, rlm_sql_thread_t);
+
+       conn = fr_connection_alloc(tconn, el,
+                                  &(fr_connection_funcs_t){
+                                       .init = _sql_connection_init,
+                                       .close = _sql_connection_close
+                                  },
+                                  conn_conf, log_prefix, thread->inst);
+       if (!conn) {
+               PERROR("Failed allocating state handler for new SQL connection");
+               return NULL;
+       }
+
+       return conn;
+}
+
+static void sql_trunk_connection_notify(fr_trunk_connection_t *tconn, fr_connection_t *conn,
+                                       fr_event_list_t *el,
+                                       fr_trunk_connection_event_t notify_on, UNUSED void *uctx)
+{
+       rlm_sql_mysql_conn_t    *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
+       fr_event_fd_cb_t        read_fn = NULL, write_fn = NULL;
+
+       switch (notify_on) {
+       case FR_TRUNK_CONN_EVENT_NONE:
+               fr_event_fd_delete(el, sql_conn->fd, FR_EVENT_FILTER_IO);
+               return;
+
+       case FR_TRUNK_CONN_EVENT_READ:
+               read_fn = sql_conn_readable;
+               break;
+
+       case FR_TRUNK_CONN_EVENT_WRITE:
+               write_fn = sql_conn_writable;
+               break;
+
+       case FR_TRUNK_CONN_EVENT_BOTH:
+               read_fn = sql_conn_readable;
+               write_fn = sql_conn_writable;
+               break;
+       }
+
+       if (fr_event_fd_insert(sql_conn, NULL, el, sql_conn->fd, read_fn, write_fn, sql_conn_error, tconn) < 0) {
+               PERROR("Failed inserting FD event");
+               fr_trunk_connection_signal_reconnect(tconn, FR_CONNECTION_FAILED);
+       }
 }
 
+static void sql_trunk_request_mux(UNUSED fr_event_list_t *el, fr_trunk_connection_t *tconn,
+                                 fr_connection_t *conn, UNUSED void *uctx)
+{
+       rlm_sql_mysql_conn_t    *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
+       request_t               *request;
+       fr_trunk_request_t      *treq;
+       fr_sql_query_t          *query_ctx;
+       char const              *info;
+       int                     err;
+
+       if (fr_trunk_connection_pop_request(&treq, tconn) != 0) return;
+       if (!treq) return;
+
+       query_ctx = talloc_get_type_abort(treq->preq, fr_sql_query_t);
+       request = query_ctx->request;
+
+       /*
+        *      Each of the MariaDB async "start" calls returns a non-zero value
+        *      if they are waiting on I/O.
+        *      A return value of zero means that the operation completed.
+        */
+
+       switch (query_ctx->status) {
+       case SQL_QUERY_PREPARED:
+               ROPTIONAL(RDEBUG2, DEBUG2, "Executing query: %s", query_ctx->query_str);
+               sql_conn->status = mysql_real_query_start(&err, sql_conn->sock, query_ctx->query_str, strlen(query_ctx->query_str));
+               query_ctx->tconn = tconn;
+
+               if (sql_conn->status) {
+                       ROPTIONAL(RDEBUG3, DEBUG3, "Waiting for IO");
+                       query_ctx->status = SQL_QUERY_SUBMITTED;
+                       sql_conn->query_ctx = query_ctx;
+                       fr_trunk_request_signal_sent(treq);
+                       return;
+               }
+
+               if (err) {
+                       /*
+                        *      Need to check what kind of error this is - it may
+                        *      be a unique key conflict, we run the next query.
+                        */
+                       info = mysql_info(sql_conn->sock);
+                       query_ctx->rcode = sql_check_error(sql_conn->sock, 0);
+                       if (info) ERROR("%s", info);
+                       switch (query_ctx->rcode) {
+                       case RLM_SQL_OK:
+                       case RLM_SQL_ALT_QUERY:
+                               break;
+
+                       default:
+                               query_ctx->status = SQL_QUERY_FAILED;
+                               fr_trunk_request_signal_fail(treq);
+                               return;
+                       }
+               }
+               query_ctx->status = SQL_QUERY_RETURNED;
+
+               break;
+
+       case SQL_QUERY_RETURNED:
+               ROPTIONAL(RDEBUG2, DEBUG2, "Fetching results");
+               fr_assert(query_ctx->tconn == tconn);
+               sql_conn->status = mysql_store_result_start(&sql_conn->result, sql_conn->sock);
+
+               if (sql_conn->status) {
+                       ROPTIONAL(RDEBUG3, DEBUG3, "Waiting for IO");
+                       query_ctx->status = SQL_QUERY_FETCHING_RESULTS;
+                       sql_conn->query_ctx = query_ctx;
+                       fr_trunk_request_signal_sent(treq);
+                       return;
+               }
+               query_ctx->status = SQL_QUERY_RESULTS_FETCHED;
+
+               break;
+
+       default:
+               /*
+                *      The request outstanding on this connection returned
+                *      immediately, so we are not actually waiting for I/O.
+                */
+               return;
+       }
+
+       /*
+        *      The current request is not waiting for I/O so the request can run
+        */
+       ROPTIONAL(RDEBUG3, DEBUG3, "Got immediate response");
+       fr_trunk_request_signal_idle(treq);
+       if (request) unlang_interpret_mark_runnable(request);
+}
+
+static void sql_trunk_request_demux(UNUSED fr_event_list_t *el, UNUSED fr_trunk_connection_t *tconn,
+                                   fr_connection_t *conn, UNUSED void *uctx)
+{
+       rlm_sql_mysql_conn_t    *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
+       fr_sql_query_t          *query_ctx;
+       char const              *info;
+       int                     err = 0;
+       request_t               *request;
+
+       /*
+        *      Lookup the outstanding SQL query for this connection.
+        *      There will only ever be one per tconn.
+        */
+       query_ctx = sql_conn->query_ctx;
+
+       /*
+        *      No outstanding query on this connection.
+        *      Should not happen, but added for safety.
+        */
+       if (unlikely(!query_ctx)) return;
+
+       switch (query_ctx->status) {
+       case SQL_QUERY_SUBMITTED:
+               sql_conn->status = mysql_real_query_cont(&err, sql_conn->sock, sql_conn->status);
+               break;
+
+       case SQL_QUERY_FETCHING_RESULTS:
+               sql_conn->status = mysql_store_result_cont(&sql_conn->result, sql_conn->sock, sql_conn->status);
+               break;
+
+       default:
+               /*
+                *      The request outstanding on this connection returned
+                *      immediately, so we are not actually waiting for I/O.
+                */
+               return;
+       }
+
+       /*
+        *      Are we still waiting for any further I/O?
+        */
+       if (sql_conn->status != 0) return;
+
+       sql_conn->query_ctx = NULL;
+
+       switch (query_ctx->status) {
+       case SQL_QUERY_SUBMITTED:
+               query_ctx->status = SQL_QUERY_RETURNED;
+               break;
+
+       case SQL_QUERY_FETCHING_RESULTS:
+               query_ctx->status = SQL_QUERY_RESULTS_FETCHED;
+               break;
+
+       default:
+               fr_assert(0);
+       }
+
+       request = query_ctx->request;
+       if (request) unlang_interpret_mark_runnable(request);
+
+       if (err) {
+               info = mysql_info(sql_conn->sock);
+               query_ctx->rcode = sql_check_error(sql_conn->sock, 0);
+               if (info) ROPTIONAL(RERROR, ERROR, "%s", info);
+               return;
+       }
+
+       query_ctx->rcode = RLM_SQL_OK;
+}
+
+static void sql_request_cancel(fr_connection_t *conn, void *preq, fr_trunk_cancel_reason_t reason,
+                              UNUSED void *uctx)
+{
+       fr_sql_query_t          *query_ctx = talloc_get_type_abort(preq, fr_sql_query_t);
+       rlm_sql_mysql_conn_t    *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
+
+       if (!query_ctx->treq) return;
+       if (reason != FR_TRUNK_CANCEL_REASON_SIGNAL) return;
+       if (sql_conn->query_ctx == query_ctx) sql_conn->query_ctx = NULL;
+}
+
+static void sql_request_cancel_mux(UNUSED fr_event_list_t *el, fr_trunk_connection_t *tconn,
+                                  fr_connection_t *conn, UNUSED void *uctx)
+{
+       fr_trunk_request_t      *treq;
+
+       /*
+        *      The MariaDB non-blocking API doesn't have any cancellation functions -
+        *      rather you are expected to close the connection.
+        */
+       if ((fr_trunk_connection_pop_cancellation(&treq, tconn)) == 0) {
+               fr_trunk_request_signal_cancel_complete(treq);
+               fr_connection_signal_reconnect(conn, FR_CONNECTION_FAILED);
+       }
+}
+
+static void sql_request_fail(request_t *request, void *preq, UNUSED void *rctx,
+                            UNUSED fr_trunk_request_state_t state, UNUSED void *uctx)
+{
+       fr_sql_query_t          *query_ctx = talloc_get_type_abort(preq, fr_sql_query_t);
+
+       query_ctx->treq = NULL;
+       query_ctx->rcode = RLM_SQL_ERROR;
+
+       if (request) unlang_interpret_mark_runnable(request);
+}
+
+static unlang_action_t sql_query_resume(rlm_rcode_t *p_result, UNUSED int *priority, UNUSED request_t *request, void *uctx)
+{
+       fr_sql_query_t          *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
+
+       if (query_ctx->rcode == RLM_SQL_OK) RETURN_MODULE_OK;
+       RETURN_MODULE_FAIL;
+}
+
+static unlang_action_t sql_select_query_resume(rlm_rcode_t *p_result, UNUSED int *priority, UNUSED request_t *request, void *uctx)
+{
+       fr_sql_query_t          *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
+
+       if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
+
+       if (query_ctx->status == SQL_QUERY_RETURNED) {
+               fr_trunk_request_requeue(query_ctx->treq);
+
+               if (unlang_function_repeat_set(request, sql_select_query_resume) < 0) {
+                       query_ctx->rcode = RLM_SQL_ERROR;
+                       RETURN_MODULE_FAIL;
+               }
+
+               return UNLANG_ACTION_YIELD;
+       }
+
+       RETURN_MODULE_OK;
+}
+
+/** Allocate the argument used for the SQL escape function
+ *
+ * In this case, a dedicated connection to allow the escape
+ * function to have access to server side parameters, though
+ * no packets ever flow after the connection is made.
+ */
+static void *sql_escape_arg_alloc(TALLOC_CTX *ctx, fr_event_list_t *el, void *uctx)
+{
+       rlm_sql_t const *inst = talloc_get_type_abort(uctx, rlm_sql_t);
+       fr_connection_t *conn;
+
+       conn = fr_connection_alloc(ctx, el,
+                                   &(fr_connection_funcs_t){
+                                       .init = _sql_connection_init,
+                                       .close = _sql_connection_close,
+                                   },
+                                   inst->config.trunk_conf.conn_conf,
+                                   inst->name, inst);
+
+       if (!conn) {
+               PERROR("Failed allocating state handler for SQL escape connection");
+               return NULL;
+       }
+
+       fr_connection_signal_init(conn);
+       return conn;
+}
+
+static void sql_escape_arg_free(void *uctx)
+{
+       fr_connection_t *conn = talloc_get_type_abort(uctx, fr_connection_t);
+       fr_connection_signal_halt(conn);
+}
 
 /* Exported to rlm_sql */
 extern rlm_sql_driver_t rlm_sql_mysql;
@@ -782,9 +1173,8 @@ rlm_sql_driver_t rlm_sql_mysql = {
                .instantiate                    = mod_instantiate
        },
        .flags                          = RLM_SQL_RCODE_FLAGS_ALT_QUERY,
-       .sql_socket_init                = sql_socket_init,
-       .sql_query                      = sql_query,
-       .sql_select_query               = sql_select_query,
+       .sql_query_resume               = sql_query_resume,
+       .sql_select_query_resume        = sql_select_query_resume,
        .sql_num_rows                   = sql_num_rows,
        .sql_affected_rows              = sql_affected_rows,
        .sql_fields                     = sql_fields,
@@ -793,5 +1183,17 @@ rlm_sql_driver_t rlm_sql_mysql = {
        .sql_error                      = sql_error,
        .sql_finish_query               = sql_finish_query,
        .sql_finish_select_query        = sql_finish_query,
-       .sql_escape_func                = sql_escape_func
+       .sql_escape_func                = sql_escape_func,
+       .sql_escape_arg_alloc           = sql_escape_arg_alloc,
+       .sql_escape_arg_free            = sql_escape_arg_free,
+       .uses_trunks                    = true,
+       .trunk_io_funcs = {
+               .connection_alloc       = sql_trunk_connection_alloc,
+               .connection_notify      = sql_trunk_connection_notify,
+               .request_mux            = sql_trunk_request_mux,
+               .request_demux          = sql_trunk_request_demux,
+               .request_cancel_mux     = sql_request_cancel_mux,
+               .request_cancel         = sql_request_cancel,
+               .request_fail           = sql_request_fail,
+       }
 };
index 5fe28cc46121aabafa996b857fe5537fbfbad127..513169703fbb1fbb87099f6d9ccd616e148151d2 100644 (file)
@@ -128,6 +128,8 @@ typedef enum {
        SQL_QUERY_FAILED = -1,                                  //!< Failed to submit.
        SQL_QUERY_PREPARED = 0,                                 //!< Ready to submit.
        SQL_QUERY_SUBMITTED,                                    //!< Submitted for execution.
+       SQL_QUERY_RETURNED,                                     //!< Query has executed.
+       SQL_QUERY_FETCHING_RESULTS,                             //!< Fetching results from server.
        SQL_QUERY_RESULTS_FETCHED                               //!< Results fetched from the server.
 } fr_sql_query_status_t;