]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect: improve prepare mpms routine
authorVictor Julien <vjulien@oisf.net>
Tue, 25 Apr 2023 09:23:47 +0000 (11:23 +0200)
committerVictor Julien <vjulien@oisf.net>
Tue, 25 Apr 2023 09:36:37 +0000 (11:36 +0200)
Based on hash table work in:
e624328deb25 ("detect: split mpm per alproto for file.data & others")

Instead of using a large stack array use a hash table for the intermediate
steps of the mpm build.

src/detect-engine-mpm.c

index e43593aa114852334d72f1b51a2a0550e66022af..66cfdc859098aae127825163ff8933ba6dbf8e46 100644 (file)
@@ -1556,28 +1556,90 @@ static void SetRawReassemblyFlag(DetectEngineCtx *de_ctx, SigGroupHead *sgh)
     SCLogDebug("rule group %p does NOT have SIG_GROUP_HEAD_HAVERAWSTREAM set", sgh);
 }
 
+typedef struct DetectBufferInstance {
+    // key
+    int list;
+
+    struct SidsArray ts;
+    struct SidsArray tc;
+} DetectBufferInstance;
+
+static uint32_t DetectBufferInstanceHashFunc(HashListTable *ht, void *data, uint16_t datalen)
+{
+    const DetectBufferInstance *ms = (const DetectBufferInstance *)data;
+    uint32_t hash = ms->list;
+    return hash % ht->array_size;
+}
+
+static char DetectBufferInstanceCompareFunc(void *data1, uint16_t len1, void *data2, uint16_t len2)
+{
+    const DetectBufferInstance *ms1 = (DetectBufferInstance *)data1;
+    const DetectBufferInstance *ms2 = (DetectBufferInstance *)data2;
+    return (ms1->list == ms2->list);
+}
+
+static void DetectBufferInstanceFreeFunc(void *ptr)
+{
+    DetectBufferInstance *ms = ptr;
+    if (ms->ts.sids_array != NULL)
+        SCFree(ms->ts.sids_array);
+    if (ms->tc.sids_array != NULL)
+        SCFree(ms->tc.sids_array);
+    SCFree(ms);
+}
+
+static HashListTable *DetectBufferInstanceInit(void)
+{
+    return HashListTableInit(4096, DetectBufferInstanceHashFunc, DetectBufferInstanceCompareFunc,
+            DetectBufferInstanceFreeFunc);
+}
+
 static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
 {
+    HashListTable *bufs = DetectBufferInstanceInit();
+    BUG_ON(bufs == NULL);
+
     const int max_buffer_id = de_ctx->buffer_type_id + 1;
-    struct SidsArray sids[max_buffer_id][2];
-    memset(sids, 0, sizeof(sids));
     const uint32_t max_sid = DetectEngineGetMaxSigId(de_ctx) / 8 + 1;
 
+    int types[max_buffer_id];
+    memset(types, 0, sizeof(types));
+
     /* flag the list+directions we have engines for as active */
     for (DetectBufferMpmRegistery *a = de_ctx->pkt_mpms_list; a != NULL; a = a->next) {
-        struct SidsArray *sa = &sids[a->sm_list][0];
-        sa->active = true;
-        sa->type = a->type;
+        types[a->sm_list] = a->type;
+
+        DetectBufferInstance lookup = {
+            .list = a->sm_list,
+        };
+        DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+        if (instance == NULL) {
+            instance = SCCalloc(1, sizeof(*instance));
+            BUG_ON(instance == NULL);
+            instance->list = a->sm_list;
+            HashListTableAdd(bufs, instance, 0);
+        }
+        instance->ts.active = true;
+        instance->tc.active = true;
     }
     for (DetectBufferMpmRegistery *a = de_ctx->app_mpms_list; a != NULL; a = a->next) {
-        sids[a->sm_list][0].type = a->type;
-        if ((a->direction == SIG_FLAG_TOSERVER) && SGH_DIRECTION_TS(sh)) {
-            struct SidsArray *sa = &sids[a->sm_list][0];
-            sa->active = true;
-        }
-        if ((a->direction == SIG_FLAG_TOCLIENT) && SGH_DIRECTION_TC(sh)) {
-            struct SidsArray *sa = &sids[a->sm_list][1];
-            sa->active = true;
+        const bool add_ts = ((a->direction == SIG_FLAG_TOSERVER) && SGH_DIRECTION_TS(sh));
+        const bool add_tc = ((a->direction == SIG_FLAG_TOCLIENT) && SGH_DIRECTION_TC(sh));
+        if (add_ts || add_tc) {
+            types[a->sm_list] = a->type;
+
+            DetectBufferInstance lookup = {
+                .list = a->sm_list,
+            };
+            DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+            if (instance == NULL) {
+                instance = SCCalloc(1, sizeof(*instance));
+                BUG_ON(instance == NULL);
+                instance->list = a->sm_list;
+                HashListTableAdd(bufs, instance, 0);
+            }
+            instance->ts.active |= add_ts;
+            instance->tc.active |= add_tc;
         }
     }
 
@@ -1593,11 +1655,17 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
         if (list == DETECT_SM_LIST_PMATCH)
             continue;
 
-        switch (sids[list][0].type) {
+        switch (types[list]) {
             /* app engines are direction aware */
-            case DETECT_BUFFER_MPM_TYPE_APP:
+            case DETECT_BUFFER_MPM_TYPE_APP: {
+                DetectBufferInstance lookup = {
+                    .list = list,
+                };
+                DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+                if (instance == NULL)
+                    continue;
                 if (s->flags & SIG_FLAG_TOSERVER) {
-                    struct SidsArray *sa = &sids[list][0];
+                    struct SidsArray *sa = &instance->ts;
                     if (sa->active) {
                         if (sa->sids_array == NULL) {
                             sa->sids_array = SCCalloc(1, max_sid);
@@ -1605,10 +1673,11 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
                             BUG_ON(sa->sids_array == NULL); // TODO
                         }
                         sa->sids_array[s->num / 8] |= 1 << (s->num % 8);
+                        SCLogDebug("instance %p: stored %u/%u ts", instance, s->id, s->num);
                     }
                 }
                 if (s->flags & SIG_FLAG_TOCLIENT) {
-                    struct SidsArray *sa = &sids[list][1];
+                    struct SidsArray *sa = &instance->tc;
                     if (sa->active) {
                         if (sa->sids_array == NULL) {
                             sa->sids_array = SCCalloc(1, max_sid);
@@ -1616,12 +1685,20 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
                             BUG_ON(sa->sids_array == NULL); // TODO
                         }
                         sa->sids_array[s->num / 8] |= 1 << (s->num % 8);
+                        SCLogDebug("instance %p: stored %u/%u tc", instance, s->id, s->num);
                     }
                 }
                 break;
+            }
             /* pkt engines are directionless, so only use index 0 */
             case DETECT_BUFFER_MPM_TYPE_PKT: {
-                struct SidsArray *sa = &sids[list][0];
+                DetectBufferInstance lookup = {
+                    .list = list,
+                };
+                DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+                if (instance == NULL)
+                    continue;
+                struct SidsArray *sa = &instance->ts;
                 if (sa->active) {
                     if (sa->sids_array == NULL) {
                         sa->sids_array = SCCalloc(1, max_sid);
@@ -1645,7 +1722,16 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
     BUG_ON(sh->init->pkt_mpms == NULL);
 
     for (DetectBufferMpmRegistery *a = de_ctx->pkt_mpms_list; a != NULL; a = a->next) {
-        struct SidsArray *sa = &sids[a->sm_list][0];
+        DetectBufferInstance lookup = {
+            .list = a->sm_list,
+        };
+        DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+        if (instance == NULL) {
+            continue;
+        }
+        struct SidsArray *sa = &instance->ts;
+        if (!sa->active)
+            continue;
 
         MpmStore *mpm_store = MpmStorePrepareBufferPkt(de_ctx, sh, a, sa);
         if (mpm_store != NULL) {
@@ -1668,8 +1754,18 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
     for (DetectBufferMpmRegistery *a = de_ctx->app_mpms_list; a != NULL; a = a->next) {
         if ((a->direction == SIG_FLAG_TOSERVER && SGH_DIRECTION_TS(sh)) ||
                 (a->direction == SIG_FLAG_TOCLIENT && SGH_DIRECTION_TC(sh))) {
-            const int dir = a->direction == SIG_FLAG_TOCLIENT;
-            struct SidsArray *sa = &sids[a->sm_list][dir];
+
+            DetectBufferInstance lookup = {
+                .list = a->sm_list,
+            };
+            DetectBufferInstance *instance = HashListTableLookup(bufs, &lookup, 0);
+            if (instance == NULL) {
+                continue;
+            }
+            struct SidsArray *sa =
+                    (a->direction == SIG_FLAG_TOSERVER) ? &instance->ts : &instance->tc;
+            if (!sa->active)
+                continue;
 
             MpmStore *mpm_store = MpmStorePrepareBufferAppLayer(de_ctx, sh, a, sa);
             if (mpm_store != NULL) {
@@ -1689,17 +1785,7 @@ static void PrepareMpms(DetectEngineCtx *de_ctx, SigGroupHead *sh)
             }
         }
     }
-
-    /* free temp sig arrays */
-    for (int i = 0; i < max_buffer_id; i++) {
-        struct SidsArray *sa;
-        sa = &sids[i][0];
-        if (sa->sids_array != NULL)
-            SCFree(sa->sids_array);
-        sa = &sids[i][1];
-        if (sa->sids_array != NULL)
-            SCFree(sa->sids_array);
-    }
+    HashListTableFree(bufs);
 }
 
 /** \brief Prepare the pattern matcher ctx in a sig group head.