]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect/port: use qsort instead of insert sort
authorVictor Julien <vjulien@oisf.net>
Mon, 26 Feb 2024 16:08:21 +0000 (21:38 +0530)
committerVictor Julien <victor@inliniac.net>
Mon, 4 Mar 2024 10:50:30 +0000 (11:50 +0100)
Instead of using in place insertion sort on linked list based on two
keys, convert the linked list to an array, perform sorting on it using
qsort and convert it back to a linked list. This turns out to be much
faster.

Ticket #6795

src/detect-engine-build.c

index e812f861226e9f0414b0c260a1768d3c7906f6d6..67fb7405317c24d546f016f71517f210440dd7f4 100644 (file)
@@ -1126,9 +1126,191 @@ static int RuleSetWhitelist(Signature *s)
     return wl;
 }
 
-int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list, DetectPort **newhead,
-        uint32_t unique_groups, int (*CompareFunc)(DetectPort *, DetectPort *));
-int CreateGroupedPortListCmpCnt(DetectPort *a, DetectPort *b);
+static int SortCompare(const void *a, const void *b)
+{
+    const DetectPort *pa = *(const DetectPort **)a;
+    const DetectPort *pb = *(const DetectPort **)b;
+
+    if (pa->sh->init->score < pb->sh->init->score) {
+        return 1;
+    } else if (pa->sh->init->score > pb->sh->init->score) {
+        return -1;
+    }
+
+    if (pa->sh->init->sig_cnt < pb->sh->init->sig_cnt) {
+        return 1;
+    } else if (pa->sh->init->sig_cnt > pb->sh->init->sig_cnt) {
+        return -1;
+    }
+
+    /* Hack to make the qsort output deterministic across platforms.
+     * This had to be done because the order of equal elements sorted
+     * by qsort is undeterministic and showed different output on BSD,
+     * MacOS and Windows. Sorting based on id makes it deterministic. */
+    if (pa->sh->id < pb->sh->id)
+        return -1;
+
+    return 1;
+}
+
+static inline void SortGroupList(
+        uint32_t *groups, DetectPort **list, int (*CompareFunc)(const void *, const void *))
+{
+    int cnt = 0;
+    for (DetectPort *x = *list; x != NULL; x = x->next) {
+        DEBUG_VALIDATE_BUG_ON(x->port > x->port2);
+        cnt++;
+    }
+    if (cnt <= 1)
+        return;
+
+    /* build temporary array to sort with qsort */
+    DetectPort **array = (DetectPort **)SCCalloc(cnt, sizeof(DetectPort *));
+    if (array == NULL)
+        return;
+
+    int idx = 0;
+    for (DetectPort *x = *list; x != NULL;) {
+        /* assign a temporary id to resolve otherwise equal groups */
+        x->sh->id = idx + 1;
+        SigGroupHeadSetSigCnt(x->sh, 0);
+        DetectPort *next = x->next;
+        x->next = x->prev = x->last = NULL;
+        DEBUG_VALIDATE_BUG_ON(x->port > x->port2);
+        array[idx++] = x;
+        x = next;
+    }
+    DEBUG_VALIDATE_BUG_ON(cnt != idx);
+
+    qsort(array, idx, sizeof(DetectPort *), SortCompare);
+
+    /* rebuild the list based on the qsort-ed array */
+    DetectPort *new_list = NULL, *tail = NULL;
+    for (int i = 0; i < idx; i++) {
+        DetectPort *p = array[i];
+        /* unset temporary group id */
+        p->sh->id = 0;
+
+        if (new_list == NULL) {
+            new_list = p;
+        }
+        if (tail != NULL) {
+            tail->next = p;
+        }
+        p->prev = tail;
+        tail = p;
+    }
+
+    *list = new_list;
+    *groups = idx;
+
+#if DEBUG
+    int dbgcnt = 0;
+    SCLogDebug("SORTED LIST:");
+    for (DetectPort *tmp = *list; tmp != NULL; tmp = tmp->next) {
+        SCLogDebug("item:= [%u:%u]; score: %d; sig_cnt: %d", tmp->port, tmp->port2,
+                tmp->sh->init->score, tmp->sh->init->sig_cnt);
+        dbgcnt++;
+        BUG_ON(dbgcnt > cnt);
+    }
+#endif
+    SCFree(array);
+}
+/** \internal
+ *  \brief Create a list of DetectPort objects sorted based on CompareFunc's
+ *         logic.
+ *
+ *  List can limit the number of groups. In this case an extra "join" group
+ *  is created that contains the sigs belonging to that. It's *appended* to
+ *  the list, meaning that if the list is walked linearly it's found last.
+ *  The joingr is meant to be a catch all.
+ *
+ */
+static int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list,
+        DetectPort **newhead, uint32_t unique_groups,
+        int (*CompareFunc)(const void *, const void *))
+{
+    DetectPort *tmplist = NULL, *joingr = NULL;
+    uint32_t groups = 0;
+
+    /* insert the ports into the tmplist, where it will
+     * be sorted descending on 'cnt' and on whether a group
+     * is whitelisted. */
+    tmplist = port_list;
+    SortGroupList(&groups, &tmplist, SortCompare);
+    uint32_t left = unique_groups;
+    if (left == 0)
+        left = groups;
+
+    /* create another list: take the port groups from above
+     * and add them to the 2nd list until we have met our
+     * count. The rest is added to the 'join' group. */
+    DetectPort *tmplist2 = NULL, *tmplist2_tail = NULL;
+    DetectPort *gr, *next_gr;
+    for (gr = tmplist; gr != NULL;) {
+        next_gr = gr->next;
+
+        SCLogDebug("temp list gr %p %u:%u", gr, gr->port, gr->port2);
+        DetectPortPrint(gr);
+
+        /* if we've set up all the unique groups, add the rest to the
+         * catch-all joingr */
+        if (left == 0) {
+            if (joingr == NULL) {
+                DetectPortParse(de_ctx, &joingr, "0:65535");
+                if (joingr == NULL) {
+                    goto error;
+                }
+                SCLogDebug("joingr => %u-%u", joingr->port, joingr->port2);
+                joingr->next = NULL;
+            }
+            SigGroupHeadCopySigs(de_ctx, gr->sh, &joingr->sh);
+
+            /* when a group's sigs are added to the joingr, we can free it */
+            gr->next = NULL;
+            DetectPortFree(de_ctx, gr);
+            /* append */
+        } else {
+            gr->next = NULL;
+
+            if (tmplist2 == NULL) {
+                tmplist2 = gr;
+                tmplist2_tail = gr;
+            } else {
+                tmplist2_tail->next = gr;
+                tmplist2_tail = gr;
+            }
+        }
+
+        if (left > 0)
+            left--;
+
+        gr = next_gr;
+    }
+
+    /* if present, append the joingr that covers the rest */
+    if (joingr != NULL) {
+        SCLogDebug("appending joingr %p %u:%u", joingr, joingr->port, joingr->port2);
+
+        if (tmplist2 == NULL) {
+            tmplist2 = joingr;
+            // tmplist2_tail = joingr;
+        } else {
+            tmplist2_tail->next = joingr;
+            // tmplist2_tail = joingr;
+        }
+    } else {
+        SCLogDebug("no joingr");
+    }
+
+    /* pass back our new list to the caller */
+    *newhead = tmplist2;
+    DetectPortPrintList(*newhead);
+
+    return 0;
+error:
+    return -1;
+}
 
 #define RANGE_PORT  1
 #define SINGLE_PORT 2
