]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect: add callback for when rate filter changes action
authorJason Ish <jason.ish@oisf.net>
Tue, 22 Apr 2025 23:15:12 +0000 (17:15 -0600)
committerVictor Julien <victor@inliniac.net>
Fri, 2 May 2025 18:10:09 +0000 (20:10 +0200)
This callback will be called when alert action has been changed due to a
rate filter. The user can then reset or customize the action in their
callback per their own logic.

As the callback is added to the current detection engine, make sure its
copied to the new detection engine on reload.

Ticket: #7673

examples/lib/custom/main.c
src/detect-engine-threshold.c
src/detect-engine.c
src/detect.h

index 4a08e4a3fe55d17f2943794f2df0a575e1386625..200936cae06c39f3b8a5bb7e784921321a3ba100 100644 (file)
@@ -16,6 +16,8 @@
  */
 
 #include "suricata.h"
+#include "detect.h"
+#include "detect-engine.h"
 #include "runmodes.h"
 #include "conf.h"
 #include "pcap.h"
@@ -145,6 +147,13 @@ done:
     pthread_exit(NULL);
 }
 
+static uint8_t RateFilterCallback(const Packet *p, const uint32_t sid, const uint32_t gid,
+        const uint32_t rev, uint8_t original_action, uint8_t new_action, void *arg)
+{
+    /* Don't change the action. */
+    return new_action;
+}
+
 int main(int argc, char **argv)
 {
     SuricataPreInit(argv[0]);
@@ -208,6 +217,8 @@ int main(int argc, char **argv)
 
     SuricataInit();
 
+    SCDetectEngineRegisterRateFilterCallback(RateFilterCallback, NULL);
+
     /* Create and start worker on its own thread, passing the PCAP
      * file as argument. This needs to be done in between SuricataInit
      * and SuricataPostInit. */
@@ -238,6 +249,7 @@ int main(int argc, char **argv)
      * function and SCTmThreadsSlotPacketLoopFinish that require them
      * to be run concurrently at this time. */
     SuricataShutdown();
+
     GlobalsDestroy();
 
     return EXIT_SUCCESS;
index b61661b911ad43468d5c37f9a6ef458562984779..5fd80d305d1fc53981f7bd0d172704ebf8a294cb 100644 (file)
@@ -709,7 +709,8 @@ static int ThresholdSetup(const DetectThresholdData *td, ThresholdEntry *te,
     return 0;
 }
 
-static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *te,
+static int ThresholdCheckUpdate(const DetectEngineCtx *de_ctx, const DetectThresholdData *td,
+        ThresholdEntry *te,
         const Packet *p, // ts only? - cache too
         const uint32_t sid, const uint32_t gid, const uint32_t rev, PacketAlert *pa)
 {
@@ -793,8 +794,9 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
                 te->current_count = 1;
             }
             break;
-        case TYPE_RATE:
+        case TYPE_RATE: {
             SCLogDebug("rate_filter");
+            const uint8_t original_action = pa->action;
             ret = 1;
             /* Check if we have a timeout enabled, if so,
              * we still matching (and enabling the new_action) */
@@ -821,7 +823,16 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
                     te->current_count = 1;
                 }
             }
+            if (de_ctx->RateFilterCallback && original_action != pa->action) {
+                pa->action = de_ctx->RateFilterCallback(p, sid, gid, rev, original_action,
+                        pa->action, de_ctx->rate_filter_callback_arg);
+                if (pa->action == original_action) {
+                    /* Reset back to original action, clear modified flag. */
+                    pa->flags &= ~PACKET_ALERT_FLAG_RATE_FILTER_MODIFIED;
+                }
+            }
             break;
+        }
         case TYPE_BACKOFF:
             SCLogDebug("backoff");
 
@@ -844,8 +855,8 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
     return ret;
 }
 
-static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const Signature *s,
-        const DetectThresholdData *td, PacketAlert *pa)
+static int ThresholdGetFromHash(const DetectEngineCtx *de_ctx, struct Thresholds *tctx,
+        const Packet *p, const Signature *s, const DetectThresholdData *td, PacketAlert *pa)
 {
     /* fast track for count 1 threshold */
     if (td->count == 1 && td->type == TYPE_THRESHOLD) {
@@ -894,7 +905,7 @@ static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const
             r = ThresholdSetup(td, te, p->ts, s->id, s->gid, s->rev, p->tenant_id);
         } else {
             // existing, check/update
-            r = ThresholdCheckUpdate(td, te, p, s->id, s->gid, s->rev, pa);
+            r = ThresholdCheckUpdate(de_ctx, td, te, p, s->id, s->gid, s->rev, pa);
         }
 
         (void)THashDecrUsecnt(res.data);
@@ -909,8 +920,8 @@ static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const
  *  \retval 1 normal match
  *  \retval 0 no match
  */
-static int ThresholdHandlePacketFlow(Flow *f, Packet *p, const DetectThresholdData *td,
-        uint32_t sid, uint32_t gid, uint32_t rev, PacketAlert *pa)
+static int ThresholdHandlePacketFlow(const DetectEngineCtx *de_ctx, Flow *f, Packet *p,
+        const DetectThresholdData *td, uint32_t sid, uint32_t gid, uint32_t rev, PacketAlert *pa)
 {
     int ret = 0;
     ThresholdEntry *found = ThresholdFlowLookupEntry(f, sid, gid, rev, p->tenant_id);
@@ -930,7 +941,7 @@ static int ThresholdHandlePacketFlow(Flow *f, Packet *p, const DetectThresholdDa
         }
     } else {
         // existing, check/update
-        ret = ThresholdCheckUpdate(td, found, p, sid, gid, rev, pa);
+        ret = ThresholdCheckUpdate(de_ctx, td, found, p, sid, gid, rev, pa);
     }
     return ret;
 }
