From: Victor Julien Date: Fri, 19 Apr 2024 16:57:32 +0000 (+0200) Subject: detect/threshold: expand cache support for rule tracking X-Git-Tag: suricata-8.0.0-beta1~1084 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7bcf364095a845addae5933d1fdf18aed86558ac;p=thirdparty%2Fsuricata.git detect/threshold: expand cache support for rule tracking Use the same hash key as for the regular threshold storage, so include gid, rev, tentant id. --- diff --git a/src/detect-engine-threshold.c b/src/detect-engine-threshold.c index fa40bb9f09..4ae17b653c 100644 --- a/src/detect-engine-threshold.c +++ b/src/detect-engine-threshold.c @@ -224,12 +224,17 @@ uint32_t ThresholdsExpire(const SCTime_t ts) #include "util-hash.h" +#define TC_ADDRESS 0 +#define TC_SID 1 +#define TC_GID 2 +#define TC_REV 3 +#define TC_TENANT 4 + typedef struct ThresholdCacheItem { int8_t track; // by_src/by_dst int8_t ipv; int8_t retval; - uint32_t addr; - uint32_t sid; + uint32_t key[5]; SCTime_t expires_at; RB_ENTRY(ThresholdCacheItem) rb; } ThresholdCacheItem; @@ -297,8 +302,8 @@ static void ThresholdCacheExpire(SCTime_t now) static uint32_t ThresholdCacheHashFunc(HashTable *ht, void *data, uint16_t datalen) { - ThresholdCacheItem *tci = data; - int hash = tci->ipv * tci->track + tci->addr + tci->sid; + ThresholdCacheItem *e = data; + uint32_t hash = hashword(e->key, sizeof(e->key) / sizeof(uint32_t), 0) * (e->ipv + e->track); hash = hash % ht->array_size; return hash; } @@ -308,8 +313,8 @@ static char ThresholdCacheHashCompareFunc( { ThresholdCacheItem *tci1 = data1; ThresholdCacheItem *tci2 = data2; - return tci1->ipv == tci2->ipv && tci1->track == tci2->track && tci1->addr == tci2->addr && - tci1->sid == tci2->sid; + return tci1->ipv == tci2->ipv && tci1->track == tci2->track && + memcmp(tci1->key, tci2->key, sizeof(tci1->key)) == 0; } static void ThresholdCacheHashFreeFunc(void *data) @@ -319,7 +324,7 @@ static void ThresholdCacheHashFreeFunc(void *data) /// \brief Thread local cache static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, const uint32_t sid, - SCTime_t expires) + const uint32_t gid, const uint32_t rev, SCTime_t expires) { if (!threshold_cache_ht) { threshold_cache_ht = HashTableInit(256, ThresholdCacheHashFunc, @@ -339,8 +344,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, .track = track, .ipv = 4, .retval = retval, - .addr = addr, - .sid = sid, + .key[TC_ADDRESS] = addr, + .key[TC_SID] = sid, + .key[TC_GID] = gid, + .key[TC_REV] = rev, + .key[TC_TENANT] = p->tenant_id, .expires_at = expires, }; ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0); @@ -350,8 +358,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, n->track = track; n->ipv = 4; n->retval = retval; - n->addr = addr; - n->sid = sid; + n->key[TC_ADDRESS] = addr; + n->key[TC_SID] = sid; + n->key[TC_GID] = gid; + n->key[TC_REV] = rev; + n->key[TC_TENANT] = p->tenant_id; n->expires_at = expires; if (HashTableAdd(threshold_cache_ht, n, 0) == 0) { @@ -381,7 +392,8 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, * \retval -4 error - unsupported tracker * \retval ret cached return code */ -static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid) +static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid, const uint32_t gid, + const uint32_t rev) { cache_lookup_cnt++; @@ -407,8 +419,11 @@ static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid) ThresholdCacheItem lookup = { .track = track, .ipv = 4, - .addr = addr, - .sid = sid, + .key[TC_ADDRESS] = addr, + .key[TC_SID] = sid, + .key[TC_GID] = gid, + .key[TC_REV] = rev, + .key[TC_TENANT] = p->tenant_id, }; ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0); if (found) { @@ -652,7 +667,7 @@ static int ThresholdSetup(const DetectThresholdData *td, ThresholdEntry *te, static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *te, const Packet *p, // ts only? - cache too - const uint32_t sid, PacketAlert *pa) + const uint32_t sid, const uint32_t gid, const uint32_t rev, PacketAlert *pa) { int ret = 0; const SCTime_t packet_time = p->ts; @@ -670,7 +685,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t ret = 2; if (PacketIsIPv4(p)) { - SetupCache(p, td->track, (int8_t)ret, sid, entry); + SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry); } } } else { @@ -705,7 +720,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t ret = 2; if (PacketIsIPv4(p)) { - SetupCache(p, td->track, (int8_t)ret, sid, entry); + SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry); } } } else { @@ -819,7 +834,7 @@ static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const r = ThresholdSetup(td, te, p->ts, s->id, s->gid, s->rev, p->tenant_id); } else { // existing, check/update - r = ThresholdCheckUpdate(td, te, p, s->id, pa); + r = ThresholdCheckUpdate(td, te, p, s->id, s->gid, s->rev, pa); } (void)THashDecrUsecnt(res.data); @@ -855,7 +870,7 @@ static int ThresholdHandlePacketFlow(Flow *f, Packet *p, const DetectThresholdDa } } else { // existing, check/update - ret = ThresholdCheckUpdate(td, found, p, sid, pa); + ret = ThresholdCheckUpdate(td, found, p, sid, gid, rev, pa); } return ret; } @@ -886,7 +901,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx ret = ThresholdHandlePacketSuppress(p,td,s->id,s->gid); } else if (td->track == TRACK_SRC) { if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) { - int cache_ret = CheckCache(p, td->track, s->id); + int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev); if (cache_ret >= 0) { SCReturnInt(cache_ret); } @@ -895,7 +910,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx ret = ThresholdGetFromHash(&ctx, p, s, td, pa); } else if (td->track == TRACK_DST) { if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) { - int cache_ret = CheckCache(p, td->track, s->id); + int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev); if (cache_ret >= 0) { SCReturnInt(cache_ret); }