]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect/threshold: expand cache support for rule tracking
authorVictor Julien <vjulien@oisf.net>
Fri, 19 Apr 2024 16:57:32 +0000 (18:57 +0200)
committerVictor Julien <vjulien@oisf.net>
Fri, 28 Jun 2024 07:46:34 +0000 (09:46 +0200)
Use the same hash key as for the regular threshold storage,
so include gid, rev, tentant id.

src/detect-engine-threshold.c

index fa40bb9f098f9f08e1c65a056ff2fa1170551727..4ae17b653c4256f526a4e6dae9669f9591f8f23f 100644 (file)
@@ -224,12 +224,17 @@ uint32_t ThresholdsExpire(const SCTime_t ts)
 
 #include "util-hash.h"
 
+#define TC_ADDRESS 0
+#define TC_SID     1
+#define TC_GID     2
+#define TC_REV     3
+#define TC_TENANT  4
+
 typedef struct ThresholdCacheItem {
     int8_t track; // by_src/by_dst
     int8_t ipv;
     int8_t retval;
-    uint32_t addr;
-    uint32_t sid;
+    uint32_t key[5];
     SCTime_t expires_at;
     RB_ENTRY(ThresholdCacheItem) rb;
 } ThresholdCacheItem;
@@ -297,8 +302,8 @@ static void ThresholdCacheExpire(SCTime_t now)
 
 static uint32_t ThresholdCacheHashFunc(HashTable *ht, void *data, uint16_t datalen)
 {
-    ThresholdCacheItem *tci = data;
-    int hash = tci->ipv * tci->track + tci->addr + tci->sid;
+    ThresholdCacheItem *e = data;
+    uint32_t hash = hashword(e->key, sizeof(e->key) / sizeof(uint32_t), 0) * (e->ipv + e->track);
     hash = hash % ht->array_size;
     return hash;
 }
@@ -308,8 +313,8 @@ static char ThresholdCacheHashCompareFunc(
 {
     ThresholdCacheItem *tci1 = data1;
     ThresholdCacheItem *tci2 = data2;
-    return tci1->ipv == tci2->ipv && tci1->track == tci2->track && tci1->addr == tci2->addr &&
-           tci1->sid == tci2->sid;
+    return tci1->ipv == tci2->ipv && tci1->track == tci2->track &&
+           memcmp(tci1->key, tci2->key, sizeof(tci1->key)) == 0;
 }
 
 static void ThresholdCacheHashFreeFunc(void *data)
@@ -319,7 +324,7 @@ static void ThresholdCacheHashFreeFunc(void *data)
 
 /// \brief Thread local cache
 static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, const uint32_t sid,
-        SCTime_t expires)
+        const uint32_t gid, const uint32_t rev, SCTime_t expires)
 {
     if (!threshold_cache_ht) {
         threshold_cache_ht = HashTableInit(256, ThresholdCacheHashFunc,
@@ -339,8 +344,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
         .track = track,
         .ipv = 4,
         .retval = retval,
-        .addr = addr,
-        .sid = sid,
+        .key[TC_ADDRESS] = addr,
+        .key[TC_SID] = sid,
+        .key[TC_GID] = gid,
+        .key[TC_REV] = rev,
+        .key[TC_TENANT] = p->tenant_id,
         .expires_at = expires,
     };
     ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0);
@@ -350,8 +358,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
             n->track = track;
             n->ipv = 4;
             n->retval = retval;
-            n->addr = addr;
-            n->sid = sid;
+            n->key[TC_ADDRESS] = addr;
+            n->key[TC_SID] = sid;
+            n->key[TC_GID] = gid;
+            n->key[TC_REV] = rev;
+            n->key[TC_TENANT] = p->tenant_id;
             n->expires_at = expires;
 
             if (HashTableAdd(threshold_cache_ht, n, 0) == 0) {
@@ -381,7 +392,8 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
  *  \retval -4 error - unsupported tracker
  *  \retval ret cached return code
  */
-static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid)
+static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid, const uint32_t gid,
+        const uint32_t rev)
 {
     cache_lookup_cnt++;
 
@@ -407,8 +419,11 @@ static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid)
     ThresholdCacheItem lookup = {
         .track = track,
         .ipv = 4,
-        .addr = addr,
-        .sid = sid,
+        .key[TC_ADDRESS] = addr,
+        .key[TC_SID] = sid,
+        .key[TC_GID] = gid,
+        .key[TC_REV] = rev,
+        .key[TC_TENANT] = p->tenant_id,
     };
     ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0);
     if (found) {
@@ -652,7 +667,7 @@ static int ThresholdSetup(const DetectThresholdData *td, ThresholdEntry *te,
 
 static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *te,
         const Packet *p, // ts only? - cache too
-        const uint32_t sid, PacketAlert *pa)
+        const uint32_t sid, const uint32_t gid, const uint32_t rev, PacketAlert *pa)
 {
     int ret = 0;
     const SCTime_t packet_time = p->ts;
@@ -670,7 +685,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
                     ret = 2;
 
                     if (PacketIsIPv4(p)) {
-                        SetupCache(p, td->track, (int8_t)ret, sid, entry);
+                        SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry);
                     }
                 }
             } else {
@@ -705,7 +720,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
                     ret = 2;
 
                     if (PacketIsIPv4(p)) {
-                        SetupCache(p, td->track, (int8_t)ret, sid, entry);
+                        SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry);
                     }
                 }
             } else {
@@ -819,7 +834,7 @@ static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const
             r = ThresholdSetup(td, te, p->ts, s->id, s->gid, s->rev, p->tenant_id);
         } else {
             // existing, check/update
-            r = ThresholdCheckUpdate(td, te, p, s->id, pa);
+            r = ThresholdCheckUpdate(td, te, p, s->id, s->gid, s->rev, pa);
         }
 
         (void)THashDecrUsecnt(res.data);
@@ -855,7 +870,7 @@ static int ThresholdHandlePacketFlow(Flow *f, Packet *p, const DetectThresholdDa
         }
     } else {
         // existing, check/update
-        ret = ThresholdCheckUpdate(td, found, p, sid, pa);
+        ret = ThresholdCheckUpdate(td, found, p, sid, gid, rev, pa);
     }
     return ret;
 }
@@ -886,7 +901,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx
         ret = ThresholdHandlePacketSuppress(p,td,s->id,s->gid);
     } else if (td->track == TRACK_SRC) {
         if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) {
-            int cache_ret = CheckCache(p, td->track, s->id);
+            int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev);
             if (cache_ret >= 0) {
                 SCReturnInt(cache_ret);
             }
@@ -895,7 +910,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx
         ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
     } else if (td->track == TRACK_DST) {
         if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) {
-            int cache_ret = CheckCache(p, td->track, s->id);
+            int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev);
             if (cache_ret >= 0) {
                 SCReturnInt(cache_ret);
             }