]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Rework] Return back N-ary optimizations for arithmetic-alike expressions
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 14 Jul 2025 20:22:06 +0000 (21:22 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 14 Jul 2025 20:22:06 +0000 (21:22 +0100)
src/libutil/expression.c

index e36964e7254a82cf25c2f53872f9d658d5b62f1c..cac7594d6a179e5d6bdf8ac3ef7ea2516f8eaf15 100644 (file)
@@ -80,6 +80,10 @@ struct rspamd_expr_process_data {
        /* != NULL if trace is collected */
        GPtrArray *trace;
        rspamd_expression_process_cb process_closure;
+       /* Optimization thresholds for arithmetic operations */
+       double threshold;
+       enum rspamd_expression_op threshold_op;
+       gboolean has_threshold;
 };
 
 #define msg_debug_expression(...) rspamd_conditional_debug_fast(NULL, NULL,                                        \
@@ -1197,11 +1201,56 @@ error_label:
        return FALSE;
 }
 
+/*
+ * Analyze AST node to determine if arithmetic operations can be optimized
+ * based on comparison context (e.g., A + B + C > 2 can stop at 3)
+ */
+static void
+rspamd_ast_analyze_node(GNode *node, struct rspamd_expr_process_data *process_data)
+{
+       struct rspamd_expression_elt *elt;
+       GNode *child;
+
+       if (!node || !node->data) {
+               return;
+       }
+
+       elt = (struct rspamd_expression_elt *) node->data;
+
+       /* Check if this is a comparison operation with arithmetic child */
+       if (elt->type == ELT_OP && (elt->p.op.op_flags & RSPAMD_EXPRESSION_COMPARISON)) {
+               /* Look for arithmetic operations in children */
+               child = node->children;
+               if (child && child->next) {
+                       GNode *left = child;
+                       GNode *right = child->next;
+                       struct rspamd_expression_elt *left_elt = left->data;
+                       struct rspamd_expression_elt *right_elt = right->data;
+
+                       /* Check if left child is arithmetic operation and right is limit */
+                       if (left_elt->type == ELT_OP &&
+                               (left_elt->p.op.op_flags & RSPAMD_EXPRESSION_ARITHMETIC) &&
+                               (left_elt->p.op.op_flags & RSPAMD_EXPRESSION_NARY) &&
+                               right_elt->type == ELT_LIMIT) {
+
+                               /* Set threshold for arithmetic optimization */
+                               process_data->has_threshold = TRUE;
+                               process_data->threshold = right_elt->p.lim;
+                               process_data->threshold_op = elt->p.op.op;
+
+                               msg_debug_expression_verbose("detected arithmetic optimization: %s %.1f",
+                                                                                        rspamd_expr_op_to_str(elt->p.op.op),
+                                                                                        right_elt->p.lim);
+                       }
+               }
+       }
+}
+
 /*
  *  Node optimizer function: skip nodes that are not relevant
  */
 static gboolean
