]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect: use generic integer functions for streamsize
authorPhilippe Antoine <contact@catenacyber.fr>
Mon, 25 Apr 2022 15:59:00 +0000 (17:59 +0200)
committerVictor Julien <vjulien@oisf.net>
Thu, 2 Jun 2022 05:39:33 +0000 (07:39 +0200)
By the way, adds the prefilter feature

Ticket: #2697
Ticket: #4112

rust/src/detect.rs
src/detect-stream_size.c
src/detect-stream_size.h

index 27c53671720f3f03314bde552f73c28cf1da2b1f..6bc0ed0f3b24c1f78534331772d0bf83212b1cd5 100644 (file)
 
 use nom7::branch::alt;
 use nom7::bytes::complete::{is_a, tag, take_while};
-use nom7::character::complete::digit1;
-use nom7::combinator::{all_consuming, map_opt, opt, value, verify};
+use nom7::character::complete::{alpha0, char, digit1};
+use nom7::combinator::{all_consuming, map_opt, map_res, opt, value, verify};
 use nom7::error::{make_error, ErrorKind};
 use nom7::Err;
 use nom7::IResult;
 
 use std::ffi::CStr;
+use std::str::FromStr;
 
 #[derive(PartialEq, Clone, Debug)]
 #[repr(u8)]
@@ -92,7 +93,9 @@ fn detect_parse_uint_mode(i: &str) -> IResult<&str, DetectUintMode> {
         value(DetectUintMode::DetectUintModeLte, tag("<=")),
         value(DetectUintMode::DetectUintModeGt, tag(">")),
         value(DetectUintMode::DetectUintModeLt, tag("<")),
+        value(DetectUintMode::DetectUintModeNe, tag("!=")),
         value(DetectUintMode::DetectUintModeNe, tag("!")),
+        value(DetectUintMode::DetectUintModeEqual, tag("=")),
     ))(i)?;
     return Ok((i, mode));
 }
@@ -310,3 +313,78 @@ pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData<u16>) {
     // Just unbox...
     std::mem::drop(Box::from_raw(ctx));
 }