@@ -1388,7 +1570,7 @@ static DetectPort *RulesGroupByPorts(DetectEngineCtx *de_ctx, uint8_t ipproto, u
     DetectPort *newlist = NULL;
     uint16_t groupmax = (direction == SIG_FLAG_TOCLIENT) ? de_ctx->max_uniq_toclient_groups :
                                                            de_ctx->max_uniq_toserver_groups;
-    CreateGroupedPortList(de_ctx, list, &newlist, groupmax, CreateGroupedPortListCmpCnt);
+    CreateGroupedPortList(de_ctx, list, &newlist, groupmax, SortCompare);
     list = newlist;
 
     /* step 4: deduplicate the SGH's */
@@ -1668,179 +1850,6 @@ error:
     return -1;
 }
 
-static int PortGroupWhitelist(const DetectPort *a)
-{
-    return a->sh->init->score;
-}
-
-int CreateGroupedPortListCmpCnt(DetectPort *a, DetectPort *b)
-{
-    if (PortGroupWhitelist(a) && !PortGroupWhitelist(b)) {
-        SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port, a->port2,
-                a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2,
-                b->sh->init->sig_cnt, PortGroupWhitelist(b));
-        return 1;
-    } else if (!PortGroupWhitelist(a) && PortGroupWhitelist(b)) {
-        SCLogDebug("%u:%u (cnt %u, wl %d) loses against %u:%u (cnt %u, wl %d)", a->port, a->port2,
-                a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2,
-                b->sh->init->sig_cnt, PortGroupWhitelist(b));
-        return 0;
-    } else if (PortGroupWhitelist(a) > PortGroupWhitelist(b)) {
-        SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port, a->port2,
-                a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2,
-                b->sh->init->sig_cnt, PortGroupWhitelist(b));
-        return 1;
-    } else if (PortGroupWhitelist(a) == PortGroupWhitelist(b)) {
-        if (a->sh->init->sig_cnt > b->sh->init->sig_cnt) {
-            SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port,
-                    a->port2, a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2,
-                    b->sh->init->sig_cnt, PortGroupWhitelist(b));
-            return 1;
-        }
-    }
-
-    SCLogDebug("%u:%u (cnt %u, wl %d) loses against %u:%u (cnt %u, wl %d)", a->port, a->port2,
-            a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, b->sh->init->sig_cnt,
-            PortGroupWhitelist(b));
-    return 0;
-}
-
-/** \internal
- *  \brief Create a list of DetectPort objects sorted based on CompareFunc's
- *         logic.
- *
- *  List can limit the number of groups. In this case an extra "join" group
- *  is created that contains the sigs belonging to that. It's *appended* to
- *  the list, meaning that if the list is walked linearly it's found last.
- *  The joingr is meant to be a catch all.
- *
- */
-int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list, DetectPort **newhead,
-        uint32_t unique_groups, int (*CompareFunc)(DetectPort *, DetectPort *))
-{
-    DetectPort *tmplist = NULL, *joingr = NULL;
-    char insert = 0;
-    uint32_t groups = 0;
-    DetectPort *list;
-
-    /* insert the ports into the tmplist, where it will
-     * be sorted descending on 'cnt' and on whether a group
-     * is whitelisted. */
-
-    DetectPort *oldhead = port_list;
-    while (oldhead) {
-        /* take the top of the list */
-        list = oldhead;
-        oldhead = oldhead->next;
-        list->next = NULL;
-
-        groups++;
-        SigGroupHeadSetSigCnt(list->sh, 0);
-
-        /* insert it */
-        DetectPort *tmpgr = tmplist, *prevtmpgr = NULL;
-        if (tmplist == NULL) {
-            /* empty list, set head */
-            tmplist = list;
-        } else {
-            /* look for the place to insert */
-            for ( ; tmpgr != NULL && !insert; tmpgr = tmpgr->next) {
-                if (CompareFunc(list, tmpgr) == 1) {
-                    if (tmpgr == tmplist) {
-                        list->next = tmplist;
-                        tmplist = list;
-                        SCLogDebug("new list top: %u:%u", tmplist->port, tmplist->port2);
-                    } else {
-                        list->next = prevtmpgr->next;
-                        prevtmpgr->next = list;
-                    }
-                    insert = 1;
-                    break;
-                }
-                prevtmpgr = tmpgr;
-            }
-            if (insert == 0) {
-                list->next = NULL;
-                prevtmpgr->next = list;
-            }
-            insert = 0;
-        }
-    }
-
-    uint32_t left = unique_groups;
-    if (left == 0)
-        left = groups;
-
-    /* create another list: take the port groups from above
-     * and add them to the 2nd list until we have met our
-     * count. The rest is added to the 'join' group. */
-    DetectPort *tmplist2 = NULL, *tmplist2_tail = NULL;
-    DetectPort *gr, *next_gr;
-    for (gr = tmplist; gr != NULL; ) {
-        next_gr = gr->next;
-
-        SCLogDebug("temp list gr %p %u:%u", gr, gr->port, gr->port2);
-        DetectPortPrint(gr);
-
-        /* if we've set up all the unique groups, add the rest to the
-         * catch-all joingr */
-        if (left == 0) {
-            if (joingr == NULL) {
-                DetectPortParse(de_ctx, &joingr, "0:65535");
-                if (joingr == NULL) {
-                    goto error;
-                }
-                SCLogDebug("joingr => %u-%u", joingr->port, joingr->port2);
-                joingr->next = NULL;
-            }
-            SigGroupHeadCopySigs(de_ctx,gr->sh,&joingr->sh);
-
-            /* when a group's sigs are added to the joingr, we can free it */
-            gr->next = NULL;
-            DetectPortFree(de_ctx, gr);
-        /* append */
-        } else {
-            gr->next = NULL;
-
-            if (tmplist2 == NULL) {
-                tmplist2 = gr;
-                tmplist2_tail = gr;
-            } else {
-                tmplist2_tail->next = gr;
-                tmplist2_tail = gr;
-            }
-        }
-
-        if (left > 0)
-            left--;
-
-        gr = next_gr;
-    }
-
-    /* if present, append the joingr that covers the rest */
-    if (joingr != NULL) {
-        SCLogDebug("appending joingr %p %u:%u", joingr, joingr->port, joingr->port2);
-
-        if (tmplist2 == NULL) {
-            tmplist2 = joingr;
-            //tmplist2_tail = joingr;
-        } else {
-            tmplist2_tail->next = joingr;
-            //tmplist2_tail = joingr;
-        }
-    } else {
-        SCLogDebug("no joingr");
-    }
-
-    /* pass back our new list to the caller */
-    *newhead = tmplist2;
-    DetectPortPrintList(*newhead);
-
-    return 0;
-error:
-    return -1;
-}
-
 /**
  *  \internal
  *  \brief add a decoder event signature to the detection engine ctx