]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Duplicate a bunch of code in unlang/function.c *sigh* and allow it to work when the...
authorArran Cudbard-Bell <a.cudbardb@freeradius.org>
Mon, 2 Jun 2025 05:12:54 +0000 (23:12 -0600)
committerNick Porter <nick@portercomputing.co.uk>
Wed, 18 Jun 2025 12:53:16 +0000 (13:53 +0100)
src/lib/unlang/function.c

index ab7f611d6c3b80a30ef8b51a846d37e020821ab7..4a9a4de352af58ddba2d373050de109cc1a2b2b1 100644 (file)
@@ -121,20 +121,30 @@ static void unlang_function_signal(request_t *request,
  * @param[in] request          The current request.
  * @param[in] frame            The current frame.
  */
-static unlang_action_t call_with_result(unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
+static unlang_action_t call_with_result_repeat(unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
 {
        unlang_action_t                 ua;
        unlang_frame_state_func_t       *state = talloc_get_type_abort(frame->state, unlang_frame_state_func_t);
+       unlang_function_with_result_t   func = REPEAT(state);
 
        STORE_CALLER;
 
+       if (!REPEAT(state)) {
+               RDEBUG4("Repeat function is NULL, likely due to previous yield, skipping call");
+               ua = UNLANG_ACTION_CALCULATE_RESULT;
+               goto done;
+       }
+
 again:
-       RDEBUG4("Calling function %p (%s)", state->func.wres, state->func_name);
-       ua = state->func.wres(p_result, request, state->uctx);
-       state->func.wres = state->repeat.wres;
-       state->repeat.wres = NULL;
+       RDEBUG4("Calling repeat function %p (%s)", REPEAT(state), state->repeat_name);
+
+       /*
+        *      Only called once...
+        */
+       REPEAT(state) = NULL;
        state->repeat_name = NULL;
-       if (state->func.wres) {
+       ua = func(p_result, request, state->uctx);
+       if (REPEAT(state)) { /* set again by func */
                switch (ua) {
                case UNLANG_ACTION_STOP_PROCESSING:
                        break;
@@ -143,40 +153,81 @@ again:
                        goto again;
 
                default:
-                       frame_repeat(frame, call_with_result);
+                       frame_repeat(frame, call_with_result_repeat);
                }
        }
+
+done:
        RESTORE_CALLER;
 
        return ua;
 }
 
-/** Call a generic function that produces no result
+/** Call a generic function that produces a result
  *
- * These functions report results by modifying the rctx passed into the function.
- * They are not allowed to return UNLANG_ACTION_FAIL.  In non-debug builds this
- * is rewritten to UNLANG_ACTION_CALCULATE_RESULT, and in debug builds it triggers
- * an assert.
+ * @param[out] p_result                The frame result.
+ * @param[in] request          The current request.
+ * @param[in] frame            The current frame.
+ */
+static unlang_action_t call_with_result(unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
+{
+       unlang_action_t                 ua;
+       unlang_frame_state_func_t       *state = talloc_get_type_abort(frame->state, unlang_frame_state_func_t);
+
+       STORE_CALLER;
+
+       RDEBUG4("Calling function %p (%s)", FUNC(state), state->func_name);
+       ua = state->func.wres(p_result, request, state->uctx);
+       FUNC(state) = NULL;
+       state->func_name = NULL;
+       if (REPEAT(state)) {
+               switch (ua) {
+               case UNLANG_ACTION_STOP_PROCESSING:
+                       break;
+
+               case UNLANG_ACTION_CALCULATE_RESULT:
+                       ua = call_with_result_repeat(p_result, request, frame);
+                       break;
+
+               default:
+                       frame_repeat(frame, call_with_result_repeat);
+               }
+       }
+       RESTORE_CALLER;
+
+       return ua;
+}
+
+/** Call a generic function that produces a result
  *
  * @param[out] p_result                The frame result.
  * @param[in] request          The current request.
  * @param[in] frame            The current frame.
  */
-static unlang_action_t call_no_result(UNUSED unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
+static unlang_action_t call_no_result_repeat(UNUSED unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
 {
        unlang_action_t                 ua;
        unlang_frame_state_func_t       *state = talloc_get_type_abort(frame->state, unlang_frame_state_func_t);
+       unlang_function_no_result_t     func = REPEAT(state);
 
        STORE_CALLER;
 
+       if (!REPEAT(state)) {
+               RDEBUG4("Repeat function is NULL, likely due to previous yield, skipping call");
+               ua = UNLANG_ACTION_CALCULATE_RESULT;
+               goto done;
+       }
+
 again:
-       RDEBUG4("Calling function %p (%s)", state->func.nres, state->func_name);
-       ua = state->func.nres(request, state->uctx);
-       state->func.nres = state->repeat.nres;
-       state->func_name = state->repeat_name;
-       state->repeat.nres = NULL;
+       RDEBUG4("Calling repeat function %p (%s)", REPEAT(state), state->repeat_name);
+
+       /*
+        *      Only called once...
+        */
+       REPEAT(state) = NULL;
        state->repeat_name = NULL;
-       if (state->func.nres) {
+       ua = func(request, state->uctx);
+       if (REPEAT(state)) { /* set again by func */
                switch (ua) {
                case UNLANG_ACTION_STOP_PROCESSING:
                        break;
@@ -184,15 +235,65 @@ again:
                case UNLANG_ACTION_CALCULATE_RESULT:
                        goto again;
 
+               case UNLANG_ACTION_FAIL:
+               no_action_fail:
+                       fr_assert_msg(0, "Function %s (%p) is not allowed to indicate failure via UNLANG_ACTION_FAIL",
+                                     state->repeat_name, REPEAT(state));
+                       ua = UNLANG_ACTION_CALCULATE_RESULT;
+                       break;
+
                default:
-                       frame_repeat(frame, call_no_result);
+                       frame_repeat(frame, call_no_result_repeat);
                }
        }
-       if (ua == UNLANG_ACTION_FAIL) {
-               fr_assert_msg(0, "Function %s (%p) is not allowed to indicate failure via UNLANG_ACTION_FAIL",
-                             state->func_name, state->func.nres);
-               ua = UNLANG_ACTION_CALCULATE_RESULT;
+
+       if (ua == UNLANG_ACTION_FAIL) goto no_action_fail;
+
+done:
+       RESTORE_CALLER;
+
+       return ua;
+}
+
+/** Call a generic function that produces a result
+ *
+ * @param[out] p_result                The frame result.
+ * @param[in] request          The current request.
+ * @param[in] frame            The current frame.
+ */
+static unlang_action_t call_no_result(UNUSED unlang_result_t *p_result, request_t *request, unlang_stack_frame_t *frame)
+{
+       unlang_action_t                 ua;
+       unlang_frame_state_func_t       *state = talloc_get_type_abort(frame->state, unlang_frame_state_func_t);
+
+       STORE_CALLER;
+
+       RDEBUG4("Calling function %p (%s)", FUNC(state), state->func_name);
+       ua = state->func.nres(request, state->uctx);
+       FUNC(state) = NULL;
+       state->func_name = NULL;
+       if (REPEAT(state)) {
+               switch (ua) {
+               case UNLANG_ACTION_STOP_PROCESSING:
+                       break;
+
+               case UNLANG_ACTION_CALCULATE_RESULT:
+                       ua = call_no_result_repeat(p_result, request, frame);
+                       break;
+
+               case UNLANG_ACTION_FAIL:
+               no_action_fail:
+                       fr_assert_msg(0, "Function %s (%p) is not allowed to indicate failure via UNLANG_ACTION_FAIL",
+                                     state->func_name, state->func.nres);
+                       ua = UNLANG_ACTION_CALCULATE_RESULT;
+                       break;
+
+               default:
+                       frame_repeat(frame, call_no_result_repeat);
+               }
        }
+       if (ua == UNLANG_ACTION_FAIL) goto no_action_fail;
+
        RESTORE_CALLER;
 
        return ua;
@@ -326,6 +427,11 @@ unlang_action_t unlang_function_push_common(unlang_result_t *p_result,
        unlang_stack_frame_t            *frame;
        unlang_frame_state_func_t       *state;
 
+       if (!func && !repeat) {
+               fr_assert_msg(0, "function push must push at least one function!");
+               return UNLANG_ACTION_FAIL;
+       }
+
        /*
         *      Push module's function
         */
@@ -346,26 +452,10 @@ unlang_action_t unlang_function_push_common(unlang_result_t *p_result,
        state->type = type;
        state->uctx = uctx;
 
-       /*
-        *      Just skip to the repeat state directly
-        */
-       if (!func && repeat) {
-               FUNC(state) = repeat;
-               state->func_name = repeat_name;
-               repeatable_set(frame);  /* execute on the way back up */
-       /*
-        *      If we have both a function and a repeat,
-        *      then record them both, and execute
-        *      'func' first.  This will set the repeat
-        *      function to call 'repeat' on the way
-        *      back up the stack.
-        */
-       } else {
-               FUNC(state) = func;
-               state->func_name = func_name;
-               REPEAT(state) = repeat;
-               state->repeat_name = repeat_name;
-       }
+       FUNC(state) = func;
+       state->func_name = func_name;
+       REPEAT(state) = repeat;
+       state->repeat_name = repeat_name;
 
        return UNLANG_ACTION_PUSHED_CHILD;
 }
@@ -414,7 +504,12 @@ unlang_action_t _unlang_function_push_with_result(unlang_result_t *p_result,
        if (unlikely(ua == UNLANG_ACTION_FAIL)) return UNLANG_ACTION_FAIL;
 
        frame = frame_current(request);
-       frame->process = call_with_result;
+       if (!func && repeat) {
+               frame->process = call_with_result_repeat;
+               repeatable_set(frame);                          /* execute on the way back up */
+       } else {
+               frame->process = call_with_result;
+       }
 
        return ua;
 }
@@ -448,6 +543,7 @@ unlang_action_t _unlang_function_push_no_result(request_t *request,
                                                bool top_frame, void *uctx)
 {
        unlang_action_t ua;
+       unlang_stack_frame_t *frame;
 
        ua = unlang_function_push_common(NULL,
                                         request,
@@ -458,6 +554,12 @@ unlang_action_t _unlang_function_push_no_result(request_t *request,
 
        if (unlikely(ua == UNLANG_ACTION_FAIL)) return UNLANG_ACTION_FAIL;
 
+       frame = frame_current(request);
+       if (!func && repeat) {
+               frame->process = call_no_result_repeat;
+               repeatable_set(frame);                          /* execute on the way back up */
+       }
+
        /* frame->process = call_no_result - This is the default, we don't need to set it again */
 
        return ua;