]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Re-work rlm_sql_db2 to use trunks
authorNick Porter <nick@portercomputing.co.uk>
Tue, 26 Nov 2024 12:12:59 +0000 (12:12 +0000)
committerNick Porter <nick@portercomputing.co.uk>
Tue, 26 Nov 2024 12:12:59 +0000 (12:12 +0000)
The client library doesn't support async

src/modules/rlm_sql/drivers/rlm_sql_db2/rlm_sql_db2.c

index 2e76d4c3389ed8a086ee1d9a42864753e37f9924..99ec27a1574b142e88099ce54a00f62ff5c68d6a 100644 (file)
@@ -37,7 +37,9 @@ RCSID("$Id$")
 #include <sys/stat.h>
 
 #include <sqlcli.h>
+#include <sqlstate.h>
 #include "rlm_sql.h"
+#include "rlm_sql_trunk.h"
 
 typedef struct {
        SQLHANDLE dbc_handle;
@@ -45,8 +47,10 @@ typedef struct {
        SQLHANDLE stmt;
 } rlm_sql_db2_conn_t;
 
-static int _sql_socket_destructor(rlm_sql_db2_conn_t *conn)
+static void _sql_connection_close(UNUSED fr_event_list_t *el, void *h, UNUSED void *uctx)
 {
+       rlm_sql_db2_conn_t      *conn = talloc_get_type_abort(h, rlm_sql_db2_conn_t);
+
        DEBUG2("Socket destructor called, closing socket");
 
        if (conn->dbc_handle) {
@@ -56,24 +60,25 @@ static int _sql_socket_destructor(rlm_sql_db2_conn_t *conn)
 
        if (conn->env_handle) SQLFreeHandle(SQL_HANDLE_ENV, conn->env_handle);
 
-       return RLM_SQL_OK;
+       talloc_free(h);
 }
 
-static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t const *config,
-                                  UNUSED fr_time_delta_t timeout)
+CC_NO_UBSAN(function) /* UBSAN: false positive - public vs private connection_t trips --fsanitize=function */
+static connection_state_t _sql_connection_init(void **h, connection_t *conn, void *uctx)
 {
-       SQLRETURN row;
 #if 0
        uint32_t timeout_ms = FR_TIMEVAL_TO_MS(timeout);
 #endif
-       rlm_sql_db2_conn_t *conn;
+       rlm_sql_db2_conn_t      *c;
+       rlm_sql_t const         *sql = talloc_get_type_abort_const(uctx, rlm_sql_t);
+       rlm_sql_config_t const  *config = &sql->config;
+       SQLRETURN               ret;
 
-       MEM(conn = handle->conn = talloc_zero(handle, rlm_sql_db2_conn_t));
-       talloc_set_destructor(conn, _sql_socket_destructor);
+       MEM(c = talloc_zero(conn, rlm_sql_db2_conn_t));
 
        /* Allocate handles */
-       SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &(conn->env_handle));
-       SQLAllocHandle(SQL_HANDLE_DBC, conn->env_handle, &(conn->dbc_handle));
+       SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &(c->env_handle));
+       SQLAllocHandle(SQL_HANDLE_DBC, c->env_handle, &(c->dbc_handle));
 
        /* Set the connection timeout */
 #if 0