-rspamd_ast_node_done(struct rspamd_expression_elt *elt, double acc)
+rspamd_ast_node_done(struct rspamd_expression_elt *elt, double acc, struct rspamd_expr_process_data *process_data)
 {
        gboolean ret = FALSE;
 
@@ -1217,6 +1266,47 @@ rspamd_ast_node_done(struct rspamd_expression_elt *elt, double acc)
        case OP_OR:
                ret = acc != 0;
                break;
+       case OP_PLUS:
+       case OP_MULT:
+               /* Handle arithmetic operations with thresholds */
+               if (process_data->has_threshold) {
+                       switch (process_data->threshold_op) {
+                       case OP_GT:
+                               /* For A + B + C > 2, stop when acc > threshold */
+                               ret = acc > process_data->threshold;
+                               break;
+                       case OP_GE:
+                               /* For A + B + C >= 2, stop when acc >= threshold */
+                               ret = acc >= process_data->threshold;
+                               break;
+                       case OP_LT:
+                               /* For A + B + C < 2, stop when acc >= threshold (result will be false) */
+                               ret = acc >= process_data->threshold;
+                               break;
+                       case OP_LE:
+                               /* For A + B + C <= 2, stop when acc > threshold (result will be false) */
+                               ret = acc > process_data->threshold;
+                               break;
+                       case OP_EQ:
+                               /* For A + B + C == 2, stop when acc > threshold (result will be false) */
+                               ret = acc > process_data->threshold;
+                               break;
+                       case OP_NE:
+                               /* For A + B + C != 2, stop when acc > threshold (result will be true) */
+                               ret = acc > process_data->threshold;
+                               break;
+                       default:
+                               break;
+                       }
+
+                       if (ret) {
+                               msg_debug_expression_verbose("arithmetic optimization triggered: %s %.1f %s %.1f",
+                                                                                        rspamd_expr_op_to_str(elt->p.op.op), acc,
+                                                                                        rspamd_expr_op_to_str(process_data->threshold_op),
+                                                                                        process_data->threshold);
+                       }
+               }
+               break;
        default:
                break;
        }
@@ -1340,9 +1430,28 @@ rspamd_ast_process_node(struct rspamd_expression *e, GNode *node,
        double val;
        gboolean calc_ticks = FALSE;
        __attribute__((unused)) const char *op_name = NULL;
+       gboolean saved_has_threshold = FALSE;
+       double saved_threshold = 0.0;
+       enum rspamd_expression_op saved_threshold_op = OP_INVALID;
 
        elt = node->data;
 
+       /* Analyze node for optimization opportunities */
+       if (elt->type == ELT_OP && (elt->p.op.op_flags & RSPAMD_EXPRESSION_COMPARISON)) {
+               /* Save current threshold state */
+               saved_has_threshold = process_data->has_threshold;
+               saved_threshold = process_data->threshold;
+               saved_threshold_op = process_data->threshold_op;
+
+               /* Reset threshold state */
+               process_data->has_threshold = FALSE;
+               process_data->threshold = 0.0;
+               process_data->threshold_op = OP_INVALID;
+
+               /* Analyze for optimization opportunities */
+               rspamd_ast_analyze_node(node, process_data);
+       }
+
        switch (elt->type) {
        case ELT_ATOM:
                if (!(elt->flags & RSPAMD_EXPR_FLAG_PROCESSED)) {
@@ -1400,7 +1509,7 @@ rspamd_ast_process_node(struct rspamd_expression *e, GNode *node,
 
                                /* Check if we need to process further */
                                if (!(process_data->flags & RSPAMD_EXPRESSION_FLAG_NOOPT)) {
-                                       if (rspamd_ast_node_done(elt, acc)) {
+                                       if (rspamd_ast_node_done(elt, acc, process_data)) {
                                                msg_debug_expression_verbose("optimizer: done");
                                                return acc;
                                        }
@@ -1443,6 +1552,13 @@ rspamd_ast_process_node(struct rspamd_expression *e, GNode *node,
                break;
        }
 
+       /* Restore threshold state if it was saved */
+       if (elt->type == ELT_OP && (elt->p.op.op_flags & RSPAMD_EXPRESSION_COMPARISON)) {
+               process_data->has_threshold = saved_has_threshold;
+               process_data->threshold = saved_threshold;
+               process_data->threshold_op = saved_threshold_op;
+       }
+
        return acc;
 }
 
@@ -1477,6 +1593,9 @@ rspamd_process_expression_closure(struct rspamd_expression *expr,
        pd.process_closure = cb;
        pd.flags = flags;
        pd.ud = runtime_ud;
+       pd.has_threshold = FALSE;
+       pd.threshold = 0.0;
+       pd.threshold_op = OP_INVALID;
 
        if (track) {
                pd.trace = g_ptr_array_sized_new(32);