]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Amend rlm_sql mod_authorize and group membership xlat to push queries on the stack
authorNick Porter <nick@portercomputing.co.uk>
Mon, 27 May 2024 11:08:31 +0000 (12:08 +0100)
committerArran Cudbard-Bell <a.cudbardb@freeradius.org>
Fri, 7 Jun 2024 02:26:58 +0000 (22:26 -0400)
They share a common function for retrieving group membership so have to
be amended at the same time

src/modules/rlm_sql/rlm_sql.c
src/modules/rlm_sql/rlm_sql.h
src/modules/rlm_sql/sql.c

index b5f33baef062b6836c71cfb0c768f13c02f904ee..6b2c19e471ce8d62fd66155ed632f4d694f7913b 100644 (file)
@@ -149,14 +149,21 @@ typedef struct rlm_sql_grouplist_s rlm_sql_grouplist_t;
 /** Status of the authorization process
  */
 typedef enum {
-       SQL_AUTZ_CHECK          = 0x11,         //!< Running user `check` query
-       SQL_AUTZ_REPLY          = 0x12,         //!< Running user `reply` query
-       SQL_AUTZ_GROUP_MEMB     = 0x20,         //!< Running group membership query
-       SQL_AUTZ_GROUP_CHECK    = 0x21,         //!< Running group `check` query
-       SQL_AUTZ_GROUP_REPLY    = 0x22,         //!< Running group `reply` query
-       SQL_AUTZ_PROFILE_START  = 0x40,         //!< Starting processing user profiles
-       SQL_AUTZ_PROFILE_CHECK  = 0x41,         //!< Running profile `check` query
-       SQL_AUTZ_PROFILE_REPLY  = 0x42,         //!< Running profile `reply` query
+       SQL_AUTZ_CHECK                  = 0x12,         //!< Running user `check` query
+       SQL_AUTZ_CHECK_RESUME           = 0x13,         //!< Completed user `check` query
+       SQL_AUTZ_REPLY                  = 0x14,         //!< Running user `reply` query
+       SQL_AUTZ_REPLY_RESUME           = 0x15,         //!< Completed user `reply` query
+       SQL_AUTZ_GROUP_MEMB             = 0x20,         //!< Running group membership query
+       SQL_AUTZ_GROUP_MEMB_RESUME      = 0x21,         //!< Completed group membership query
+       SQL_AUTZ_GROUP_CHECK            = 0x22,         //!< Running group `check` query
+       SQL_AUTZ_GROUP_CHECK_RESUME     = 0x23,         //!< Completed group `check` query
+       SQL_AUTZ_GROUP_REPLY            = 0x24,         //!< Running group `reply` query
+       SQL_AUTZ_GROUP_REPLY_RESUME     = 0x25,         //!< Completed group `reply` query
+       SQL_AUTZ_PROFILE_START          = 0x40,         //!< Starting processing user profiles
+       SQL_AUTZ_PROFILE_CHECK          = 0x42,         //!< Running profile `check` query
+       SQL_AUTZ_PROFILE_CHECK_RESUME   = 0x43,         //!< Completed profile `check` query
+       SQL_AUTZ_PROFILE_REPLY          = 0x44,         //!< Running profile `reply` query
+       SQL_AUTZ_PROFILE_REPLY_RESUME   = 0x45,         //!< Completed profile `reply` query
 } sql_autz_status_t;
 
 #define SQL_AUTZ_STAGE_GROUP 0x20
