]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Change function signature of rlm_sql_select_query() to be unlang_function_t
authorNick Porter <nick@portercomputing.co.uk>
Mon, 29 Apr 2024 15:47:05 +0000 (16:47 +0100)
committerArran Cudbard-Bell <a.cudbardb@freeradius.org>
Fri, 7 Jun 2024 02:26:58 +0000 (22:26 -0400)
src/modules/rlm_sql/rlm_sql.c
src/modules/rlm_sql/rlm_sql.h
src/modules/rlm_sql/sql.c
src/modules/rlm_sqlippool/rlm_sqlippool.c

index b82e242564b14707045a7953e869e0593cb164ee..b857015f76a04740247c88293b2d7b1118bc0770 100644 (file)
@@ -435,7 +435,7 @@ static xlat_action_t sql_xlat(TALLOC_CTX *ctx, fr_dcursor_t *out,
                if (query_ctx->rcode != RLM_SQL_OK) {
                query_error:
                        RERROR("SQL query failed: %s", fr_table_str_by_value(sql_rcode_description_table,
-                                                                            query_ctx ? query_ctx->rcode : rcode, "<INVALID>"));
+                                                                            query_ctx->rcode, "<INVALID>"));
 
                        ret = XLAT_ACTION_FAIL;
                        goto finish;
@@ -455,11 +455,13 @@ static xlat_action_t sql_xlat(TALLOC_CTX *ctx, fr_dcursor_t *out,
                goto finish;
        } /* else it's a SELECT statement */
 
-       rcode = rlm_sql_select_query(inst, request, &handle, arg->vb_strvalue);
-       if (rcode != RLM_SQL_OK) goto query_error;
+       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, handle,
+                                          arg->vb_strvalue, SQL_QUERY_SELECT));
+       rlm_sql_select_query(&p_result, NULL, request, query_ctx);
+       if (query_ctx->rcode != RLM_SQL_OK) goto query_error;
 
        do {
-               rcode = rlm_sql_fetch_row(&row, inst, request, &handle);
+               rcode = rlm_sql_fetch_row(&row, inst, request, &query_ctx->handle);
                switch (rcode) {
                case RLM_SQL_OK:
                        if (row[0]) break;
@@ -467,7 +469,7 @@ static xlat_action_t sql_xlat(TALLOC_CTX *ctx, fr_dcursor_t *out,
                        RDEBUG2("NULL value in first column of result");
                        ret = XLAT_ACTION_FAIL;
 
-                       goto finish_query;
+                       goto finish;
 
                case RLM_SQL_NO_MORE_ROWS:
                        if (!fetched) {
@@ -475,10 +477,9 @@ static xlat_action_t sql_xlat(TALLOC_CTX *ctx, fr_dcursor_t *out,
                                ret = XLAT_ACTION_FAIL;
                        }
 
-                       goto finish_query;
+                       goto finish;
 
                default:
-                       (inst->driver->sql_finish_select_query)(handle, &inst->config);
                        goto query_error;
                }
 
@@ -490,11 +491,8 @@ static xlat_action_t sql_xlat(TALLOC_CTX *ctx, fr_dcursor_t *out,
 
        } while (1);
 
-finish_query:
-       (inst->driver->sql_finish_select_query)(handle, &inst->config);
-
 finish:
-       handle = query_ctx->handle ? query_ctx->handle : handle;
+       handle = query_ctx->handle;
        talloc_free(query_ctx);
        fr_pool_connection_release(inst->pool, request, handle);
 
@@ -584,6 +582,7 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
 
        char const              *query_str = NULL;
        fr_value_box_t          *query_head = fr_value_box_list_head(query);
+       fr_sql_query_t          *query_ctx;
 
 #define MAX_SQL_FIELD_INDEX (64)
 
@@ -613,9 +612,12 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
                RETURN_MODULE_FAIL;
        }
 
-       ret = rlm_sql_select_query(inst, request, &handle, query_str);
-       if (ret != RLM_SQL_OK) {
-               RERROR("SQL query failed: %s", fr_table_str_by_value(sql_rcode_description_table, ret, "<INVALID>"));
+       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, handle, query_str, SQL_QUERY_SELECT));
+       rlm_sql_select_query(p_result, NULL, request, query_ctx);
+       handle = query_ctx->handle;
+
+       if (query_ctx->rcode != RLM_SQL_OK) {
+               RERROR("SQL query failed: %s", fr_table_str_by_value(sql_rcode_description_table, query_ctx->rcode, "<INVALID>"));
                rcode = RLM_MODULE_FAIL;
                goto finish;
        }
@@ -628,7 +630,6 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
                if (ret == 0) {
                        RDEBUG2("Server returned an empty result");
                        rcode = RLM_MODULE_NOOP;
-                       (inst->driver->sql_finish_select_query)(handle, &inst->config);
                        goto finish;
                }
 
@@ -636,7 +637,6 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
                        RERROR("Failed retrieving row count");
                error:
                        rcode = RLM_MODULE_FAIL;
-                       (inst->driver->sql_finish_select_query)(handle, &inst->config);
                        goto finish;
                }
        }