+
+#[repr(u8)]
+#[derive(Clone, Copy, PartialEq, FromPrimitive, Debug)]
+pub enum DetectStreamSizeDataFlags {
+    StreamSizeServer = 1,
+    StreamSizeClient = 2,
+    StreamSizeBoth = 3,
+    StreamSizeEither = 4,
+}
+
+impl std::str::FromStr for DetectStreamSizeDataFlags {
+    type Err = String;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        match s {
+            "server" => Ok(DetectStreamSizeDataFlags::StreamSizeServer),
+            "client" => Ok(DetectStreamSizeDataFlags::StreamSizeClient),
+            "both" => Ok(DetectStreamSizeDataFlags::StreamSizeBoth),
+            "either" => Ok(DetectStreamSizeDataFlags::StreamSizeEither),
+            _ => Err(format!(
+                "'{}' is not a valid value for DetectStreamSizeDataFlags",
+                s
+            )),
+        }
+    }
+}
+
+#[derive(Debug)]
+#[repr(C)]
+pub struct DetectStreamSizeData {
+    pub flags: DetectStreamSizeDataFlags,
+    pub du32: DetectUintData<u32>,
+}
+
+pub fn detect_parse_stream_size(i: &str) -> IResult<&str, DetectStreamSizeData> {
+    let (i, _) = opt(is_a(" "))(i)?;
+    let (i, flags) = map_res(alpha0, |s: &str| {
+        DetectStreamSizeDataFlags::from_str(s)
+    })(i)?;
+    let (i, _) = opt(is_a(" "))(i)?;
+    let (i, _) = char(',')(i)?;
+    let (i, _) = opt(is_a(" "))(i)?;
+    let (i, mode) = detect_parse_uint_mode(i)?;
+    let (i, _) = opt(is_a(" "))(i)?;
+    let (i, _) = char(',')(i)?;
+    let (i, _) = opt(is_a(" "))(i)?;
+    let (i, arg1) = map_opt(digit1, |s: &str| s.parse::<u32>().ok())(i)?;
+    let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?;
+    let du32 = DetectUintData::<u32> {
+        arg1: arg1,
+        arg2: 0,
+        mode: mode,
+    };
+    Ok((i, DetectStreamSizeData { flags, du32 }))
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn rs_detect_stream_size_parse(
+    ustr: *const std::os::raw::c_char,
+) -> *mut DetectStreamSizeData {
+    let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe
+    if let Ok(s) = ft_name.to_str() {
+        if let Ok((_, ctx)) = detect_parse_stream_size(s) {
+            let boxed = Box::new(ctx);
+            return Box::into_raw(boxed) as *mut _;
+        }
+    }
+    return std::ptr::null_mut();
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn rs_detect_stream_size_free(ctx: &mut DetectStreamSizeData) {
+    // Just unbox...
+    std::mem::drop(Box::from_raw(ctx));
+}
index 689732562228ba0eb891748b53426576c669bb63..86aabd77c8e8736572e9dbd8429d22ac0f679ae8 100644 (file)
 #include "flow.h"
 #include "detect-stream_size.h"
 #include "stream-tcp-private.h"
+#include "detect-engine-prefilter-common.h"
+#include "detect-engine-uint.h"
 #include "util-debug.h"
 #include "util-byte.h"
 
-/**
- * \brief Regex for parsing our flow options
- */
-#define PARSE_REGEX  "^\\s*([A-z_]+)\\s*,\\s*([<=>!]+)\\s*,\\s*([0-9]+)\\s*$"
-
-static DetectParseRegex parse_regex;
 
 /*prototypes*/
 static int DetectStreamSizeMatch (DetectEngineThreadCtx *, Packet *,
@@ -51,6 +47,8 @@ void DetectStreamSizeFree(DetectEngineCtx *de_ctx, void *);
 #ifdef UNITTESTS
 static void DetectStreamSizeRegisterTests(void);
 #endif
+static int PrefilterSetupStreamSize(DetectEngineCtx *de_ctx, SigGroupHead *sgh);
+static bool PrefilterStreamSizeIsPrefilterable(const Signature *s);
 
 /**
  * \brief Registration function for stream_size: keyword
@@ -67,217 +65,69 @@ void DetectStreamSizeRegister(void)
 #ifdef UNITTESTS
     sigmatch_table[DETECT_STREAM_SIZE].RegisterTests = DetectStreamSizeRegisterTests;
 #endif
-    DetectSetupParseRegexes(PARSE_REGEX, &parse_regex);
+    sigmatch_table[DETECT_STREAM_SIZE].SupportsPrefilter = PrefilterStreamSizeIsPrefilterable;
+    sigmatch_table[DETECT_STREAM_SIZE].SetupPrefilter = PrefilterSetupStreamSize;
 }
 
-/**
- * \brief Function to comapre the stream size against defined size in the user
- *  options.
- *
- *  \param  diff    The stream size of server or client stream.
- *  \param  stream_size User defined stream size
- *  \param  mode    The mode defined by user.
- *
- *  \retval 1 on success and 0 on failure.
- */
-
-static int DetectStreamSizeCompare (uint32_t diff, uint32_t stream_size, uint8_t mode)
+static int DetectStreamSizeMatchAux(const DetectStreamSizeData *sd, const TcpSession *ssn)
 {
-    SCLogDebug("diff %u stream_size %u mode %u", diff, stream_size, mode);
-
-    int ret = 0;
-    switch (mode) {
-        case DETECTSSIZE_LT:
-            if (diff < stream_size)
-                ret = 1;
-            break;
-        case DETECTSSIZE_LEQ:
-            if (diff <= stream_size)
-                ret = 1;
-            break;
-        case DETECTSSIZE_EQ:
-            if (diff == stream_size)
-                ret = 1;
-            break;
-        case DETECTSSIZE_NEQ:
-            if (diff != stream_size)
-                ret = 1;
-            break;
-        case DETECTSSIZE_GEQ:
-            if (diff >= stream_size)
-                ret = 1;
-            break;
-        case DETECTSSIZE_GT:
-            if (diff > stream_size)
-                ret = 1;
-            break;
-    }
-
-    SCReturnInt(ret);
-}
-
-/**
- * \brief This function is used to match Stream size rule option on a packet with those passed via stream_size:
- *
- * \param t pointer to thread vars
- * \param det_ctx pointer to the pattern matcher thread
- * \param p pointer to the current packet
- * \param m pointer to the sigmatch that we will cast into DetectStreamSizeData
- *
- * \retval 0 no match
- * \retval 1 match
- */
-static int DetectStreamSizeMatch (DetectEngineThreadCtx *det_ctx, Packet *p,
-        const Signature *s, const SigMatchCtx *ctx)
-{
-
-    const DetectStreamSizeData *sd = (const DetectStreamSizeData *)ctx;
-
-    if (!(PKT_IS_TCP(p)))
-        return 0;
-    if (p->flow == NULL || p->flow->protoctx == NULL)
-        return 0;
-
-    const TcpSession *ssn = (TcpSession *)p->flow->protoctx;
     int ret = 0;
     uint32_t csdiff = 0;
     uint32_t ssdiff = 0;
 
-    if (sd->flags & STREAM_SIZE_SERVER) {
+    if (sd->flags == StreamSizeServer) {
         /* get the server stream size */
         ssdiff = ssn->server.next_seq - ssn->server.isn;
-        ret = DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode);
+        ret = DetectU32Match(ssdiff, &sd->du32);
 
-    } else if (sd->flags & STREAM_SIZE_CLIENT) {
+    } else if (sd->flags == StreamSizeClient) {
         /* get the client stream size */
         csdiff = ssn->client.next_seq - ssn->client.isn;
-        ret = DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode);
+        ret = DetectU32Match(csdiff, &sd->du32);
 
-    } else if (sd->flags & STREAM_SIZE_BOTH) {
+    } else if (sd->flags == StreamSizeBoth) {
         ssdiff = ssn->server.next_seq - ssn->server.isn;
         csdiff = ssn->client.next_seq - ssn->client.isn;
 
-        if (DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode) &&
-            DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode))
+        if (DetectU32Match(ssdiff, &sd->du32) && DetectU32Match(csdiff, &sd->du32))
             ret = 1;
 
-    } else if (sd->flags & STREAM_SIZE_EITHER) {
+    } else if (sd->flags == StreamSizeEither) {
         ssdiff = ssn->server.next_seq - ssn->server.isn;
         csdiff = ssn->client.next_seq - ssn->client.isn;
 
-        if (DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode) ||
-            DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode))
+        if (DetectU32Match(ssdiff, &sd->du32) || DetectU32Match(csdiff, &sd->du32))
             ret = 1;
     }
