]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Re-work sqlcounter mod_authorize to push sql xlat on stack
authorNick Porter <nick@portercomputing.co.uk>
Thu, 15 Feb 2024 10:52:07 +0000 (10:52 +0000)
committerNick Porter <nick@portercomputing.co.uk>
Thu, 15 Feb 2024 18:32:00 +0000 (18:32 +0000)
src/modules/rlm_sqlcounter/rlm_sqlcounter.c

index ea09ef8d16752bb0d8f4ec0f35dbb485a7e814b4..d01b63cae89b0b13bdf32bb366c6e9205460ee86 100644 (file)
@@ -30,11 +30,10 @@ RCSID("$Id$")
 #include <freeradius-devel/server/base.h>
 #include <freeradius-devel/server/module_rlm.h>
 #include <freeradius-devel/util/debug.h>
+#include <freeradius-devel/unlang/function.h>
 
 #include <ctype.h>
 
-#define MAX_QUERY_LEN 1024
-
 /*
  *     Note: When your counter spans more than 1 period (ie 3 months
  *     or 2 weeks), this module probably does NOT do what you want! It
@@ -98,6 +97,7 @@ static const conf_parser_t module_config[] = {
 };
 
 typedef struct {
+       xlat_exp_head_t *query_xlat;            //!< Tokenized xlat to run query.
        tmpl_t          *reply_attr;            //!< Attribute to write timeout to.
        tmpl_t          *reply_msg_attr;        //!< Attribute to write reply message to.
 } sqlcounter_call_env_t;
@@ -244,79 +244,43 @@ static int find_prev_reset(rlm_sqlcounter_t *inst, fr_time_t now)
        return ret;
 }
 
-
-/*
- *     Find the named user in this modules database.  Create the set
- *     of attribute-value pairs to check and reply with for this user
- *     from the database. The authentication code only needs to check
- *     the password, the rest is done here.
+typedef struct {
+       bool                    last_success;
+       fr_value_box_list_t     result;
+       rlm_sqlcounter_t        *inst;
+       sqlcounter_call_env_t   *env;
+       fr_pair_t               *limit;
+} sqlcounter_rctx_t;
+
+/** Handle the result of calling the SQL query to retrieve the `counter` value.
+ *
+ * Create / update the `counter` attribute in the contol list
+ * If `counter` > `limit`, optionally populate a reply message and return RLM_MODULE_REJECT.
+ * Otherwise, optionally populate a reply attribute with the value of `limit` - `counter` and return RLM_MODULE_UPDATED.
+ * If no reply attribute is set, return RLM_MODULE_OK.
  */
