From: Victor Julien Date: Mon, 26 Feb 2024 16:08:21 +0000 (+0530) Subject: detect/port: use qsort instead of insert sort X-Git-Tag: suricata-8.0.0-beta1~1675 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e7e4305d91a05acde921b5bc87f7adbdf566def6;p=thirdparty%2Fsuricata.git detect/port: use qsort instead of insert sort Instead of using in place insertion sort on linked list based on two keys, convert the linked list to an array, perform sorting on it using qsort and convert it back to a linked list. This turns out to be much faster. Ticket #6795 --- diff --git a/src/detect-engine-build.c b/src/detect-engine-build.c index e812f86122..67fb740531 100644 --- a/src/detect-engine-build.c +++ b/src/detect-engine-build.c @@ -1126,9 +1126,191 @@ static int RuleSetWhitelist(Signature *s) return wl; } -int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list, DetectPort **newhead, - uint32_t unique_groups, int (*CompareFunc)(DetectPort *, DetectPort *)); -int CreateGroupedPortListCmpCnt(DetectPort *a, DetectPort *b); +static int SortCompare(const void *a, const void *b) +{ + const DetectPort *pa = *(const DetectPort **)a; + const DetectPort *pb = *(const DetectPort **)b; + + if (pa->sh->init->score < pb->sh->init->score) { + return 1; + } else if (pa->sh->init->score > pb->sh->init->score) { + return -1; + } + + if (pa->sh->init->sig_cnt < pb->sh->init->sig_cnt) { + return 1; + } else if (pa->sh->init->sig_cnt > pb->sh->init->sig_cnt) { + return -1; + } + + /* Hack to make the qsort output deterministic across platforms. + * This had to be done because the order of equal elements sorted + * by qsort is undeterministic and showed different output on BSD, + * MacOS and Windows. Sorting based on id makes it deterministic. */ + if (pa->sh->id < pb->sh->id) + return -1; + + return 1; +} + +static inline void SortGroupList( + uint32_t *groups, DetectPort **list, int (*CompareFunc)(const void *, const void *)) +{ + int cnt = 0; + for (DetectPort *x = *list; x != NULL; x = x->next) { + DEBUG_VALIDATE_BUG_ON(x->port > x->port2); + cnt++; + } + if (cnt <= 1) + return; + + /* build temporary array to sort with qsort */ + DetectPort **array = (DetectPort **)SCCalloc(cnt, sizeof(DetectPort *)); + if (array == NULL) + return; + + int idx = 0; + for (DetectPort *x = *list; x != NULL;) { + /* assign a temporary id to resolve otherwise equal groups */ + x->sh->id = idx + 1; + SigGroupHeadSetSigCnt(x->sh, 0); + DetectPort *next = x->next; + x->next = x->prev = x->last = NULL; + DEBUG_VALIDATE_BUG_ON(x->port > x->port2); + array[idx++] = x; + x = next; + } + DEBUG_VALIDATE_BUG_ON(cnt != idx); + + qsort(array, idx, sizeof(DetectPort *), SortCompare); + + /* rebuild the list based on the qsort-ed array */ + DetectPort *new_list = NULL, *tail = NULL; + for (int i = 0; i < idx; i++) { + DetectPort *p = array[i]; + /* unset temporary group id */ + p->sh->id = 0; + + if (new_list == NULL) { + new_list = p; + } + if (tail != NULL) { + tail->next = p; + } + p->prev = tail; + tail = p; + } + + *list = new_list; + *groups = idx; + +#if DEBUG + int dbgcnt = 0; + SCLogDebug("SORTED LIST:"); + for (DetectPort *tmp = *list; tmp != NULL; tmp = tmp->next) { + SCLogDebug("item:= [%u:%u]; score: %d; sig_cnt: %d", tmp->port, tmp->port2, + tmp->sh->init->score, tmp->sh->init->sig_cnt); + dbgcnt++; + BUG_ON(dbgcnt > cnt); + } +#endif + SCFree(array); +} +/** \internal + * \brief Create a list of DetectPort objects sorted based on CompareFunc's + * logic. + * + * List can limit the number of groups. In this case an extra "join" group + * is created that contains the sigs belonging to that. It's *appended* to + * the list, meaning that if the list is walked linearly it's found last. + * The joingr is meant to be a catch all. + * + */ +static int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list, + DetectPort **newhead, uint32_t unique_groups, + int (*CompareFunc)(const void *, const void *)) +{ + DetectPort *tmplist = NULL, *joingr = NULL; + uint32_t groups = 0; + + /* insert the ports into the tmplist, where it will + * be sorted descending on 'cnt' and on whether a group + * is whitelisted. */ + tmplist = port_list; + SortGroupList(&groups, &tmplist, SortCompare); + uint32_t left = unique_groups; + if (left == 0) + left = groups; + + /* create another list: take the port groups from above + * and add them to the 2nd list until we have met our + * count. The rest is added to the 'join' group. */ + DetectPort *tmplist2 = NULL, *tmplist2_tail = NULL; + DetectPort *gr, *next_gr; + for (gr = tmplist; gr != NULL;) { + next_gr = gr->next; + + SCLogDebug("temp list gr %p %u:%u", gr, gr->port, gr->port2); + DetectPortPrint(gr); + + /* if we've set up all the unique groups, add the rest to the + * catch-all joingr */ + if (left == 0) { + if (joingr == NULL) { + DetectPortParse(de_ctx, &joingr, "0:65535"); + if (joingr == NULL) { + goto error; + } + SCLogDebug("joingr => %u-%u", joingr->port, joingr->port2); + joingr->next = NULL; + } + SigGroupHeadCopySigs(de_ctx, gr->sh, &joingr->sh); + + /* when a group's sigs are added to the joingr, we can free it */ + gr->next = NULL; + DetectPortFree(de_ctx, gr); + /* append */ + } else { + gr->next = NULL; + + if (tmplist2 == NULL) { + tmplist2 = gr; + tmplist2_tail = gr; + } else { + tmplist2_tail->next = gr; + tmplist2_tail = gr; + } + } + + if (left > 0) + left--; + + gr = next_gr; + } + + /* if present, append the joingr that covers the rest */ + if (joingr != NULL) { + SCLogDebug("appending joingr %p %u:%u", joingr, joingr->port, joingr->port2); + + if (tmplist2 == NULL) { + tmplist2 = joingr; + // tmplist2_tail = joingr; + } else { + tmplist2_tail->next = joingr; + // tmplist2_tail = joingr; + } + } else { + SCLogDebug("no joingr"); + } + + /* pass back our new list to the caller */ + *newhead = tmplist2; + DetectPortPrintList(*newhead); + + return 0; +error: + return -1; +} #define RANGE_PORT 1 #define SINGLE_PORT 2 @@ -1388,7 +1570,7 @@ static DetectPort *RulesGroupByPorts(DetectEngineCtx *de_ctx, uint8_t ipproto, u DetectPort *newlist = NULL; uint16_t groupmax = (direction == SIG_FLAG_TOCLIENT) ? de_ctx->max_uniq_toclient_groups : de_ctx->max_uniq_toserver_groups; - CreateGroupedPortList(de_ctx, list, &newlist, groupmax, CreateGroupedPortListCmpCnt); + CreateGroupedPortList(de_ctx, list, &newlist, groupmax, SortCompare); list = newlist; /* step 4: deduplicate the SGH's */ @@ -1668,179 +1850,6 @@ error: return -1; } -static int PortGroupWhitelist(const DetectPort *a) -{ - return a->sh->init->score; -} - -int CreateGroupedPortListCmpCnt(DetectPort *a, DetectPort *b) -{ - if (PortGroupWhitelist(a) && !PortGroupWhitelist(b)) { - SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port, a->port2, - a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, - b->sh->init->sig_cnt, PortGroupWhitelist(b)); - return 1; - } else if (!PortGroupWhitelist(a) && PortGroupWhitelist(b)) { - SCLogDebug("%u:%u (cnt %u, wl %d) loses against %u:%u (cnt %u, wl %d)", a->port, a->port2, - a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, - b->sh->init->sig_cnt, PortGroupWhitelist(b)); - return 0; - } else if (PortGroupWhitelist(a) > PortGroupWhitelist(b)) { - SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port, a->port2, - a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, - b->sh->init->sig_cnt, PortGroupWhitelist(b)); - return 1; - } else if (PortGroupWhitelist(a) == PortGroupWhitelist(b)) { - if (a->sh->init->sig_cnt > b->sh->init->sig_cnt) { - SCLogDebug("%u:%u (cnt %u, wl %d) wins against %u:%u (cnt %u, wl %d)", a->port, - a->port2, a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, - b->sh->init->sig_cnt, PortGroupWhitelist(b)); - return 1; - } - } - - SCLogDebug("%u:%u (cnt %u, wl %d) loses against %u:%u (cnt %u, wl %d)", a->port, a->port2, - a->sh->init->sig_cnt, PortGroupWhitelist(a), b->port, b->port2, b->sh->init->sig_cnt, - PortGroupWhitelist(b)); - return 0; -} - -/** \internal - * \brief Create a list of DetectPort objects sorted based on CompareFunc's - * logic. - * - * List can limit the number of groups. In this case an extra "join" group - * is created that contains the sigs belonging to that. It's *appended* to - * the list, meaning that if the list is walked linearly it's found last. - * The joingr is meant to be a catch all. - * - */ -int CreateGroupedPortList(DetectEngineCtx *de_ctx, DetectPort *port_list, DetectPort **newhead, - uint32_t unique_groups, int (*CompareFunc)(DetectPort *, DetectPort *)) -{ - DetectPort *tmplist = NULL, *joingr = NULL; - char insert = 0; - uint32_t groups = 0; - DetectPort *list; - - /* insert the ports into the tmplist, where it will - * be sorted descending on 'cnt' and on whether a group - * is whitelisted. */ - - DetectPort *oldhead = port_list; - while (oldhead) { - /* take the top of the list */ - list = oldhead; - oldhead = oldhead->next; - list->next = NULL; - - groups++; - SigGroupHeadSetSigCnt(list->sh, 0); - - /* insert it */ - DetectPort *tmpgr = tmplist, *prevtmpgr = NULL; - if (tmplist == NULL) { - /* empty list, set head */ - tmplist = list; - } else { - /* look for the place to insert */ - for ( ; tmpgr != NULL && !insert; tmpgr = tmpgr->next) { - if (CompareFunc(list, tmpgr) == 1) { - if (tmpgr == tmplist) { - list->next = tmplist; - tmplist = list; - SCLogDebug("new list top: %u:%u", tmplist->port, tmplist->port2); - } else { - list->next = prevtmpgr->next; - prevtmpgr->next = list; - } - insert = 1; - break; - } - prevtmpgr = tmpgr; - } - if (insert == 0) { - list->next = NULL; - prevtmpgr->next = list; - } - insert = 0; - } - } - - uint32_t left = unique_groups; - if (left == 0) - left = groups; - - /* create another list: take the port groups from above - * and add them to the 2nd list until we have met our - * count. The rest is added to the 'join' group. */ - DetectPort *tmplist2 = NULL, *tmplist2_tail = NULL; - DetectPort *gr, *next_gr; - for (gr = tmplist; gr != NULL; ) { - next_gr = gr->next; - - SCLogDebug("temp list gr %p %u:%u", gr, gr->port, gr->port2); - DetectPortPrint(gr); - - /* if we've set up all the unique groups, add the rest to the - * catch-all joingr */ - if (left == 0) { - if (joingr == NULL) { - DetectPortParse(de_ctx, &joingr, "0:65535"); - if (joingr == NULL) { - goto error; - } - SCLogDebug("joingr => %u-%u", joingr->port, joingr->port2); - joingr->next = NULL; - } - SigGroupHeadCopySigs(de_ctx,gr->sh,&joingr->sh); - - /* when a group's sigs are added to the joingr, we can free it */ - gr->next = NULL; - DetectPortFree(de_ctx, gr); - /* append */ - } else { - gr->next = NULL; - - if (tmplist2 == NULL) { - tmplist2 = gr; - tmplist2_tail = gr; - } else { - tmplist2_tail->next = gr; - tmplist2_tail = gr; - } - } - - if (left > 0) - left--; - - gr = next_gr; - } - - /* if present, append the joingr that covers the rest */ - if (joingr != NULL) { - SCLogDebug("appending joingr %p %u:%u", joingr, joingr->port, joingr->port2); - - if (tmplist2 == NULL) { - tmplist2 = joingr; - //tmplist2_tail = joingr; - } else { - tmplist2_tail->next = joingr; - //tmplist2_tail = joingr; - } - } else { - SCLogDebug("no joingr"); - } - - /* pass back our new list to the caller */ - *newhead = tmplist2; - DetectPortPrintList(*newhead); - - return 0; -error: - return -1; -} - /** * \internal * \brief add a decoder event signature to the detection engine ctx