@@ -93,51 +98,79 @@ static sql_rcode_t sql_socket_init(rlm_sql_handle_t *handle, rlm_sql_config_t co
         *
         *      Driver={IBM DB2 ODBC Driver};Database=testDb;Hostname=remoteHostName.com;UID=username;PWD=mypasswd;PORT=50000
         */
-       row = SQLConnect(conn->dbc_handle,
+       ret = SQLConnect(c->dbc_handle,
                         UNCONST(SQLCHAR *, config->sql_server), SQL_NTS,
                         UNCONST(SQLCHAR *, config->sql_login), SQL_NTS,
                         UNCONST(SQLCHAR *, config->sql_password), SQL_NTS);
-       if (row != SQL_SUCCESS) {
+       if (ret != SQL_SUCCESS) {
                ERROR("could not connect to DB2 server %s", config->sql_server);
 
-               return RLM_SQL_ERROR;
+               return CONNECTION_STATE_FAILED;
        }
 
-       return RLM_SQL_OK;
+       *h = c;
+       return CONNECTION_STATE_CONNECTED;
 }
 
-static unlang_action_t sql_query(rlm_rcode_t *p_result, UNUSED int *priority, UNUSED request_t *request, void *uctx)
+SQL_TRUNK_CONNECTION_ALLOC
+
+SQL_QUERY_RESUME
+
+CC_NO_UBSAN(function) /* UBSAN: false positive - public vs private connection_t trips --fsanitize=function */
+static void sql_trunk_request_mux(UNUSED fr_event_list_t *el, trunk_connection_t *tconn,
+                                 connection_t *conn, UNUSED void *uctx)
 {
-       fr_sql_query_t          *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
-       SQLRETURN row;
-       rlm_sql_db2_conn_t *conn;
+       rlm_sql_db2_conn_t      *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_db2_conn_t);
+       trunk_request_t         *treq;
+       request_t               *request;
+       fr_sql_query_t          *query_ctx;
+       SQLRETURN               ret;
+       SQLCHAR                 *db2_query;
 
-       conn = query_ctx->handle->conn;
+       if (trunk_connection_pop_request(&treq, tconn) != 0) return;
+       if (!treq) return;
+
+       query_ctx = talloc_get_type_abort(treq->preq, fr_sql_query_t);
+       request = query_ctx->request;
+       query_ctx->tconn = tconn;
+
+       ROPTIONAL(RDEBUG2, DEBUG2, "Executing query: %s", query_ctx->query_str);
 
        /* allocate handle for statement */
-       SQLAllocHandle(SQL_HANDLE_STMT, conn->dbc_handle, &(conn->stmt));
+       SQLAllocHandle(SQL_HANDLE_STMT, sql_conn->dbc_handle, &(sql_conn->stmt));
 
        /* execute query */
-       {
-               SQLCHAR *db2_query;
-               memcpy(&db2_query, &query_ctx->query_str, sizeof(query_ctx->query_str));
-
-               row = SQLExecDirect(conn->stmt, db2_query, SQL_NTS);
-               if(row != SQL_SUCCESS) {
-                       /* XXX Check if row means we should return RLM_SQL_RECONNECT */
-                       ERROR("Could not execute statement \"%s\"", query_ctx->query_str);
-                       query_ctx->rcode = RLM_SQL_ERROR;
-                       RETURN_MODULE_FAIL;
+       memcpy(&db2_query, &query_ctx->query_str, sizeof(query_ctx->query_str));
+
+       ret = SQLExecDirect(sql_conn->stmt, db2_query, SQL_NTS);
+       if (ret != SQL_SUCCESS) {
+               SQLCHAR         state[6];
+               SQLSMALLINT     len;
+
+               SQLGetDiagField(SQL_HANDLE_STMT, sql_conn->dbc_handle, 1, SQL_DIAG_SQLSTATE, state, sizeof(state), &len);
+
+               if (strncmp((char *)state, SQL_CONSTR_INDEX_UNIQUE, 5)) {
+                       query_ctx->rcode = RLM_SQL_ALT_QUERY;
+                       goto finish;
                }
+
+               /* XXX Check if ret means we should return RLM_SQL_RECONNECT */
+               ERROR("Could not execute statement \"%s\"", query_ctx->query_str);
+               query_ctx->rcode = RLM_SQL_ERROR;
+               trunk_request_signal_fail(treq);
+               return;
        }
 
        query_ctx->rcode = RLM_SQL_OK;
-       RETURN_MODULE_OK;
+finish:
+       query_ctx->status = SQL_QUERY_RETURNED;
+       trunk_request_signal_reapable(treq);
+       if (request) unlang_interpret_mark_runnable(request);
 }
 
 static sql_rcode_t sql_fields(char const **out[], fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
-       rlm_sql_db2_conn_t *conn = query_ctx->handle->conn;
+       rlm_sql_db2_conn_t *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_db2_conn_t);
 
        SQLSMALLINT     fields, len, i;
 
@@ -180,12 +213,10 @@ static unlang_action_t sql_fetch_row(rlm_rcode_t *p_result, UNUSED int *priority
        SQLINTEGER              len, slen;
        SQLSMALLINT             c;
        rlm_sql_row_t           row;
-       rlm_sql_db2_conn_t      *conn;
-       rlm_sql_handle_t        *handle = query_ctx->handle;
+       rlm_sql_db2_conn_t      *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_db2_conn_t);
 
        TALLOC_FREE(query_ctx->row);
 
-       conn = handle->conn;
        SQLNumResultCols(conn->stmt, &c);
 
        /* advance cursor */
@@ -241,7 +272,7 @@ static size_t sql_error(TALLOC_CTX *ctx, sql_log_entry_t out[], NDEBUG_UNUSED si
        char                    errbuff[1024];
        SQLINTEGER              err;
        SQLSMALLINT             rl;
-       rlm_sql_db2_conn_t      *conn = query_ctx->handle->conn;
+       rlm_sql_db2_conn_t      *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_db2_conn_t);
 
        fr_assert(conn);
        fr_assert(outlen > 0);
@@ -265,13 +296,23 @@ static sql_rcode_t sql_finish_query(UNUSED fr_sql_query_t *query_ctx, UNUSED rlm
 static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t const *config)
 {
        SQLINTEGER c;
-       rlm_sql_db2_conn_t *conn = query_ctx->handle->conn;
+       rlm_sql_db2_conn_t      *conn = talloc_get_type_abort(query_ctx->tconn->conn->h, rlm_sql_db2_conn_t);
 
        SQLRowCount(conn->stmt, &c);
 
        return c;
 }
 
+static void sql_request_fail(request_t *request, void *preq, UNUSED void *rctx,
+                            UNUSED trunk_request_state_t state, UNUSED void *uctx)
+{
+       fr_sql_query_t          *query_ctx = talloc_get_type_abort(preq, fr_sql_query_t);
+
+       query_ctx->treq = NULL;
+       if (query_ctx->rcode == RLM_SQL_OK) query_ctx->rcode = RLM_SQL_ERROR;
+       if (request) unlang_interpret_mark_runnable(request);
+}
+
 /* Exported to rlm_sql */
 extern rlm_sql_driver_t rlm_sql_db2;
 rlm_sql_driver_t rlm_sql_db2 = {
@@ -279,14 +320,20 @@ rlm_sql_driver_t rlm_sql_db2 = {
                .magic                          = MODULE_MAGIC_INIT,
                .name                           = "sql_db2",
        },
-       .sql_socket_init                = sql_socket_init,
-       .sql_query                      = sql_query,
-       .sql_select_query               = sql_query,
+       .flags                          = RLM_SQL_RCODE_FLAGS_ALT_QUERY,
+       .sql_query_resume               = sql_query_resume,
+       .sql_select_query_resume        = sql_query_resume,
        .sql_affected_rows              = sql_affected_rows,
        .sql_fields                     = sql_fields,
        .sql_fetch_row                  = sql_fetch_row,
        .sql_free_result                = sql_free_result,
        .sql_error                      = sql_error,
        .sql_finish_query               = sql_finish_query,
-       .sql_finish_select_query        = sql_finish_query
+       .sql_finish_select_query        = sql_finish_query,
+       .uses_trunks                    = true,
+       .trunk_io_funcs = {
+               .connection_alloc       = sql_trunk_connection_alloc,
+               .request_mux            = sql_trunk_request_mux,
+               .request_fail           = sql_request_fail,
+       }
 };