@@ -186,7 +193,6 @@ typedef struct {
        sql_autz_status_t       status;         //!< Current status of the authorization.
        fr_value_box_list_t     query;          //!< Where expanded query tmpls will be written.
        bool                    user_found;     //!< Has the user been found anywhere?
-       rlm_sql_grouplist_t     *groups;        //!< List of groups returned by the group membership query.
        rlm_sql_grouplist_t     *group;         //!< Current group being processed.
        fr_pair_t               *sql_group;     //!< Pair to update with group being processed.
        fr_pair_t               *profile;       //!< Current profile being processed.
@@ -959,125 +965,134 @@ struct rlm_sql_grouplist_s {
        rlm_sql_grouplist_t     *next;
 };
 
-static int sql_get_grouplist(rlm_sql_t const *inst, rlm_sql_handle_t **handle, fr_trunk_t *trunk, request_t *request,
-                            char const *query, rlm_sql_grouplist_t **phead)
+static unlang_action_t sql_get_grouplist_resume(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx)
 {
-       int                     num_groups = 0;
+       sql_group_ctx_t         *group_ctx = talloc_get_type_abort(uctx, sql_group_ctx_t);
+       fr_sql_query_t          *query_ctx = group_ctx->query_ctx;
+       rlm_sql_t const         *inst = group_ctx->inst;
        rlm_sql_row_t           row;
-       rlm_sql_grouplist_t     *entry;
-       rlm_rcode_t             p_result;
-       fr_sql_query_t          *query_ctx;
-
-       /* NOTE: sql_set_user should have been run before calling this function */
-
-       entry = *phead = NULL;
-
-       if (!query || !*query) return 0;
-
-       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, request, *handle, trunk,
-                                          query, SQL_QUERY_SELECT));
+       rlm_sql_grouplist_t     *entry = group_ctx->groups;
 
-       inst->select(&p_result, NULL, request, query_ctx);
        if (query_ctx->rcode != RLM_SQL_OK) {
+       error:
                talloc_free(query_ctx);
-               return -1;
+               RETURN_MODULE_FAIL;
        }
-       *handle = query_ctx->handle;
 