-static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, module_ctx_t const *mctx, request_t *request)
+static unlang_action_t mod_authorize_resume(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx)
 {
-       rlm_sqlcounter_t        *inst = talloc_get_type_abort(mctx->inst->data, rlm_sqlcounter_t);
-       sqlcounter_call_env_t   *env = talloc_get_type_abort(mctx->env_data, sqlcounter_call_env_t);
+       sqlcounter_rctx_t       *rctx = talloc_get_type_abort(uctx, sqlcounter_rctx_t);
+       rlm_sqlcounter_t        *inst = rctx->inst;
+       sqlcounter_call_env_t   *env = rctx->env;
+       fr_value_box_t          *sql_result = fr_value_box_list_pop_head(&rctx->result);
        uint64_t                counter, res;
-       fr_pair_t               *limit, *vp;
-       fr_pair_t               *reply_item;
-       char                    msg[128];
+       fr_pair_t               *vp, *limit = rctx->limit;
        int                     ret;
-       size_t len;
-       char *expanded = NULL;
-       char query[MAX_QUERY_LEN];
-
-       /*
-        *      Before doing anything else, see if we have to reset
-        *      the counters.
-        */
-       if (fr_time_neq(inst->reset_time, fr_time_wrap(0)) &&
-           (fr_time_lteq(inst->reset_time, request->packet->timestamp))) {
-               /*
-                *      Re-set the next time and prev_time for this counters range
-                */
-               inst->last_reset = inst->reset_time;
-               find_next_reset(inst, request->packet->timestamp);
-       }
+       char                    msg[128];
 
-       if (tmpl_find_vp(&limit, request, inst->limit_attr) < 0) {
-               RWDEBUG2("Couldn't find %s, doing nothing...", inst->limit_attr->name);
-               RETURN_MODULE_NOOP;
+       if (!sql_result || (sscanf(sql_result->vb_strvalue, "%" PRIu64, &counter) != 1)) {
+               RDEBUG2("No integer found in result string \"%pV\".  May be first session, setting counter to 0",
+                       sql_result);
+               counter = 0;
        }
 
        /*
-        *      Populate start and end attributes for use in query expansion
+        *      Add the counter to the control list
         */
-       if (tmpl_find_or_add_vp(&vp, request, inst->start_attr) < 0) {
-               REDEBUG("Couldn't create %s", inst->start_attr->name);
-               RETURN_MODULE_FAIL;
-       }
-       vp->vp_uint64 = fr_time_to_sec(inst->last_reset);
-
-       if (tmpl_find_or_add_vp(&vp, request, inst->end_attr) < 0) {
-               REDEBUG2("Couldn't create %s", inst->end_attr->name);
-               RETURN_MODULE_FAIL;
-       }
-       vp->vp_uint64 = fr_time_to_sec(inst->reset_time);
-
-       /* Then combine that with the name of the module were using to do the query */
-       len = snprintf(query, sizeof(query), "%%%s(\"%s\")", inst->sql_name, inst->query);
-       if (len >= (sizeof(query) - 1)) {
-               REDEBUG("Insufficient query buffer space");
-
-               RETURN_MODULE_FAIL;
-       }
-
-       /* Finally, xlat resulting SQL query */
-       if (xlat_aeval(request, &expanded, request, query, NULL, NULL) < 0) {
-               RETURN_MODULE_FAIL;
-       }
-
-       if (sscanf(expanded, "%" PRIu64, &counter) != 1) {
-               RDEBUG2("No integer found in result string \"%s\".  May be first session, setting counter to 0",
-                       expanded);
-               counter = 0;
-       }
-
-       talloc_free(expanded);
+       MEM(pair_update_control(&vp, tmpl_attr_tail_da(inst->counter_attr)) >= 0);
+       vp->vp_uint64 = counter;
 
        /*
         *      Check if check item > counter
@@ -341,12 +305,6 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
        RDEBUG2("Allowing user, %s value (%" PRIu64 ") is greater than counter value (%" PRIu64 ")",
                inst->limit_attr->name, limit->vp_uint64, counter);
 
-       /*
-        *      Add the counter to the control list
-        */
-       MEM(pair_update_control(&vp, tmpl_attr_tail_da(inst->counter_attr)) >= 0);
-       vp->vp_uint64 = counter;
-
        /*
         *      We are assuming that simultaneous-use=1. But
         *      even if that does not happen then our user
@@ -378,7 +336,7 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
                /*
                 *      Limit the reply attribute to the minimum of the existing value, or this new one.
                 */
-               ret = tmpl_find_or_add_vp(&reply_item, request, env->reply_attr);
+               ret = tmpl_find_or_add_vp(&vp, request, env->reply_attr);
                switch (ret) {
                case 1:         /* new */
                        break;
@@ -404,9 +362,9 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
                        RETURN_MODULE_OK;
                }
 
-               fr_value_box_cast(reply_item, &reply_item->data, reply_item->data.type, NULL, &vb);
+               fr_value_box_cast(vp, &vp->data, vp->data.type, NULL, &vb);
 
-               RDEBUG2("&%pP", reply_item);
+               RDEBUG2("&%pP", vp);
 
                RETURN_MODULE_UPDATED;
        }
@@ -414,6 +372,71 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
        RETURN_MODULE_OK;
 }
 