-
-    SCReturnInt(ret);
+    return ret;
 }
 
 /**
- * \brief This function is used to parse stream options passed via stream_size: keyword
+ * \brief This function is used to match Stream size rule option on a packet with those passed via
+ * stream_size:
  *
- * \param de_ctx Pointer to the detection engine context
- * \param streamstr Pointer to the user provided stream_size options
+ * \param t pointer to thread vars
+ * \param det_ctx pointer to the pattern matcher thread
+ * \param p pointer to the current packet
+ * \param m pointer to the sigmatch that we will cast into DetectStreamSizeData
  *
- * \retval sd pointer to DetectStreamSizeData on success
- * \retval NULL on failure
+ * \retval 0 no match
+ * \retval 1 match
  */
-static DetectStreamSizeData *DetectStreamSizeParse (DetectEngineCtx *de_ctx, const char *streamstr)
+static int DetectStreamSizeMatch(
+        DetectEngineThreadCtx *det_ctx, Packet *p, const Signature *s, const SigMatchCtx *ctx)
 {
-    DetectStreamSizeData *sd = NULL;
-    char *arg = NULL;
-    char *value = NULL;
-    char *mode = NULL;
-    int res = 0;
-    size_t pcre2_len;
-
-    int ret = DetectParsePcreExec(&parse_regex, streamstr, 0, 0);
-    if (ret != 4) {
-        SCLogError(SC_ERR_PCRE_MATCH, "pcre_exec parse error, ret %" PRId32 ", string %s", ret, streamstr);
-        goto error;
-    }
 
-    const char *str_ptr;
-    res = pcre2_substring_get_bynumber(parse_regex.match, 1, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len);
-    if (res < 0) {
-        SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed");
-        goto error;
-    }
-    arg = (char *)str_ptr;
-
-    res = pcre2_substring_get_bynumber(parse_regex.match, 2, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len);
-    if (res < 0) {
-        SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed");
-        goto error;
-    }
-    mode = (char *)str_ptr;
+    const DetectStreamSizeData *sd = (const DetectStreamSizeData *)ctx;
 
-    res = pcre2_substring_get_bynumber(parse_regex.match, 3, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len);
-    if (res < 0) {
-        SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed");
-        goto error;
-    }
-    value = (char *)str_ptr;
-
-    sd = SCMalloc(sizeof(DetectStreamSizeData));
-    if (unlikely(sd == NULL))
-        goto error;
-    sd->ssize = 0;
-    sd->flags = 0;
-
-    if (strlen(mode) == 0)
-        goto error;
-
-    if (mode[0] == '=') {
-        sd->mode = DETECTSSIZE_EQ;
-    } else if (mode[0] == '<') {
-        sd->mode = DETECTSSIZE_LT;
-        if (strcmp("<=", mode) == 0)
-            sd->mode = DETECTSSIZE_LEQ;
-    } else if (mode[0] == '>') {
-        sd->mode = DETECTSSIZE_GT;
-        if (strcmp(">=", mode) == 0)
-            sd->mode = DETECTSSIZE_GEQ;
-    } else if (strcmp("!=", mode) == 0) {
-        sd->mode = DETECTSSIZE_NEQ;
-    } else {
-        SCLogError(SC_ERR_INVALID_OPERATOR, "Invalid operator");
-        goto error;
-    }
+    if (!(PKT_IS_TCP(p)))
+        return 0;
+    if (p->flow == NULL || p->flow->protoctx == NULL)
+        return 0;
 
-    /* set the value */
-    if (StringParseUint32(&sd->ssize, 10, 0, (const char *)value) < 0) {
-        SCLogError(SC_ERR_INVALID_VALUE, "Invalid value for stream size: %s", value);
-        goto error;
-    }
-    /* inspect our options and set the flags */
-    if (strcmp(arg, "server") == 0) {
-        sd->flags |= STREAM_SIZE_SERVER;
-    } else if (strcmp(arg, "client") == 0) {
-        sd->flags |= STREAM_SIZE_CLIENT;
-    } else if ((strcmp(arg, "both") == 0)) {
-        sd->flags |= STREAM_SIZE_BOTH;
-    } else if (strcmp(arg, "either") == 0) {
-        sd->flags |= STREAM_SIZE_EITHER;
-    } else {
-        goto error;
-    }
+    const TcpSession *ssn = (TcpSession *)p->flow->protoctx;
 
-    pcre2_substring_free((PCRE2_UCHAR8 *)mode);
-    pcre2_substring_free((PCRE2_UCHAR8 *)arg);
-    pcre2_substring_free((PCRE2_UCHAR8 *)value);
-    return sd;
-
-error:
-    if (mode != NULL)
-        pcre2_substring_free((PCRE2_UCHAR8 *)mode);
-    if (arg != NULL)
-        pcre2_substring_free((PCRE2_UCHAR8 *)arg);
-    if (value != NULL)
-        pcre2_substring_free((PCRE2_UCHAR8 *)value);
-    if (sd != NULL)
-        DetectStreamSizeFree(de_ctx, sd);
-    return NULL;
+    SCReturnInt(DetectStreamSizeMatchAux(sd, ssn));
 }
 
 /**
@@ -292,7 +142,7 @@ error:
  */
 static int DetectStreamSizeSetup (DetectEngineCtx *de_ctx, Signature *s, const char *streamstr)
 {
-    DetectStreamSizeData *sd = DetectStreamSizeParse(de_ctx, streamstr);
+    DetectStreamSizeData *sd = rs_detect_stream_size_parse(streamstr);
     if (sd == NULL)
         return -1;
 
@@ -316,8 +166,70 @@ static int DetectStreamSizeSetup (DetectEngineCtx *de_ctx, Signature *s, const c
  */
 void DetectStreamSizeFree(DetectEngineCtx *de_ctx, void *ptr)
 {
-    DetectStreamSizeData *sd = (DetectStreamSizeData *)ptr;
-    SCFree(sd);
+    rs_detect_stream_size_free(ptr);
+}
+
+/* prefilter code */
+
+static void PrefilterPacketStreamsizeMatch(
+        DetectEngineThreadCtx *det_ctx, Packet *p, const void *pectx)
+{
+    if (!(PKT_IS_TCP(p)) || PKT_IS_PSEUDOPKT(p))
+        return;
+
+    if (p->flow == NULL || p->flow->protoctx == NULL)
+        return;
+
+    /* during setup Suricata will automatically see if there is another
+     * check that can be added: alproto, sport or dport */
+    const PrefilterPacketHeaderCtx *ctx = pectx;
+    if (!PrefilterPacketHeaderExtraMatch(ctx, p))
+        return;
+
+    DetectStreamSizeData dsd;
+    dsd.du32.mode = ctx->v1.u8[0];
+    dsd.flags = ctx->v1.u8[1];
+    dsd.du32.arg1 = ctx->v1.u32[2];
+    const TcpSession *ssn = (TcpSession *)p->flow->protoctx;
+    /* if we match, add all the sigs that use this prefilter. This means
+     * that these will be inspected further */
+    if (DetectStreamSizeMatchAux(&dsd, ssn)) {
+        PrefilterAddSids(&det_ctx->pmq, ctx->sigs_array, ctx->sigs_cnt);
+    }
+}
+
+static void PrefilterPacketStreamSizeSet(PrefilterPacketHeaderValue *v, void *smctx)
+{
+    const DetectStreamSizeData *a = smctx;
+    v->u8[0] = a->du32.mode;
+    v->u8[1] = a->flags;
+    v->u32[2] = a->du32.arg1;
+}
+
+static bool PrefilterPacketStreamSizeCompare(PrefilterPacketHeaderValue v, void *smctx)
+{
+    const DetectStreamSizeData *a = smctx;
+    if (v.u8[0] == a->du32.mode && v.u8[1] == a->flags && v.u32[2] == a->du32.arg1)
+        return true;
+    return false;
+}
+
+static int PrefilterSetupStreamSize(DetectEngineCtx *de_ctx, SigGroupHead *sgh)
+{
+    return PrefilterSetupPacketHeader(de_ctx, sgh, DETECT_TCPMSS, PrefilterPacketStreamSizeSet,
+            PrefilterPacketStreamSizeCompare, PrefilterPacketStreamsizeMatch);
+}
+
+static bool PrefilterStreamSizeIsPrefilterable(const Signature *s)
+{
+    const SigMatch *sm;
+    for (sm = s->init_data->smlists[DETECT_SM_LIST_MATCH]; sm != NULL; sm = sm->next) {
+        switch (sm->type) {
+            case DETECT_STREAM_SIZE:
+                return true;
+        }
+    }
+    return false;
 }
 
 #ifdef UNITTESTS
@@ -330,9 +242,9 @@ static int DetectStreamSizeParseTest01 (void)
 {
     int result = 0;
     DetectStreamSizeData *sd = NULL;
-    sd = DetectStreamSizeParse(NULL, "server,<,6");
+    sd = rs_detect_stream_size_parse("server,<,6");
     if (sd != NULL) {
-        if (sd->flags & STREAM_SIZE_SERVER && sd->mode == DETECTSSIZE_LT && sd->ssize == 6)
+        if (sd->flags & StreamSizeServer && sd->du32.mode == DETECT_UINT_LT && sd->du32.arg1 == 6)
             result = 1;
         DetectStreamSizeFree(NULL, sd);
     }
@@ -349,9 +261,9 @@ static int DetectStreamSizeParseTest02 (void)
 {
     int result = 1;
     DetectStreamSizeData *sd = NULL;
-    sd = DetectStreamSizeParse(NULL, "invalidoption,<,6");
+    sd = rs_detect_stream_size_parse("invalidoption,<,6");
     if (sd != NULL) {
-        printf("expected: NULL got 0x%02X %" PRIu32 ": ",sd->flags, sd->ssize);
+        printf("expected: NULL got 0x%02X %" PRIu32 ": ", sd->flags, sd->du32.arg1);
         result = 0;
         DetectStreamSizeFree(NULL, sd);
     }
@@ -390,24 +302,24 @@ static int DetectStreamSizeParseTest03 (void)
     memset(&f, 0, sizeof(Flow));
     memset(&tcph, 0, sizeof(TCPHdr));
 
-    sd = DetectStreamSizeParse(NULL, "client,>,8");
+    sd = rs_detect_stream_size_parse("client,>,8");
     if (sd != NULL) {
-        if (!(sd->flags & STREAM_SIZE_CLIENT)) {
+        if (!(sd->flags & StreamSizeClient)) {
             printf("sd->flags not STREAM_SIZE_CLIENT: ");
             DetectStreamSizeFree(NULL, sd);
             SCFree(p);
             return 0;
         }
 
-        if (sd->mode != DETECTSSIZE_GT) {
+        if (sd->du32.mode != DETECT_UINT_GT) {
             printf("sd->mode not DETECTSSIZE_GT: ");
             DetectStreamSizeFree(NULL, sd);
             SCFree(p);
             return 0;
         }
 
-        if (sd->ssize != 8) {
-            printf("sd->ssize is %"PRIu32", not 8: ", sd->ssize);
+        if (sd->du32.arg1 != 8) {
+            printf("sd->ssize is %" PRIu32 ", not 8: ", sd->du32.arg1);
             DetectStreamSizeFree(NULL, sd);
             SCFree(p);
             return 0;
@@ -466,11 +378,12 @@ static int DetectStreamSizeParseTest04 (void)
     memset(&f, 0, sizeof(Flow));
     memset(&ip4h, 0, sizeof(IPV4Hdr));
 
-    sd = DetectStreamSizeParse(NULL, " client , > , 8 ");
+    sd = rs_detect_stream_size_parse(" client , > , 8 ");
     if (sd != NULL) {
-        if (!(sd->flags & STREAM_SIZE_CLIENT) && sd->mode != DETECTSSIZE_GT && sd->ssize != 8) {
-        SCFree(p);
-        return 0;
+        if (!(sd->flags & StreamSizeClient) && sd->du32.mode != DETECT_UINT_GT &&
+                sd->du32.arg1 != 8) {
+            SCFree(p);
+            return 0;
         }
     } else
         {
index 32f5c50b190bad751ded2d7053f39b84a558396a..3a460bf5e28b3c532cc1bec346d4855c5f210f97 100644 (file)
 #ifndef _DETECT_STREAM_SIZE_H
 #define        _DETECT_STREAM_SIZE_H
 
-#define DETECTSSIZE_LT 0
-#define DETECTSSIZE_LEQ 1
-#define DETECTSSIZE_EQ 2
-#define DETECTSSIZE_NEQ 3
-#define DETECTSSIZE_GT 4
-#define DETECTSSIZE_GEQ 5
-
-#define STREAM_SIZE_SERVER 0x01
-#define STREAM_SIZE_CLIENT 0x02
-#define STREAM_SIZE_BOTH   0x04
-#define STREAM_SIZE_EITHER 0x08
-
-typedef struct DetectStreamSizeData_ {
-    uint8_t flags;
-    uint8_t mode;
-    uint32_t ssize;
-}DetectStreamSizeData;
-
 void DetectStreamSizeRegister(void);
 
 #endif /* _DETECT_STREAM_SIZE_H */