@@ -685,7 +685,6 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
        if (!found_field) {
                RDEBUG2("No fields matching map found in query result");
                rcode = RLM_MODULE_NOOP;
-               (inst->driver->sql_finish_select_query)(handle, &inst->config);
                goto finish;
        }
 
@@ -713,10 +712,9 @@ static unlang_action_t mod_map_proc(rlm_rcode_t *p_result, void const *mod_inst,
                rcode = RLM_MODULE_NOOP;
        }
 
-       (inst->driver->sql_finish_select_query)(handle, &inst->config);
-
 finish:
        talloc_free(fields);
+       talloc_free(query_ctx);
        fr_pool_connection_release(inst->pool, request, handle);
 
        RETURN_MODULE_RCODE(rcode);
@@ -876,7 +874,8 @@ static int sql_get_grouplist(rlm_sql_t const *inst, rlm_sql_handle_t **handle, r
        int                     num_groups = 0;
        rlm_sql_row_t           row;
        rlm_sql_grouplist_t     *entry;
-       int                     ret;
+       rlm_rcode_t             p_result;
+       fr_sql_query_t          *query_ctx;
 
        /* NOTE: sql_set_user should have been run before calling this function */
 
@@ -884,8 +883,14 @@ static int sql_get_grouplist(rlm_sql_t const *inst, rlm_sql_handle_t **handle, r
 
        if (!query || !*query) return 0;
 
-       ret = rlm_sql_select_query(inst, request, handle, query);
-       if (ret != RLM_SQL_OK) return -1;
+       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, *handle, query, SQL_QUERY_SELECT ));
+
+       rlm_sql_select_query(&p_result, NULL, request, query_ctx);
+       if (query_ctx->rcode != RLM_SQL_OK) {
+               talloc_free(query_ctx);
+               return -1;
+       }
+       *handle = query_ctx->handle;
 
        while (rlm_sql_fetch_row(&row, inst, request, handle) == RLM_SQL_OK) {
                if (!row[0]){
@@ -908,7 +913,7 @@ static int sql_get_grouplist(rlm_sql_t const *inst, rlm_sql_handle_t **handle, r
                num_groups++;
        }
 
-       (inst->driver->sql_finish_select_query)(*handle, &inst->config);
+       talloc_free(query_ctx);
 
        return num_groups;
 }
index 3cc8d63c84e241a3912312633ce0838d01de9ada..c13e7d9d732dbe9170ba8694ac648d801855feca 100644 (file)
@@ -210,7 +210,7 @@ struct sql_inst {
        xlat_escape_legacy_t    sql_escape_func;
        fr_value_box_escape_t   box_escape_func;
        unlang_function_t       query;
-       sql_rcode_t             (*select)(rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle, char const *query);
+       unlang_function_t       select;
        sql_rcode_t             (*fetch_row)(rlm_sql_row_t *out, rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle);
        fr_sql_query_t          *(*query_alloc)(TALLOC_CTX *ctx, rlm_sql_t const *inst, rlm_sql_handle_t *handle, char const *query_str, fr_sql_query_type_t type);
 
@@ -221,7 +221,7 @@ struct sql_inst {
 void           *sql_mod_conn_create(TALLOC_CTX *ctx, void *instance, fr_time_delta_t timeout);
 int            sql_get_map_list(TALLOC_CTX *ctx, rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle, map_list_t *out, char const *query, fr_dict_attr_t const *list);
 void           rlm_sql_query_log(rlm_sql_t const *inst, char const *filename, char const *query) CC_HINT(nonnull);
-sql_rcode_t    rlm_sql_select_query(rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle, char const *query) CC_HINT(nonnull (1, 3, 4));
+unlang_action_t rlm_sql_select_query(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx);
 unlang_action_t        rlm_sql_query(rlm_rcode_t *p_result, int *priority, request_t *request, void *uctx);
 sql_rcode_t            rlm_sql_fetch_row(rlm_sql_row_t *out, rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle);
 void           rlm_sql_print_error(rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t *handle, bool force_debug);
index 754dafda8693e4d964c967abf383d2bd917f4769..e7ef13a2ddb44af47880cb8a72eecefbcd5c1081 100644 (file)
@@ -524,29 +524,32 @@ unlang_action_t rlm_sql_query(rlm_rcode_t *p_result, UNUSED int *priority, reque
  * @note Caller must call ``(inst->driver->sql_finish_select_query)(handle, &inst->config);``
  *     after they're done with the result.
  *
- * @param inst #rlm_sql_t instance data.
- * @param request Current request.
- * @param handle to query the database with. *handle should not be NULL, as this indicates
- *       previous reconnection attempt has failed.
- * @param query to execute. Should not be zero length.
- * @return
+ * The rcode within the query context is updated to
  *     - #RLM_SQL_OK on success.
- *     - #RLM_SQL_RECONNECT if a new handle is required (also sets *handle = NULL).
+ *     - #RLM_SQL_RECONNECT if a new handle is required (also sets the handle to NULL).
  *     - #RLM_SQL_QUERY_INVALID, #RLM_SQL_ERROR on invalid query or connection error.
+ *     - #RLM_SQL_ALT_QUERY on constraints violation.
+ *
+ * @param p_result     Result of current module call.
+ * @param priority     Unused.
+ * @param request      Current request.
+ * @param uctx         query context containing query to execute.
+ * @return an unlang_action_t.
  */
-sql_rcode_t rlm_sql_select_query(rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle, char const *query)
+unlang_action_t rlm_sql_select_query(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx)
 {
-       int ret = RLM_SQL_ERROR;
+       fr_sql_query_t  *query_ctx = talloc_get_type_abort(uctx, fr_sql_query_t);
+       rlm_sql_t const *inst = query_ctx->inst;
        int i, count;
 
        /* Caller should check they have a valid handle */
-       fr_assert(*handle);
+       fr_assert(query_ctx->handle);
 
        /* There's no query to run, return an error */
-       if (query[0] == '\0') {
+       if (query_ctx->query_str[0] == '\0') {
                if (request) REDEBUG("Zero length query");
-
-               return RLM_SQL_QUERY_INVALID;
+               query_ctx->rcode = RLM_SQL_QUERY_INVALID;
+               RETURN_MODULE_INVALID;
        }
 
        /*
@@ -558,38 +561,39 @@ sql_rcode_t rlm_sql_select_query(rlm_sql_t const *inst, request_t *request, rlm_
         *  For sanity, for when no connections are viable, and we can't make a new one
         */
        for (i = 0; i < (count + 1); i++) {
-               ROPTIONAL(RDEBUG2, DEBUG2, "Executing select query: %s", query);
+               ROPTIONAL(RDEBUG2, DEBUG2, "Executing select query: %s", query_ctx->query_str);
 
-               ret = (inst->driver->sql_select_query)(*handle, &inst->config, query);
-               switch (ret) {
+               query_ctx->rcode = (inst->driver->sql_select_query)(query_ctx->handle, &inst->config, query_ctx->query_str);
+               query_ctx->status = SQL_QUERY_SUBMITTED;
+               switch (query_ctx->rcode) {
                case RLM_SQL_OK:
-                       break;
+                       RETURN_MODULE_OK;
 
                /*
                 *      Run through all available sockets until we exhaust all existing
                 *      sockets in the pool and fail to establish a *new* connection.
                 */
                case RLM_SQL_RECONNECT:
-                       *handle = fr_pool_connection_reconnect(inst->pool, request, *handle);
+                       query_ctx->handle = fr_pool_connection_reconnect(inst->pool, request, query_ctx->handle);
                        /* Reconnection failed */
-                       if (!*handle) return RLM_SQL_RECONNECT;
+                       if (!query_ctx->handle) RETURN_MODULE_FAIL;
                        /* Reconnection succeeded, try again with the new handle */
                        continue;
 
                case RLM_SQL_QUERY_INVALID:
                case RLM_SQL_ERROR:
                default:
-                       rlm_sql_print_error(inst, request, *handle, false);
-                       (inst->driver->sql_finish_select_query)(*handle, &inst->config);
-                       break;
+                       rlm_sql_print_error(inst, request, query_ctx->handle, false);
+                       (inst->driver->sql_finish_select_query)(query_ctx->handle, &inst->config);
+                       if (query_ctx->rcode == RLM_SQL_QUERY_INVALID) RETURN_MODULE_INVALID;
+                       RETURN_MODULE_FAIL;
                }
-
-               return ret;
        }
 
        ROPTIONAL(RERROR, ERROR, "Hit reconnection limit");
 
-       return RLM_SQL_ERROR;
+       query_ctx->rcode = RLM_SQL_ERROR;
+       RETURN_MODULE_FAIL;
 }
 
 
@@ -605,8 +609,9 @@ int sql_get_map_list(TALLOC_CTX *ctx, rlm_sql_t const *inst, request_t *request,
 {
        rlm_sql_row_t   row;
        int             rows = 0;
-       sql_rcode_t     rcode;
        map_t           *parent = NULL;
+       fr_sql_query_t  *query_ctx;
+       rlm_rcode_t     p_result;
        tmpl_rules_t    lhs_rules = (tmpl_rules_t) {
                .attr = {
                        .dict_def = request->dict,
@@ -631,22 +636,27 @@ int sql_get_map_list(TALLOC_CTX *ctx, rlm_sql_t const *inst, request_t *request,
 
        fr_assert(request);
 
-       rcode = rlm_sql_select_query(inst, request, handle, query);
-       if (rcode != RLM_SQL_OK) return -1; /* error handled by rlm_sql_select_query */
+       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, *handle, query, SQL_QUERY_SELECT));
+       rlm_sql_select_query(&p_result, NULL, request, query_ctx);
+       if (query_ctx->rcode != RLM_SQL_OK) {
+       error:
+               *handle = query_ctx->handle;
+               talloc_free(query_ctx);
+               return -1;
+       }
 
-       while (rlm_sql_fetch_row(&row, inst, request, handle) == RLM_SQL_OK) {
+       while (rlm_sql_fetch_row(&row, inst, request, &query_ctx->handle) == RLM_SQL_OK) {
                map_t *map;
 
                if (map_afrom_fields(ctx, &map, &parent, request, row[2], row[4], row[3], &lhs_rules, &rhs_rules) < 0) {
                        RPEDEBUG("Error parsing user data from database result");
-                       (inst->driver->sql_finish_select_query)(*handle, &inst->config);
-                       return -1;
+                       goto error;
                }
                if (!map->parent) map_list_insert_tail(out, map);
 
                rows++;
        }
-       (inst->driver->sql_finish_select_query)(*handle, &inst->config);
+       talloc_free(query_ctx);
 
        return rows;
 }
index f0b193cb8c9b1274b4dfe8f2cfc2816c03c1232d..0e7b7a3061e5268a4d3470c8a4c47c8eafdb4ba6 100644 (file)
@@ -188,17 +188,23 @@ static int CC_HINT(nonnull (1, 3, 4, 5)) sqlippool_query1(char *out, int outlen,
 {
        int             rlen, retval;
        rlm_sql_row_t   row;
+       fr_sql_query_t  *query_ctx;
+       rlm_rcode_t     p_result;
 
        *out = '\0';
 
-       retval = sql->select(sql, request, handle, query);
+       MEM(query_ctx = sql->query_alloc(unlang_interpret_frame_talloc_ctx(request), sql, *handle, query, SQL_QUERY_SELECT));
+       sql->select(&p_result, NULL, request, query_ctx);
+       retval = query_ctx->rcode;
+       *handle = query_ctx->handle;
 
        if ((retval != 0) || !*handle) {
                REDEBUG("database query error on '%s'", query);
+               talloc_free(query_ctx);
                return 0;
        }
 
-       if (sql->fetch_row(&row, sql, request, handle) < 0) {
+       if (sql->fetch_row(&row, sql, request, &query_ctx->handle) < 0) {
                REDEBUG("Failed fetching query result");
                goto finish;
        }
@@ -223,7 +229,8 @@ static int CC_HINT(nonnull (1, 3, 4, 5)) sqlippool_query1(char *out, int outlen,
        retval = rlen;
 
 finish:
-       (sql->driver->sql_finish_select_query)(*handle, &sql->config);
+       *handle = query_ctx->handle;
+       talloc_free(query_ctx);
 
        return retval;
 }