+/** Check the value of a `counter` retrieved from an SQL query with a `limit`
+ *
+ * Module specific attributes containing the start / end times are created / updated,
+ * the query is tokenized as an xlat call to the relevant SQL module and then
+ * pushed on the stack for evaluation.
+ */
+static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, module_ctx_t const *mctx, request_t *request)
+{
+       rlm_sqlcounter_t        *inst = talloc_get_type_abort(mctx->inst->data, rlm_sqlcounter_t);
+       sqlcounter_call_env_t   *env = talloc_get_type_abort(mctx->env_data, sqlcounter_call_env_t);
+       fr_pair_t               *limit, *vp;
+       sqlcounter_rctx_t       *rctx;
+
+       /*
+        *      Before doing anything else, see if we have to reset
+        *      the counters.
+        */
+       if (fr_time_neq(inst->reset_time, fr_time_wrap(0)) &&
+           (fr_time_lteq(inst->reset_time, request->packet->timestamp))) {
+               /*
+                *      Re-set the next time and prev_time for this counters range
+                */
+               inst->last_reset = inst->reset_time;
+               find_next_reset(inst, request->packet->timestamp);
+       }
+
+       if (tmpl_find_vp(&limit, request, inst->limit_attr) < 0) {
+               RWDEBUG2("Couldn't find %s, doing nothing...", inst->limit_attr->name);
+               RETURN_MODULE_NOOP;
+       }
+
+       /*
+        *      Populate start and end attributes for use in query expansion
+        */
+       if (tmpl_find_or_add_vp(&vp, request, inst->start_attr) < 0) {
+               REDEBUG("Couldn't create %s", inst->start_attr->name);
+               RETURN_MODULE_FAIL;
+       }
+       vp->vp_uint64 = fr_time_to_sec(inst->last_reset);
+
+       if (tmpl_find_or_add_vp(&vp, request, inst->end_attr) < 0) {
+               REDEBUG2("Couldn't create %s", inst->end_attr->name);
+               RETURN_MODULE_FAIL;
+       }
+       vp->vp_uint64 = fr_time_to_sec(inst->reset_time);
+
+       MEM(rctx = talloc(unlang_interpret_frame_talloc_ctx(request), sqlcounter_rctx_t));
+       *rctx = (sqlcounter_rctx_t) {
+               .inst = inst,
+               .env = env,
+               .limit = limit
+       };
+
+       if (unlang_function_push(request, NULL, mod_authorize_resume, NULL, 0, UNLANG_SUB_FRAME, rctx) < 0) {
+       error:
+               talloc_free(rctx);
+               RETURN_MODULE_FAIL;
+       }
+
+       fr_value_box_list_init(&rctx->result);
+       if (unlang_xlat_push(rctx, &rctx->last_success, &rctx->result, request, env->query_xlat, UNLANG_SUB_FRAME) < 0) goto error;
+
+       return UNLANG_ACTION_PUSHED_CHILD;
+}
+
 /*
  *     Do any per-module initialization that is separate to each
  *     configured instance of the module.  e.g. set up connections
@@ -489,9 +512,49 @@ static int mod_bootstrap(module_inst_ctx_t const *mctx)
        return 0;
 }
 
+/** Custom call_env parser to tokenize the SQL query xlat used for counter retrieval
+ */
+static int call_env_query_parse(TALLOC_CTX *ctx, void *out, tmpl_rules_t const *t_rules, CONF_ITEM *ci, void const *data,
+                               UNUSED call_env_parser_t const *rule)
+{
+       rlm_sqlcounter_t const  *inst = talloc_get_type_abort_const(data, rlm_sqlcounter_t);
+       CONF_PAIR const         *to_parse = cf_item_to_pair(ci);
+       char                    *query;
+       xlat_exp_head_t         *ex;
+
+       query = talloc_asprintf(NULL, "%%%s(\"%s\")", inst->sql_name, cf_pair_value(to_parse));
+
+       if (xlat_tokenize(ctx, &ex,
+                 &FR_SBUFF_IN(query, talloc_array_length(query)),
+                 &(fr_sbuff_parse_rules_t){
+                       .escapes = &(fr_sbuff_unescape_rules_t) {
+                               .name = "xlat",
+                               .chr = '\\',
+                               .subs = {
+                                       ['%'] = '%',
+                                       ['\\'] = '\\',
+                               },
+                 }}, t_rules) < 0) {
+               talloc_free(query);
+               return -1;
+       }
+       talloc_free(query);
+
+       if (xlat_needs_resolving(ex) &&
+           (xlat_resolve(ex, &(xlat_res_rules_t){ .allow_unresolved = false }) < 0)) {
+               talloc_free(ex);
+               return -1;
+       }
+
+       *(void**)out = ex;
+       return 0;
+}
+
 static const call_env_method_t sqlcounter_call_env = {
        FR_CALL_ENV_METHOD_OUT(sqlcounter_call_env_t),
        .env = (call_env_parser_t[]){
+               { FR_CALL_ENV_PARSE_ONLY_OFFSET("query", FR_TYPE_VOID, CALL_ENV_FLAG_REQUIRED | CALL_ENV_FLAG_PARSE_ONLY, sqlcounter_call_env_t, query_xlat),
+                 .pair.func = call_env_query_parse },
                { FR_CALL_ENV_PARSE_ONLY_OFFSET("reply_name", FR_TYPE_VOID, CALL_ENV_FLAG_PARSE_ONLY, sqlcounter_call_env_t, reply_attr) },
                { FR_CALL_ENV_PARSE_ONLY_OFFSET("reply_message_name", FR_TYPE_VOID, CALL_ENV_FLAG_PARSE_ONLY, sqlcounter_call_env_t, reply_msg_attr) },
                CALL_ENV_TERMINATOR