@@ -967,7 +978,7 @@ int PacketAlertThreshold(const DetectEngineCtx *de_ctx, DetectEngineThreadCtx *d
             }
         }
 
-        ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
+        ret = ThresholdGetFromHash(de_ctx, &ctx, p, s, td, pa);
     } else if (td->track == TRACK_DST) {
         if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) {
             int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev);
@@ -976,14 +987,14 @@ int PacketAlertThreshold(const DetectEngineCtx *de_ctx, DetectEngineThreadCtx *d
             }
         }
 
-        ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
+        ret = ThresholdGetFromHash(de_ctx, &ctx, p, s, td, pa);
     } else if (td->track == TRACK_BOTH) {
-        ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
+        ret = ThresholdGetFromHash(de_ctx, &ctx, p, s, td, pa);
     } else if (td->track == TRACK_RULE) {
-        ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
+        ret = ThresholdGetFromHash(de_ctx, &ctx, p, s, td, pa);
     } else if (td->track == TRACK_FLOW) {
         if (p->flow) {
-            ret = ThresholdHandlePacketFlow(p->flow, p, td, s->id, s->gid, s->rev, pa);
+            ret = ThresholdHandlePacketFlow(de_ctx, p->flow, p, td, s->id, s->gid, s->rev, pa);
         }
     }
 
index b9b22d01dd69d59cbd7beedacc33832d785c96e4..0a381f9408a99fab6d8a98de29df708068d9034d 100644 (file)
@@ -4869,6 +4869,10 @@ int DetectEngineReload(const SCInstance *suri)
     }
     SCLogDebug("set up new_de_ctx %p", new_de_ctx);
 
+    /* Copy over callbacks. */
+    new_de_ctx->RateFilterCallback = old_de_ctx->RateFilterCallback;
+    new_de_ctx->rate_filter_callback_arg = old_de_ctx->rate_filter_callback_arg;
+
     /* add to master */
     DetectEngineAddToMaster(new_de_ctx);
 
@@ -5069,6 +5073,14 @@ bool DetectMd5ValidateCallback(
     return true;
 }
 
+void SCDetectEngineRegisterRateFilterCallback(SCDetectRateFilterFunc fn, void *arg)
+{
+    DetectEngineCtx *de_ctx = DetectEngineGetCurrent();
+    de_ctx->RateFilterCallback = fn;
+    de_ctx->rate_filter_callback_arg = arg;
+    DetectEngineDeReference(&de_ctx);
+}
+
 /*************************************Unittest*********************************/
 
 #ifdef UNITTESTS
index 37de9b9f58f5facacbf4190adbc2fbc0094982be..94ae258b4b3546a00dd3e3e814a9b0e906202052 100644 (file)
@@ -917,6 +917,16 @@ typedef struct {
     uint32_t content_inspect_min_size;
 } DetectFileDataCfg;
 
+/**
+ * \brief Function type for rate filter callback.
+ *
+ * This function should return the new action to be applied. If no change to the
+ * action is to be made, the callback should return the current action provided
+ * in the new_action parameter.
+ */
+typedef uint8_t (*SCDetectRateFilterFunc)(const Packet *p, uint32_t sid, uint32_t gid, uint32_t rev,
+        uint8_t original_action, uint8_t new_action, void *arg);
+
 /** \brief main detection engine ctx */
 typedef struct DetectEngineCtx_ {
     bool failure_fatal;
@@ -1131,8 +1141,23 @@ typedef struct DetectEngineCtx_ {
     HashTable *non_pf_engine_names;
 
     const char *firewall_rule_file_exclusive;
+
+    /* user provided rate filter callbacks. */
+    SCDetectRateFilterFunc RateFilterCallback;
+
+    /* use provided data to be passed to rate_filter_callback. */
+    void *rate_filter_callback_arg;
 } DetectEngineCtx;
 
+/**
+ * \brief Register a callback when a rate_filter has been applied to
+ *     an alert.
+ *
+ * This callback is added to the current detection engine and will be
+ * copied to all future detection engines over rule reloads.
+ */
+void SCDetectEngineRegisterRateFilterCallback(SCDetectRateFilterFunc cb, void *arg);
+
 /* Engine groups profiles (low, medium, high, custom) */
 enum {
     ENGINE_PROFILE_UNKNOWN,