]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect: use generic integer functions for bsize
authorPhilippe Antoine <pantoine@oisf.net>
Thu, 9 Jun 2022 12:53:46 +0000 (14:53 +0200)
committerVictor Julien <vjulien@oisf.net>
Fri, 1 Jul 2022 15:04:08 +0000 (17:04 +0200)
Ticket: #4112

src/detect-bsize.c
src/tests/detect-bsize.c

index d3908514592c7bc270cf344f2f03e2abb17a7ad2..6cb63a6e00faece5df8779acffd476eb65497d76 100644 (file)
@@ -31,6 +31,7 @@
 #include "detect-parse.h"
 #include "detect-engine.h"
 #include "detect-content.h"
+#include "detect-engine-uint.h"
 
 #include "detect-bsize.h"
 
@@ -60,17 +61,6 @@ void DetectBsizeRegister(void)
 #endif
 }
 
-#define DETECT_BSIZE_LT 0
-#define DETECT_BSIZE_GT 1
-#define DETECT_BSIZE_RA 2
-#define DETECT_BSIZE_EQ 3
-
-typedef struct DetectBsizeData {
-    uint8_t mode;
-    uint64_t lo;
-    uint64_t hi;
-} DetectBsizeData;
-
 /** \brief bsize match function
  *
  *  \param ctx match ctx
@@ -78,31 +68,29 @@ typedef struct DetectBsizeData {
  *  \param eof is the buffer closed?
  *
  *  \retval r 1 match, 0 no match, -1 can't match
- *
- *  \todo check logic around < vs <=
  */
 int DetectBsizeMatch(const SigMatchCtx *ctx, const uint64_t buffer_size, bool eof)
 {
-    const DetectBsizeData *bsz = (const DetectBsizeData *)ctx;
+    const DetectU64Data *bsz = (const DetectU64Data *)ctx;
+    if (DetectU64Match(buffer_size, bsz)) {
+        return 1;
+    }
     switch (bsz->mode) {
-        case DETECT_BSIZE_LT:
-            if (buffer_size < bsz->lo) {
-                return 1;
-            }
+        case DETECT_UINT_LTE:
+            return -1;
+        case DETECT_UINT_LT:
             return -1;
 
-        case DETECT_BSIZE_GT:
-            if (buffer_size > bsz->lo) {
-                return 1;
-            } else if (eof) {
+        case DETECT_UINT_GTE:
+            // fallthrough
+        case DETECT_UINT_GT:
+            if (eof) {
                 return -1;
             }
             return 0;
 
-        case DETECT_BSIZE_EQ:
-            if (buffer_size == bsz->lo) {
-                return 1;
-            } else if (buffer_size > bsz->lo) {
+        case DETECT_UINT_EQ:
+            if (buffer_size > bsz->arg1) {
                 return -1;
             } else if (eof) {
                 return -1;
@@ -110,160 +98,30 @@ int DetectBsizeMatch(const SigMatchCtx *ctx, const uint64_t buffer_size, bool eo
                 return 0;
             }
 
-        case DETECT_BSIZE_RA:
-            if (buffer_size > bsz->lo && buffer_size < bsz->hi) {
-                return 1;
-            } else if (buffer_size <= bsz->lo && eof) {
+        case DETECT_UINT_RA:
+            if (buffer_size <= bsz->arg1 && eof) {
                 return -1;
-            } else if (buffer_size <= bsz->lo) {
+            } else if (buffer_size <= bsz->arg1) {
                 return 0;
-            } else if (buffer_size >= bsz->hi) {
+            } else if (buffer_size >= bsz->arg2) {
                 return -1;
             }
     }
     return 0;
 }
 
-#define ERR(...) do { \
-    char _buf[2048];              \
-    snprintf(_buf, sizeof(_buf), __VA_ARGS__);  \
-    SCLogError(SC_ERR_INVALID_RULE_ARGUMENT, "bsize: bad input, %s", _buf); \
-} while(0)
-
 /**
  * \brief This function is used to parse bsize options passed via bsize: keyword
  *
  * \param bsizestr Pointer to the user provided bsize options
  *
- * \retval bsized pointer to DetectBsizeData on success
+ * \retval bsized pointer to DetectU64Data on success
  * \retval NULL on failure
  */
 
-static DetectBsizeData *DetectBsizeParse (const char *str)
+static DetectU64Data *DetectBsizeParse(const char *str)
 {
-    uint32_t lo = 0;
-    uint32_t hi = 0;
-
-    if (str == NULL)
-        return NULL;
-
-    size_t len = strlen(str);
-    if (len == 0)
-        return NULL;
-
-    /* allow for leading spaces */
-    while (isspace(*str))
-        (str++);
-    len = strlen(str);
-    if (len == 0)
-        return NULL;
-
-    int mode = DETECT_BSIZE_EQ;
-    switch (*str) {
-        case '>':
-            mode = DETECT_BSIZE_GT;
-            str++;
-            break;
-        case '<':
-            mode = DETECT_BSIZE_LT;
-            str++;
-            break;
-    }
-
-    /* allow for spaces between mode and value */
-    while (isspace(*str))
-        (str++);
-
-    char str1[11], *p = str1;
-    memset(str1, 0, sizeof(str1));
-    while (*str && isdigit(*str)) {
-        if (p - str1 >= ((int)sizeof(str1) - 1))
-            return NULL;
-        *p++ = *str++;
-    }
-    /* skip trailing space */
-    while (*str && isspace(*str)) {
-        str++;
-    }
-    if (*str == '\0') {
-        // done
-        SCLogDebug("str1 '%s'", str1);
-
-        uint64_t val = 0;
-        if (ParseSizeStringU64(str1, &val) < 0) {
-            return NULL;
-        }
-        lo = val;
-
-    } else if (*str == '<') {
-        str++;
-        if (*str != '>') {
-            ERR("only '<>' allowed");
-            return NULL;
-        }
-        str++;
-
-        // range
-        if (mode != DETECT_BSIZE_EQ) {
-            ERR("mode already set");
-            return NULL;
-        }
-        mode = DETECT_BSIZE_RA;
-
-        uint64_t val = 0;
-        if (ParseSizeStringU64(str1, &val) < 0) {
-            return NULL;
-        }
-        lo = val;
-
-        /* allow for spaces between mode and value */
-        while (*str && isspace(*str))
-            (str++);
-
-        char str2[11];
-        p = str2;
-        memset(str2, 0, sizeof(str2));
-        while (*str && isdigit(*str)) {
-            if (p - str2 >= ((int)sizeof(str2) - 1))
-                return NULL;
-            *p++ = *str++;
-        }
-        /* skip trailing space */
-        while (*str && isspace(*str)) {
-            str++;
-        }
-        if (*str == '\0') {
-            // done
-            SCLogDebug("str2 '%s'", str2);
-
-            if (ParseSizeStringU64(str2, &val) < 0) {
-                ERR("'%s' is not a valid u32", str2);
-                return NULL;
-            }
-            hi = val;
-            if (lo >= hi) {
-                ERR("%u > %u", lo, hi);
-                return NULL;
-            }
-
-        } else {
-            ERR("trailing data");
-            return NULL;
-        }
-
-    } else {
-        ERR("'%s'", str);
-        return NULL;
-    }
-
-    DetectBsizeData *bsz = SCCalloc(1, sizeof(*bsz));
-    if (bsz == NULL) {
-        return NULL;
-    }
-    bsz->mode = (uint8_t)mode;
-    bsz->lo = lo;
-    bsz->hi = hi;
-    return bsz;
+    return DetectU64Parse(str);
 }
 
 /**
@@ -288,7 +146,7 @@ static int DetectBsizeSetup (DetectEngineCtx *de_ctx, Signature *s, const char *
     if (list == DETECT_SM_LIST_NOTSET)
         SCReturnInt(-1);
 
-    DetectBsizeData *bsz = DetectBsizeParse(sizestr);
+    DetectU64Data *bsz = DetectBsizeParse(sizestr);
     if (bsz == NULL)
         goto error;
     sm = SigMatchAlloc();
@@ -307,17 +165,17 @@ error:
 }
 
 /**
- * \brief this function will free memory associated with DetectBsizeData
+ * \brief this function will free memory associated with DetectU64Data
  *
- * \param ptr pointer to DetectBsizeData
+ * \param ptr pointer to DetectU64Data
  */
 void DetectBsizeFree(DetectEngineCtx *de_ctx, void *ptr)
 {
     if (ptr == NULL)
         return;
 
-    DetectBsizeData *bsz = (DetectBsizeData *)ptr;
-    SCFree(bsz);
+    DetectU64Data *bsz = (DetectU64Data *)ptr;
+    rs_detect_u64_free(bsz);
 }
 
 #ifdef UNITTESTS
index 1e6b6cf68ad57d57f58bff411b8fb6023dbaf305..6c317427a9c998b3e78cd9aeeeb3eb74d060c5cc 100644 (file)
 
 #include "../util-unittest.h"
 
-#define TEST_OK(str, m, lo, hi) {                       \
-    DetectBsizeData *bsz = DetectBsizeParse((str));     \
-    FAIL_IF_NULL(bsz);                                  \
-    FAIL_IF_NOT(bsz->mode == (m));                      \
-    DetectBsizeFree(NULL, bsz);                         \
-    SCLogDebug("str %s OK", (str));                     \
-}
-#define TEST_FAIL(str) {                                \
-    DetectBsizeData *bsz = DetectBsizeParse((str));     \
-    FAIL_IF_NOT_NULL(bsz);                              \
-}
+#define TEST_OK(str, m, lo, hi)                                                                    \
+    {                                                                                              \
+        DetectU64Data *bsz = DetectBsizeParse((str));                                              \
+        FAIL_IF_NULL(bsz);                                                                         \
+        FAIL_IF_NOT(bsz->mode == (m));                                                             \
+        DetectBsizeFree(NULL, bsz);                                                                \
+        SCLogDebug("str %s OK", (str));                                                            \
+    }
+#define TEST_FAIL(str)                                                                             \
+    {                                                                                              \
+        DetectU64Data *bsz = DetectBsizeParse((str));                                              \
+        FAIL_IF_NOT_NULL(bsz);                                                                     \
+    }
 
 static int DetectBsizeTest01(void)
 {
-    TEST_OK("50", DETECT_BSIZE_EQ, 50, 0);
-    TEST_OK(" 50", DETECT_BSIZE_EQ, 50, 0);
-    TEST_OK("  50", DETECT_BSIZE_EQ, 50, 0);
-    TEST_OK("  50 ", DETECT_BSIZE_EQ, 50, 0);
-    TEST_OK("  50  ", DETECT_BSIZE_EQ, 50, 0);
+    TEST_OK("50", DETECT_UINT_EQ, 50, 0);
+    TEST_OK(" 50", DETECT_UINT_EQ, 50, 0);
+    TEST_OK("  50", DETECT_UINT_EQ, 50, 0);
+    TEST_OK("  50 ", DETECT_UINT_EQ, 50, 0);
+    TEST_OK("  50  ", DETECT_UINT_EQ, 50, 0);
 
     TEST_FAIL("AA");
     TEST_FAIL("5A");
     TEST_FAIL("A5");
-    TEST_FAIL("10000000001");
-    TEST_OK("  1000000001  ", DETECT_BSIZE_EQ, 1000000001, 0);
+    // bigger than UINT64_MAX
+    TEST_FAIL("100000000000000000001");
+    TEST_OK("  1000000001  ", DETECT_UINT_EQ, 1000000001, 0);
     PASS;
 }
 
 static int DetectBsizeTest02(void)
 {
-    TEST_OK(">50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK("> 50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(">  50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" >50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" > 50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" >  50", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" >50 ", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" > 50  ", DETECT_BSIZE_GT, 50, 0);
-    TEST_OK(" >  50   ", DETECT_BSIZE_GT, 50, 0);
+    TEST_OK(">50", DETECT_UINT_GT, 50, 0);
+    TEST_OK("> 50", DETECT_UINT_GT, 50, 0);
+    TEST_OK(">  50", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" >50", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" > 50", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" >  50", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" >50 ", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" > 50  ", DETECT_UINT_GT, 50, 0);
+    TEST_OK(" >  50   ", DETECT_UINT_GT, 50, 0);
 
     TEST_FAIL(">>50");
     TEST_FAIL("<>50");
@@ -65,15 +68,15 @@ static int DetectBsizeTest02(void)
 
 static int DetectBsizeTest03(void)
 {
-    TEST_OK("<50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK("< 50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK("<  50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" <50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" < 50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" <  50", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" <50 ", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" < 50  ", DETECT_BSIZE_LT, 50, 0);
-    TEST_OK(" <  50   ", DETECT_BSIZE_LT, 50, 0);
+    TEST_OK("<50", DETECT_UINT_LT, 50, 0);
+    TEST_OK("< 50", DETECT_UINT_LT, 50, 0);
+    TEST_OK("<  50", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" <50", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" < 50", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" <  50", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" <50 ", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" < 50  ", DETECT_UINT_LT, 50, 0);
+    TEST_OK(" <  50   ", DETECT_UINT_LT, 50, 0);
 
     TEST_FAIL(">>50");
     TEST_FAIL(" < 50A");
@@ -82,7 +85,7 @@ static int DetectBsizeTest03(void)
 
 static int DetectBsizeTest04(void)
 {
-    TEST_OK("50<>100", DETECT_BSIZE_RA, 50, 100);
+    TEST_OK("50<>100", DETECT_UINT_RA, 50, 100);
 
     TEST_FAIL("50<$50");
     TEST_FAIL("100<>50");