]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Add uctx for SQL escape functions
authorNick Porter <nick@portercomputing.co.uk>
Thu, 1 Feb 2024 17:56:55 +0000 (17:56 +0000)
committerNick Porter <nick@portercomputing.co.uk>
Wed, 7 Feb 2024 10:41:07 +0000 (10:41 +0000)
Allows for passing of an exising connection handle

src/modules/rlm_sql/rlm_sql.c
src/modules/rlm_sql/rlm_sql.h

index 32e81734afa43623222b6c55401c3f3cf7933fbd..d19d77db56cf08694d4450409df62171ffce8a70 100644 (file)
@@ -182,22 +182,23 @@ static size_t sql_escape_func(request_t *, char *out, size_t outlen, char const
 /** Escape a tainted VB used as an xlat argument
  *
  */
-static int sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx)
+static int CC_HINT(nonnull(2,3)) sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx)
 {
-       fr_sbuff_t                              sbuff;
-       fr_sbuff_uctx_talloc_t                  sbuff_ctx;
+       fr_sbuff_t                      sbuff;
+       fr_sbuff_uctx_talloc_t          sbuff_ctx;
 
-       size_t                                  len;
-       rlm_sql_handle_t                        *handle;
-       rlm_sql_t                               *inst = talloc_get_type_abort(uctx, rlm_sql_t);
-       fr_value_box_entry_t                    entry;
+       size_t                          len;
+       rlm_sql_handle_t                *handle;
+       rlm_sql_escape_uctx_t           *ctx = uctx;
+       rlm_sql_t const                 *inst = talloc_get_type_abort_const(ctx->sql, rlm_sql_t);
+       fr_value_box_entry_t            entry;
 
        /*
         *      If it's already safe, don't do anything.
         */
        if (fr_value_box_is_safe_for(vb, inst->driver)) return 0;
 
-       handle = fr_pool_connection_get(inst->pool, request);
+       handle = ctx->handle ? ctx->handle : fr_pool_connection_get(inst->pool, request);
        if (!handle) {
        error:
                fr_value_box_clear_value(vb);
@@ -236,7 +237,7 @@ static int sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx)
        fr_value_box_mark_safe_for(vb, inst->driver);
        vb->entry = entry;
 
-       fr_pool_connection_release(inst->pool, request, handle);
+       if (!ctx->handle) fr_pool_connection_release(inst->pool, request, handle);
        return 0;
 }
 
@@ -1522,6 +1523,7 @@ static int mod_bootstrap(module_inst_ctx_t const *mctx)
        CONF_SECTION            *conf = mctx->inst->conf;
        xlat_t                  *xlat;
        xlat_arg_parser_t       *sql_xlat_arg;
+       rlm_sql_escape_uctx_t   *uctx;
 
        inst->name = mctx->inst->name;  /* Need this for functions in sql.c */
        inst->driver = (rlm_sql_driver_t const *)inst->driver_submodule->module; /* Public symbol exported by the submodule */
@@ -1596,12 +1598,14 @@ static int mod_bootstrap(module_inst_ctx_t const *mctx)
         *      argument parser details need to be defined here
         */
        sql_xlat_arg = talloc_zero_array(inst, xlat_arg_parser_t, 2);
+       uctx = talloc_zero(sql_xlat_arg, rlm_sql_escape_uctx_t);
+       *uctx = (rlm_sql_escape_uctx_t){ .sql = inst, .handle = NULL };
        sql_xlat_arg[0].type = FR_TYPE_STRING;
        sql_xlat_arg[0].required = true;
        sql_xlat_arg[0].concat = true;
        sql_xlat_arg[0].func = sql_xlat_escape;
        sql_xlat_arg[0].safe_for = (fr_value_box_safe_for_t)inst->driver;
-       sql_xlat_arg[0].uctx = inst;
+       sql_xlat_arg[0].uctx = uctx;
        sql_xlat_arg[1] = (xlat_arg_parser_t)XLAT_ARG_PARSER_TERMINATOR;
 
        xlat_func_mono_set(xlat, sql_xlat_arg);
index df9ca1d2cd0158687506919a6e68bf15a8fe7466..315ca9eba90f553bfd6e157dff349f96fd5077a0 100644 (file)
@@ -174,6 +174,11 @@ extern size_t sql_rcode_table_len;
 typedef size_t (*sql_error_t)(TALLOC_CTX *ctx, sql_log_entry_t out[], size_t outlen, rlm_sql_handle_t *handle,
                              rlm_sql_config_t const *config);
 
+typedef struct {
+       rlm_sql_t const         *sql;
+       rlm_sql_handle_t        *handle;
+} rlm_sql_escape_uctx_t;
+
 typedef struct {
        module_t        common;                         //!< Common fields for all loadable modules.