-       while ((inst->fetch_row(&p_result, NULL, request, query_ctx) == UNLANG_ACTION_CALCULATE_RESULT) &&
+       while ((inst->fetch_row(p_result, NULL, request, query_ctx) == UNLANG_ACTION_CALCULATE_RESULT) &&
                (query_ctx->rcode == RLM_SQL_OK)) {
                row = query_ctx->row;
                if (!row[0]){
                        RDEBUG2("row[0] returned NULL");
-                       (inst->driver->sql_finish_select_query)(query_ctx, &inst->config);
-                       talloc_free(entry);
-                       return -1;
+                       goto error;
                }
 
-               if (!*phead || !entry) {        /* clang scan couldn't tell that when *phead != NULL then entry != NULL */
-                       *phead = talloc_zero(*handle, rlm_sql_grouplist_t);
-                       entry = *phead;
+               if (!group_ctx->groups || !entry) {     /* clang scan couldn't tell that when groups_ctx->groups != NULL then entry != NULL */
+                       group_ctx->groups = talloc_zero(group_ctx, rlm_sql_grouplist_t);
+                       entry = group_ctx->groups;
                } else {
-                       entry->next = talloc_zero(*phead, rlm_sql_grouplist_t);
+                       entry->next = talloc_zero(group_ctx, rlm_sql_grouplist_t);
                        entry = entry->next;
                }
                entry->next = NULL;
                entry->name = talloc_typed_strdup(entry, row[0]);
 
-               num_groups++;
+               group_ctx->num_groups++;
        }
 
        talloc_free(query_ctx);
+       RETURN_MODULE_OK;
+}
 
-       return num_groups;
+static unlang_action_t sql_get_grouplist(sql_group_ctx_t *group_ctx, rlm_sql_handle_t **handle, fr_trunk_t *trunk, request_t *request)
+{
+       rlm_sql_t const         *inst = group_ctx->inst;
+
+       /* NOTE: sql_set_user should have been run before calling this function */
+
+       if (!group_ctx->query || (group_ctx->query->vb_length == 0)) return UNLANG_ACTION_CALCULATE_RESULT;
+
+       MEM(group_ctx->query_ctx = fr_sql_query_alloc(group_ctx, inst, request, *handle, trunk,
+                                                     group_ctx->query->vb_strvalue, SQL_QUERY_SELECT));
+
+       if (unlang_function_push(request, NULL, sql_get_grouplist_resume, NULL, 0, UNLANG_SUB_FRAME, group_ctx) < 0) return UNLANG_ACTION_FAIL;
+
+       return unlang_function_push(request, inst->select, NULL, NULL, 0, UNLANG_SUB_FRAME, group_ctx->query_ctx);
 }
 
-/** Check if a given group is in the SQL group for this user.
+typedef struct {
+       fr_value_box_list_t     query;
+       sql_group_ctx_t         *group_ctx;
+       rlm_sql_handle_t        *handle;
+} sql_group_xlat_ctx_t;
+
+/**  Compare list of groups returned from SQL query to xlat argument.
  *
+ * Called after the SQL query has completed and group list has been built.
  */
-static bool CC_HINT(nonnull) sql_check_group(rlm_sql_t const *inst, request_t *request, char const *query, char const *name)
+static xlat_action_t sql_group_xlat_query_resume(TALLOC_CTX *ctx, fr_dcursor_t *out, xlat_ctx_t const *xctx,
+                                          request_t *request, fr_value_box_list_t *in)
 {
-       bool rcode = false;
-       rlm_sql_handle_t        *handle;
-       rlm_sql_grouplist_t     *entry, *head = NULL;
-       rlm_sql_thread_t        *thread = talloc_get_type_abort(module_thread(inst->mi)->data, rlm_sql_thread_t);
-
-       /*
-        *      Get a socket for this lookup
-        */
-       handle = fr_pool_connection_get(inst->pool, request);
-       if (!handle) {
-               REDEBUG("Failed getting connection handle");
-               return false;
-       }
+       rlm_sql_t const         *inst = talloc_get_type_abort(xctx->mctx->mi->data, rlm_sql_t);
+       sql_group_xlat_ctx_t    *xlat_ctx = talloc_get_type_abort(xctx->rctx, sql_group_xlat_ctx_t);
+       sql_group_ctx_t         *group_ctx = talloc_get_type_abort(xlat_ctx->group_ctx, sql_group_ctx_t);
+       fr_value_box_t          *arg = fr_value_box_list_head(in);
+       char const              *name = arg->vb_strvalue;
+       fr_value_box_t          *vb;
+       rlm_sql_grouplist_t     *entry;
 
-       /*
-        *      Get the list of groups this user is a member of
-        */
-       if (sql_get_grouplist(inst, &handle, thread->trunk, request, query, &head) < 0) {
-               talloc_free(head);
-               REDEBUG("Error getting group membership");
-               fr_pool_connection_release(inst->pool, request, handle);
-               return false;
-       }
+       fr_skip_whitespace(name);
 
-       for (entry = head; entry != NULL; entry = entry->next) {
+       MEM(vb = fr_value_box_alloc(ctx, FR_TYPE_BOOL, attr_expr_bool_enum));
+       for (entry = group_ctx->groups; entry != NULL; entry = entry->next) {
                if (strcmp(entry->name, name) == 0) {
-                       rcode = true;
+                       vb->vb_bool = true;
                        break;
                }
        }
+       fr_dcursor_append(out, vb);
 
-       /* Free the grouplist */
-       talloc_free(head);
-       fr_pool_connection_release(inst->pool, request, handle);
+       if (!inst->driver->uses_trunks && xlat_ctx->handle) fr_pool_connection_release(inst->pool, request, xlat_ctx->handle);
 
-       return rcode;
+       return XLAT_ACTION_DONE;
 }
 
-typedef struct {
-       fr_value_box_list_t     query;
-} sql_group_xlat_ctx_t;
-
-static xlat_action_t sql_group_xlat_resume(TALLOC_CTX *ctx, fr_dcursor_t *out, xlat_ctx_t const *xctx,
-                                          request_t *request, fr_value_box_list_t *in)
+/** Run SQL query for group membership to return list of groups
+ *
+ * Called after group membership query tmpl is expanded
+ */
+static xlat_action_t sql_group_xlat_resume(UNUSED TALLOC_CTX *ctx, UNUSED fr_dcursor_t *out, xlat_ctx_t const *xctx,
+                                          request_t *request, UNUSED fr_value_box_list_t *in)
 {
        sql_group_xlat_ctx_t    *xlat_ctx = talloc_get_type_abort(xctx->rctx, sql_group_xlat_ctx_t);
        rlm_sql_t const         *inst = talloc_get_type_abort(xctx->mctx->mi->data, rlm_sql_t);
-       fr_value_box_t          *arg = fr_value_box_list_head(in);
-       char const              *p = arg->vb_strvalue;
-       fr_value_box_t          *query, *vb;
+       rlm_sql_thread_t        *thread = talloc_get_type_abort(xctx->mctx->thread, rlm_sql_thread_t);
+       fr_value_box_t          *query;
 
        query = fr_value_box_list_head(&xlat_ctx->query);
        if (!query) return XLAT_ACTION_FAIL;
 
-       fr_skip_whitespace(p);
+       MEM(xlat_ctx->group_ctx = talloc(xlat_ctx, sql_group_ctx_t));
 
-       MEM(vb = fr_value_box_alloc(ctx, FR_TYPE_BOOL, attr_expr_bool_enum));
-       vb->vb_bool = sql_check_group(inst, request, query->vb_strvalue, p);
-       fr_dcursor_append(out, vb);
+       *xlat_ctx->group_ctx = (sql_group_ctx_t) {
+               .inst = inst,
+               .query = query,
+       };
 
-       return XLAT_ACTION_DONE;
+       if (!inst->driver->uses_trunks) {
+               xlat_ctx->handle = fr_pool_connection_get(inst->pool, request);
+               if (!xlat_ctx->handle) {
+                       REDEBUG("Failed getting conneciton handle");
+                       return XLAT_ACTION_FAIL;
+               }
+       }
+
+       if (unlang_xlat_yield(request, sql_group_xlat_query_resume, NULL, 0, xlat_ctx) != XLAT_ACTION_YIELD) return XLAT_ACTION_FAIL;
+
+       if (sql_get_grouplist(xlat_ctx->group_ctx, &xlat_ctx->handle, thread->trunk, request) != UNLANG_ACTION_PUSHED_CHILD)
+                       return XLAT_ACTION_FAIL;
+
+       return XLAT_ACTION_PUSH_UNLANG;
 }
 
 
@@ -1196,32 +1211,51 @@ static unlang_action_t mod_autz_group_resume(rlm_rcode_t *p_result, UNUSED int *
 {
        sql_autz_ctx_t          *autz_ctx = talloc_get_type_abort(uctx, sql_autz_ctx_t);
        sql_autz_call_env_t     *call_env = autz_ctx->call_env;
+       sql_group_ctx_t         *group_ctx = autz_ctx->group_ctx;
+       fr_sql_map_ctx_t        *map_ctx = autz_ctx->map_ctx;
        rlm_sql_t const         *inst = autz_ctx->inst;
        fr_value_box_t          *query = fr_value_box_list_pop_head(&autz_ctx->query);
-       int                     rows;
        sql_fall_through_t      do_fall_through = FALL_THROUGH_DEFAULT;
        fr_pair_t               *vp;
 
+       switch (*p_result) {
+       case RLM_MODULE_USER_SECTION_REJECT:
+               return UNLANG_ACTION_CALCULATE_RESULT;
+
+       default:
+               break;
+       }
+
        switch(autz_ctx->status) {
        case SQL_AUTZ_GROUP_MEMB:
-               rows = sql_get_grouplist(inst, &autz_ctx->handle, autz_ctx->trunk, request, query->vb_strvalue, &autz_ctx->groups);
-               talloc_free(query);
+               if (unlang_function_repeat_set(request, mod_autz_group_resume) < 0) RETURN_MODULE_FAIL;
+               MEM(autz_ctx->group_ctx = talloc(autz_ctx, sql_group_ctx_t));
+               *autz_ctx->group_ctx = (sql_group_ctx_t) {
+                       .inst = inst,
+                       .query = query,
+               };
 
-               if (rows < 0) {
-                       talloc_free(autz_ctx->groups);
-                       REDEBUG("Error retrieving group list");
-                       RETURN_MODULE_FAIL;
+               if (sql_get_grouplist(autz_ctx->group_ctx, &autz_ctx->handle, autz_ctx->trunk, request) == UNLANG_ACTION_PUSHED_CHILD) {
+                       autz_ctx->status = SQL_AUTZ_GROUP_MEMB_RESUME;
+                       return UNLANG_ACTION_PUSHED_CHILD;
                }
 
-               if (rows == 0) {
+               group_ctx = autz_ctx->group_ctx;
+
+               FALL_THROUGH;
+
+       case SQL_AUTZ_GROUP_MEMB_RESUME:
+               talloc_free(group_ctx->query);
+
+               if (group_ctx->num_groups == 0) {
                        RDEBUG2("User not found in any groups");
                        break;
                }
-               fr_assert(autz_ctx->groups);
+               fr_assert(group_ctx->groups);
 
                RDEBUG2("User found in the group table");
                autz_ctx->user_found = true;
-               autz_ctx->group = autz_ctx->groups;
+               autz_ctx->group = group_ctx->groups;
                MEM(pair_update_request(&autz_ctx->sql_group, inst->group_da) >= 0);
 
        next_group:
@@ -1255,22 +1289,31 @@ static unlang_action_t mod_autz_group_resume(rlm_rcode_t *p_result, UNUSED int *
 
        case SQL_AUTZ_GROUP_CHECK:
        case SQL_AUTZ_PROFILE_CHECK:
-               rows = sql_get_map_list(autz_ctx, inst, request, &autz_ctx->handle, autz_ctx->trunk,
-                                       &autz_ctx->check_tmp, query->vb_strvalue, request_attr_request);
-               talloc_free(query);
-
-               if (rows < 0) {
-                       REDEBUG("Error retrieving check pairs for %s %pV",
-                               autz_ctx->status & SQL_AUTZ_STAGE_GROUP ? "group" : "profile",
-                               &autz_ctx->sql_group->data);
-                       RETURN_MODULE_FAIL;
+               *autz_ctx->map_ctx = (fr_sql_map_ctx_t) {
+                       .ctx = autz_ctx,
+                       .inst = inst,
+                       .out = &autz_ctx->check_tmp,
+                       .list = request_attr_request,
+                       .query = query,
+               };
+
+               if (unlang_function_repeat_set(request, mod_autz_group_resume) < 0) RETURN_MODULE_FAIL;
+               if (sql_get_map_list(request, map_ctx, &autz_ctx->handle, autz_ctx->trunk) == UNLANG_ACTION_PUSHED_CHILD) {
+                       autz_ctx->status = autz_ctx->status & SQL_AUTZ_STAGE_GROUP ? SQL_AUTZ_GROUP_CHECK_RESUME : SQL_AUTZ_PROFILE_CHECK_RESUME;
+                       return UNLANG_ACTION_PUSHED_CHILD;
                }
 
+               FALL_THROUGH;
+
+       case SQL_AUTZ_GROUP_CHECK_RESUME:
+       case SQL_AUTZ_PROFILE_CHECK_RESUME:
+               talloc_free(map_ctx->query);
+
                /*
                 *      If we got check rows we need to process them before we decide to
                 *      process the reply rows
                 */
-               if (rows > 0) {
+               if (map_ctx->rows > 0) {
                        if (check_map_process(request, &autz_ctx->check_tmp, &autz_ctx->reply_tmp) < 0) {
                                map_list_talloc_free(&autz_ctx->check_tmp);
                                goto next_group_find;
@@ -1301,17 +1344,27 @@ static unlang_action_t mod_autz_group_resume(rlm_rcode_t *p_result, UNUSED int *
 
        case SQL_AUTZ_GROUP_REPLY:
        case SQL_AUTZ_PROFILE_REPLY:
-               rows = sql_get_map_list(autz_ctx, inst, request, &autz_ctx->handle, autz_ctx->trunk,
-                                       &autz_ctx->reply_tmp, query->vb_strvalue, request_attr_reply);
-               talloc_free(query);
+               *autz_ctx->map_ctx = (fr_sql_map_ctx_t) {
+                       .ctx = autz_ctx,
+                       .inst = inst,
+                       .out = &autz_ctx->reply_tmp,
+                       .list = request_attr_reply,
+                       .query = query,
+               };
 
-               if (rows < 0) {
-                       REDEBUG("Error retrieving reply pairs for %s %pV",
-                               autz_ctx->status & SQL_AUTZ_STAGE_GROUP ? "group" : "profile", &autz_ctx->sql_group->data);
-                       RETURN_MODULE_FAIL;
+               if (unlang_function_repeat_set(request, mod_autz_group_resume) < 0) RETURN_MODULE_FAIL;
+               if (sql_get_map_list(request, map_ctx, &autz_ctx->handle, autz_ctx->trunk) == UNLANG_ACTION_PUSHED_CHILD) {
+                       autz_ctx->status = autz_ctx->status & SQL_AUTZ_STAGE_GROUP ? SQL_AUTZ_GROUP_REPLY_RESUME : SQL_AUTZ_PROFILE_REPLY_RESUME;
+                       return UNLANG_ACTION_PUSHED_CHILD;
                }
 
-               if (rows == 0) {
+               FALL_THROUGH;
+
+       case SQL_AUTZ_GROUP_REPLY_RESUME:
+       case SQL_AUTZ_PROFILE_REPLY_RESUME:
+               talloc_free(map_ctx->query);
+
+               if (map_ctx->rows == 0) {
                        do_fall_through = FALL_THROUGH_DEFAULT;
                        goto group_attr_cache;
                }
@@ -1390,21 +1443,42 @@ static unlang_action_t mod_authorize_resume(rlm_rcode_t *p_result, int *priority
        sql_autz_call_env_t     *call_env = autz_ctx->call_env;
        rlm_sql_t const         *inst = autz_ctx->inst;
        fr_value_box_t          *query = fr_value_box_list_pop_head(&autz_ctx->query);
-       int                     rows;
        sql_fall_through_t      do_fall_through = FALL_THROUGH_DEFAULT;
+       fr_sql_map_ctx_t        *map_ctx = autz_ctx->map_ctx;
+
+       /*
+        *      If a previous async call returned one of the "failure" results just return.
+        */
+       switch (*p_result) {
+       case RLM_MODULE_USER_SECTION_REJECT:
+               return UNLANG_ACTION_CALCULATE_RESULT;
+
+       default:
+               break;
+       }
 
        switch(autz_ctx->status) {
        case SQL_AUTZ_CHECK:
-               rows = sql_get_map_list(autz_ctx, inst, request, &autz_ctx->handle, autz_ctx->trunk,
-                                       &autz_ctx->check_tmp, query->vb_strvalue, request_attr_request);
-               talloc_free(query);
+               *autz_ctx->map_ctx = (fr_sql_map_ctx_t) {
+                       .ctx = autz_ctx,
+                       .inst = inst,
+                       .out = &autz_ctx->check_tmp,
+                       .list = request_attr_request,
+                       .query = query,
+               };
 
-               if (rows < 0) {
-                       REDEBUG("Failed getting check attributes");
-                       RETURN_MODULE_FAIL;
+               if (unlang_function_repeat_set(request, mod_authorize_resume) < 0) RETURN_MODULE_FAIL;
+               if (sql_get_map_list(request, map_ctx, &autz_ctx->handle, autz_ctx->trunk) == UNLANG_ACTION_PUSHED_CHILD){
+                       autz_ctx->status = SQL_AUTZ_CHECK_RESUME;
+                       return UNLANG_ACTION_PUSHED_CHILD;
                }
 
-               if (rows == 0) goto skip_reply; /* Don't need to handle map entries we don't have */
+               FALL_THROUGH;
+
+       case SQL_AUTZ_CHECK_RESUME:
+               talloc_free(map_ctx->query);
+
+               if (map_ctx->rows == 0) goto skip_reply;        /* Don't need to handle map entries we don't have */
 
                /*
                 *      Only do this if *some* check pairs were returned
@@ -1426,16 +1500,26 @@ static unlang_action_t mod_authorize_resume(rlm_rcode_t *p_result, int *priority
                return UNLANG_ACTION_PUSHED_CHILD;
 
        case SQL_AUTZ_REPLY:
-               rows = sql_get_map_list(autz_ctx, inst, request, &autz_ctx->handle, autz_ctx->trunk,
-                                       &autz_ctx->reply_tmp, query->vb_strvalue, request_attr_reply);
-               talloc_free(query);
+               *autz_ctx->map_ctx = (fr_sql_map_ctx_t) {
+                       .ctx = autz_ctx,
+                       .inst = inst,
+                       .out = &autz_ctx->reply_tmp,
+                       .list = request_attr_reply,
+                       .query = query,
+               };
 
-               if (rows < 0) {
-                       REDEBUG("SQL query error getting reply attributes");
-                       RETURN_MODULE_FAIL;
+               if (unlang_function_repeat_set(request, mod_authorize_resume) < 0) RETURN_MODULE_FAIL;
+               if (sql_get_map_list(request, map_ctx, &autz_ctx->handle, autz_ctx->trunk) == UNLANG_ACTION_PUSHED_CHILD){
+                       autz_ctx->status = SQL_AUTZ_REPLY_RESUME;
+                       return UNLANG_ACTION_PUSHED_CHILD;
                }
 
-               if (rows == 0) goto skip_reply;
+               FALL_THROUGH;
+
+       case SQL_AUTZ_REPLY_RESUME:
+               talloc_free(map_ctx->query);
+
+               if (map_ctx->rows == 0) goto skip_reply;
 
                do_fall_through = fall_through(&autz_ctx->reply_tmp);
 
@@ -1528,7 +1612,7 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
         */
        sql_set_user(inst, request, &call_env->user);
 
-       MEM(autz_ctx = talloc_zero(unlang_interpret_frame_talloc_ctx(request), sql_autz_ctx_t));
+       MEM(autz_ctx = talloc(unlang_interpret_frame_talloc_ctx(request), sql_autz_ctx_t));
        *autz_ctx = (sql_autz_ctx_t) {
                .inst = inst,
                .call_env = call_env,
@@ -1538,6 +1622,7 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
        };
        map_list_init(&autz_ctx->check_tmp);
        map_list_init(&autz_ctx->reply_tmp);
+       MEM(autz_ctx->map_ctx = talloc_zero(autz_ctx, fr_sql_map_ctx_t));
        talloc_set_destructor(autz_ctx, sql_autz_ctx_free);
 
        /*
@@ -1545,14 +1630,17 @@ static unlang_action_t CC_HINT(nonnull) mod_authorize(rlm_rcode_t *p_result, mod
         *
         *      This is freed by the talloc destructor for autz_ctx
         */
-       autz_ctx->handle = fr_pool_connection_get(inst->pool, request);
-       if (!autz_ctx->handle) RETURN_MODULE_FAIL;
+       if (!inst->driver->uses_trunks) {
+               autz_ctx->handle = fr_pool_connection_get(inst->pool, request);
+               if (!autz_ctx->handle) RETURN_MODULE_FAIL;
+       }
 
        if (!inst->sql_escape_arg && !thread->sql_escape_arg) request_data_add(request, (void *)sql_escape_uctx_alloc, 0,
                                                                               autz_ctx->handle, false, false, false);
 
-       if (unlang_function_push(request, NULL, mod_authorize_resume, NULL, 0,
-                                UNLANG_SUB_FRAME, autz_ctx) < 0) {
+       if (unlang_function_push(request, NULL,
+                                (call_env->check_query || call_env->reply_query) ? mod_authorize_resume : mod_autz_group_resume,
+                                NULL, 0, UNLANG_SUB_FRAME, autz_ctx) < 0) {
        error:
                talloc_free(autz_ctx);
                RETURN_MODULE_FAIL;
index 271b1f2ec341f07e692f977aca38ec0482086512..207916c3d5b16546793f49fbf1a85344d28722cf 100644 (file)
@@ -252,7 +252,7 @@ struct sql_inst {
 };
 
 void           *sql_mod_conn_create(TALLOC_CTX *ctx, void *instance, fr_time_delta_t timeout);
-int            sql_get_map_list(TALLOC_CTX *ctx, rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle, fr_trunk_t *trunk, map_list_t *out, char const *query, fr_dict_attr_t const *list);
+unlang_action_t        sql_get_map_list(request_t *request, fr_sql_map_ctx_t *map_ctx, rlm_sql_handle_t **handle, fr_trunk_t *trunk);
 void           rlm_sql_query_log(rlm_sql_t const *inst, char const *filename, char const *query) CC_HINT(nonnull);
 unlang_action_t rlm_sql_select_query(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx);
 unlang_action_t        rlm_sql_query(rlm_rcode_t *p_result, int *priority, request_t *request, void *uctx);
index 2515541fcef2ee876b1ffba063d9ba832ccab222..3b5648c16f1853fd942ef205d6fa99eeee660c2f 100644 (file)
@@ -608,71 +608,65 @@ unlang_action_t rlm_sql_select_query(rlm_rcode_t *p_result, UNUSED int *priority
        RETURN_MODULE_FAIL;
 }
 
-
-/*************************************************************************
- *
- *     Function: sql_getvpdata
- *
- *     Purpose: Get any group check or reply pairs
+/** Process the results of an SQL query to produce a map list.
  *
- *************************************************************************/
-int sql_get_map_list(TALLOC_CTX *ctx, rlm_sql_t const *inst, request_t *request, rlm_sql_handle_t **handle,
-                    fr_trunk_t *trunk, map_list_t *out, char const *query, fr_dict_attr_t const *list)
+ */
+static unlang_action_t sql_get_map_list_resume(rlm_rcode_t *p_result, UNUSED int *priority, request_t *request, void *uctx)
 {
-       rlm_sql_row_t   row;
-       int             rows = 0;
-       map_t           *parent = NULL;
-       fr_sql_query_t  *query_ctx;
-       rlm_rcode_t     p_result;
-       tmpl_rules_t    lhs_rules = (tmpl_rules_t) {
+       fr_sql_map_ctx_t        *map_ctx = talloc_get_type_abort(uctx, fr_sql_map_ctx_t);
+       tmpl_rules_t            lhs_rules = (tmpl_rules_t) {
                .attr = {
                        .dict_def = request->dict,
                        .prefix = TMPL_ATTR_REF_PREFIX_AUTO,
-                       .list_def = list,
-                       .list_presence = TMPL_ATTR_LIST_ALLOW,
-
-                       /*
-                        *      Otherwise the tmpl code returns 0 when asked
-                        *      to parse unknown names.  So we say "please
-                        *      parse unknown names as unresolved attributes",
-                        *      and then do a second pass to complain that the
-                        *      thing isn't known.
-                        */
-                       .allow_unresolved = false
+                       .list_def = map_ctx->list,
+                       .list_presence = TMPL_ATTR_LIST_ALLOW
                }
        };
        tmpl_rules_t    rhs_rules = lhs_rules;
+       fr_sql_query_t  *query_ctx = map_ctx->query_ctx;
+       rlm_sql_row_t   row;
+       map_t           *parent = NULL;
+       rlm_sql_t const *inst = map_ctx->inst;
 
        rhs_rules.attr.prefix = TMPL_ATTR_REF_PREFIX_YES;
        rhs_rules.attr.list_def = request_attr_request;
 
-       fr_assert(request);
-
-       MEM(query_ctx = fr_sql_query_alloc(unlang_interpret_frame_talloc_ctx(request), inst, request, *handle, trunk, query, SQL_QUERY_SELECT));
-       inst->select(&p_result, NULL, request, query_ctx);
-       if (query_ctx->rcode != RLM_SQL_OK) {
-       error:
-               *handle = query_ctx->handle;
-               talloc_free(query_ctx);
-               return -1;
-       }
+       if (query_ctx->rcode != RLM_SQL_OK) RETURN_MODULE_FAIL;
 
-       while ((inst->fetch_row(&p_result, NULL, request, query_ctx) == UNLANG_ACTION_CALCULATE_RESULT) &&
+       while ((inst->fetch_row(p_result, NULL, request, query_ctx) == UNLANG_ACTION_CALCULATE_RESULT) &&
               (query_ctx->rcode == RLM_SQL_OK)) {
                map_t *map;
 
                row = query_ctx->row;
-               if (map_afrom_fields(ctx, &map, &parent, request, row[2], row[4], row[3], &lhs_rules, &rhs_rules) < 0) {
+               if (map_afrom_fields(map_ctx->ctx, &map, &parent, request, row[2], row[4], row[3], &lhs_rules, &rhs_rules) < 0) {
                        RPEDEBUG("Error parsing user data from database result");
-                       goto error;
+                       RETURN_MODULE_FAIL;
                }
-               if (!map->parent) map_list_insert_tail(out, map);
+               if (!map->parent) map_list_insert_tail(map_ctx->out, map);
 
-               rows++;
+               map_ctx->rows++;
        }
        talloc_free(query_ctx);
 
-       return rows;
+       RETURN_MODULE_OK;
+}
+
+/** Submit the query to get any user / group check or reply pairs
+ *
+ */
+unlang_action_t sql_get_map_list(request_t *request, fr_sql_map_ctx_t *map_ctx, rlm_sql_handle_t **handle,
+                                fr_trunk_t *trunk)
+{
+       rlm_sql_t const *inst = map_ctx->inst;
+
+       fr_assert(request);
+
+       MEM(map_ctx->query_ctx = fr_sql_query_alloc(map_ctx->ctx, inst, request, *handle, trunk,
+                                                   map_ctx->query->vb_strvalue, SQL_QUERY_SELECT));
+
+       if (unlang_function_push(request, NULL, sql_get_map_list_resume, NULL, 0, UNLANG_SUB_FRAME, map_ctx) < 0) return UNLANG_ACTION_FAIL;
+
+       return unlang_function_push(request, inst->select, NULL, NULL, 0, UNLANG_SUB_FRAME, map_ctx->query_ctx);
 }
 
 /*