From: Jeff Lucovsky Date: Sun, 9 Jul 2023 14:44:26 +0000 (-0400) Subject: detect/byte_math: Permit var name for bytes value X-Git-Tag: suricata-7.0.0~45 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=690b65ae881df234fbbb13b77e0b0e37a8bdda41;p=thirdparty%2Fsuricata.git detect/byte_math: Permit var name for bytes value Issue: 6145 Modifications to permit a variable name to be used for the byte_math bytes value. --- diff --git a/rust/src/detect/byte_math.rs b/rust/src/detect/byte_math.rs index a2bead3745..0cc60e52bf 100644 --- a/rust/src/detect/byte_math.rs +++ b/rust/src/detect/byte_math.rs @@ -33,6 +33,7 @@ pub const DETECT_BYTEMATH_FLAG_STRING: u8 = 0x02; pub const DETECT_BYTEMATH_FLAG_BITMASK: u8 = 0x04; pub const DETECT_BYTEMATH_FLAG_ENDIAN: u8 = 0x08; pub const DETECT_BYTEMATH_FLAG_RVALUE_VAR: u8 = 0x10; +pub const DETECT_BYTEMATH_FLAG_NBYTES_VAR: u8 = 0x20; // Ensure required values are provided const DETECT_BYTEMATH_FLAG_NBYTES: u8 = 0x1; @@ -98,6 +99,7 @@ enum ResultValue { pub struct DetectByteMathData { rvalue_str: *const c_char, result: *const c_char, + nbytes_str: *const c_char, rvalue: u32, offset: i32, bitmask_val: u32, @@ -120,6 +122,9 @@ impl Drop for DetectByteMathData { if !self.rvalue_str.is_null() { let _ = CString::from_raw(self.rvalue_str as *mut c_char); } + if !self.nbytes_str.is_null() { + let _ = CString::from_raw(self.nbytes_str as *mut c_char); + } } } } @@ -133,6 +138,7 @@ impl Default for DetectByteMathData { offset: 0, oper: ByteMathOperator::OperatorNone, rvalue_str: std::ptr::null_mut(), + nbytes_str: std::ptr::null_mut(), rvalue: 0, result: std::ptr::null_mut(), endian: DETECT_BYTEMATH_ENDIAN_DEFAULT, @@ -190,12 +196,12 @@ fn get_endian_value(value: &str) -> Result { // Parsed as a u64 for validation with u32 {min,max} so values greater than uint32 // are not treated as a string value. -fn parse_rvalue(input: &str) -> IResult<&str, ResultValue, RuleParseError<&str>> { - let (input, rvalue) = parse_token(input)?; - if let Ok(val) = rvalue.parse::() { +fn parse_var(input: &str) -> IResult<&str, ResultValue, RuleParseError<&str>> { + let (input, value) = parse_token(input)?; + if let Ok(val) = value.parse::() { Ok((input, ResultValue::Numeric(val))) } else { - Ok((input, ResultValue::String(rvalue.to_string()))) + Ok((input, ResultValue::String(value.to_string()))) } } @@ -259,7 +265,7 @@ fn parse_bytemath(input: &str) -> IResult<&str, DetectByteMathData, RuleParseErr if 0 != (required_flags & DETECT_BYTEMATH_FLAG_RVALUE) { return Err(make_error("rvalue already set".to_string())); } - let (_, res) = parse_rvalue(val)?; + let (_, res) = parse_var(val)?; match res { ResultValue::Numeric(val) => { if val >= u32::MIN.into() && val <= u32::MAX.into() { @@ -358,14 +364,29 @@ fn parse_bytemath(input: &str) -> IResult<&str, DetectByteMathData, RuleParseErr if 0 != (required_flags & DETECT_BYTEMATH_FLAG_NBYTES) { return Err(make_error("nbytes already set".to_string())); } - byte_math.nbytes = val - .parse() - .map_err(|_| make_error(format!("invalid bytes value: {}", val)))?; - if byte_math.nbytes < 1 || byte_math.nbytes > 10 { - return Err(make_error(format!( - "invalid bytes value: must be between 1 and 10: {}", - byte_math.nbytes - ))); + let (_, res) = parse_var(val)?; + match res { + ResultValue::Numeric(val) => { + if (1..=10).contains(&val) { + byte_math.nbytes = val as u8 + } else { + return Err(make_error(format!( + "invalid nbytes value: must be between 1 and 10: {}", + val + ))); + } + } + ResultValue::String(val) => match CString::new(val) { + Ok(newval) => { + byte_math.nbytes_str = newval.into_raw(); + byte_math.flags |= DETECT_BYTEMATH_FLAG_NBYTES_VAR; + } + _ => { + return Err(make_error( + "parse string not safely convertible to C".to_string(), + )) + } + }, } required_flags |= DETECT_BYTEMATH_FLAG_NBYTES; } @@ -439,6 +460,14 @@ mod tests { return false; } + if !self.nbytes_str.is_null() && !other.nbytes_str.is_null() { + let s_val = unsafe { CStr::from_ptr(self.nbytes_str) }; + let o_val = unsafe { CStr::from_ptr(other.nbytes_str) }; + res = s_val == o_val; + } else if !self.nbytes_str.is_null() || !other.nbytes_str.is_null() { + return false; + } + if !self.result.is_null() && !self.result.is_null() { let s_val = unsafe { CStr::from_ptr(self.result) }; let o_val = unsafe { CStr::from_ptr(other.result) }; @@ -462,7 +491,7 @@ mod tests { } fn valid_test( - args: &str, nbytes: u8, offset: i32, oper: ByteMathOperator, rvalue_str: &str, rvalue: u32, + args: &str, nbytes: u8, offset: i32, oper: ByteMathOperator, rvalue_str: &str, nbytes_str: &str, rvalue: u32, result: &str, base: ByteMathBase, endian: ByteMathEndian, bitmask_val: u32, flags: u8, ) { let bmd = DetectByteMathData { @@ -474,6 +503,11 @@ mod tests { } else { std::ptr::null_mut() }, + nbytes_str: if !nbytes_str.is_empty() { + CString::new(nbytes_str).unwrap().into_raw() + } else { + std::ptr::null_mut() + }, rvalue, result: CString::new(result).unwrap().into_raw(), base, @@ -501,6 +535,7 @@ mod tests { 3933, ByteMathOperator::Addition, "myrvalue", + "", 0, "myresult", ByteMathBase::BaseDec, @@ -517,6 +552,7 @@ mod tests { 3933, ByteMathOperator::Addition, "", + "", 99, "other", ByteMathBase::BaseDec, @@ -531,6 +567,7 @@ mod tests { -3933, ByteMathOperator::Addition, "rvalue", + "", 0, "foo", BASE_DEFAULT, @@ -539,6 +576,21 @@ mod tests { DETECT_BYTEMATH_FLAG_RVALUE_VAR, ); + valid_test( + "bytes nbytes_var, offset -3933, oper +, rvalue myrvalue, result foo", + 0, + -3933, + ByteMathOperator::Addition, + "rvalue", + "nbytes_var", + 0, + "foo", + BASE_DEFAULT, + ByteMathEndian::BigEndian, + 0, + DETECT_BYTEMATH_FLAG_RVALUE_VAR | DETECT_BYTEMATH_FLAG_NBYTES_VAR, + ); + // Out of order valid_test( "string dec, endian big, result other, rvalue 99, oper +, offset 3933, bytes 4", @@ -546,6 +598,7 @@ mod tests { 3933, ByteMathOperator::Addition, "", + "", 99, "other", ByteMathBase::BaseDec, diff --git a/src/detect-bytemath.c b/src/detect-bytemath.c index abcc1761d4..9064b06fcf 100644 --- a/src/detect-bytemath.c +++ b/src/detect-bytemath.c @@ -76,20 +76,30 @@ void DetectBytemathRegister(void) #endif } +static inline bool DetectByteMathValidateNbytesOnly(const DetectByteMathData *data, int32_t nbytes) +{ + return nbytes >= 1 && + (((data->flags & DETECT_BYTEMATH_FLAG_STRING) && nbytes <= 10) || (nbytes <= 4)); +} + int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *smd, - const Signature *s, const uint8_t *payload, - uint16_t payload_len, uint64_t rvalue, uint64_t *value, uint8_t endian) + const Signature *s, const uint8_t *payload, uint16_t payload_len, uint8_t nbytes, + uint64_t rvalue, uint64_t *value, uint8_t endian) { const DetectByteMathData *data = (DetectByteMathData *)smd->ctx; + if (payload_len == 0) { + return 0; + } + + if (!DetectByteMathValidateNbytesOnly(data, nbytes)) { + return 0; + } + const uint8_t *ptr; int32_t len; uint64_t val; int extbytes; - if (payload_len == 0) { - return 0; - } - /* Calculate the ptr value for the byte-math op and length remaining in * the packet from that point. */ @@ -116,33 +126,30 @@ int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *sm } /* Validate that the to-be-extracted is within the packet */ - if (ptr < payload || data->nbytes > len) { - SCLogDebug("Data not within payload pkt=%p, ptr=%p, len=%"PRIu32", nbytes=%d", - payload, ptr, len, data->nbytes); + if (ptr < payload || nbytes > len) { + SCLogDebug("Data not within payload pkt=%p, ptr=%p, len=%" PRIu32 ", nbytes=%d", payload, + ptr, len, nbytes); return 0; } /* Extract the byte data */ if (data->flags & DETECT_BYTEMATH_FLAG_STRING) { - extbytes = ByteExtractStringUint64(&val, data->base, - data->nbytes, (const char *)ptr); + extbytes = ByteExtractStringUint64(&val, data->base, nbytes, (const char *)ptr); if (extbytes <= 0) { if (val == 0) { SCLogDebug("No Numeric value"); return 0; } else { - SCLogDebug("error extracting %d bytes of string data: %d", - data->nbytes, extbytes); + SCLogDebug("error extracting %d bytes of string data: %d", nbytes, extbytes); return -1; } } } else { ByteMathEndian bme = endian; int endianness = (bme == BigEndian) ? BYTE_BIG_ENDIAN : BYTE_LITTLE_ENDIAN; - extbytes = ByteExtractUint64(&val, endianness, data->nbytes, ptr); - if (extbytes != data->nbytes) { - SCLogDebug("error extracting %d bytes of numeric data: %d", - data->nbytes, extbytes); + extbytes = ByteExtractUint64(&val, endianness, nbytes, ptr); + if (extbytes != nbytes) { + SCLogDebug("error extracting %d bytes of numeric data: %d", nbytes, extbytes); return 0; } } @@ -206,7 +213,8 @@ int DetectByteMathDoMatch(DetectEngineThreadCtx *det_ctx, const SigMatchData *sm * \retval bmd On success an instance containing the parsed data. * On failure, NULL. */ -static DetectByteMathData *DetectByteMathParse(DetectEngineCtx *de_ctx, const char *arg, char **rvalue) +static DetectByteMathData *DetectByteMathParse( + DetectEngineCtx *de_ctx, const char *arg, char **nbytes, char **rvalue) { DetectByteMathData *bmd; if ((bmd = ScByteMathParse(arg)) == NULL) { @@ -214,6 +222,19 @@ static DetectByteMathData *DetectByteMathParse(DetectEngineCtx *de_ctx, const ch return NULL; } + if (bmd->nbytes_str) { + if (nbytes == NULL) { + SCLogError("byte_math supplied with " + "var name for nbytes. \"nbytes\" argument supplied to " + "this function must be non-NULL"); + goto error; + } + *nbytes = SCStrdup(bmd->nbytes_str); + if (*nbytes == NULL) { + goto error; + } + } + if (bmd->rvalue_str) { if (rvalue == NULL) { SCLogError("byte_math supplied with " @@ -262,9 +283,10 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char SigMatch *prev_pm = NULL; DetectByteMathData *data; char *rvalue = NULL; + char *nbytes = NULL; int ret = -1; - data = DetectByteMathParse(de_ctx, arg, &rvalue); + data = DetectByteMathParse(de_ctx, arg, &nbytes, &rvalue); if (data == NULL) goto error; @@ -336,6 +358,18 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char } } + if (nbytes != NULL) { + DetectByteIndexType index; + if (!DetectByteRetrieveSMVar(nbytes, s, &index)) { + SCLogError("unknown byte_ keyword var seen in byte_math - %s", nbytes); + goto error; + } + data->nbytes = index; + data->flags |= DETECT_BYTEMATH_FLAG_NBYTES_VAR; + SCFree(nbytes); + nbytes = NULL; + } + if (rvalue != NULL) { DetectByteIndexType index; if (!DetectByteRetrieveSMVar(rvalue, s, &index)) { @@ -386,6 +420,8 @@ static int DetectByteMathSetup(DetectEngineCtx *de_ctx, Signature *s, const char error: if (rvalue) SCFree(rvalue); + if (nbytes) + SCFree(nbytes); DetectByteMathFree(de_ctx, data); return ret; } @@ -448,8 +484,10 @@ SigMatch *DetectByteMathRetrieveSMVar(const char *arg, const Signature *s) static int DetectByteMathParseTest01(void) { - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue 10, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue 10, result bar", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -468,8 +506,10 @@ static int DetectByteMathParseTest01(void) static int DetectByteMathParseTest02(void) { /* bytes value invalid */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 257, offset 2, oper +, " - "rvalue 39, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 257, offset 2, oper +, " + "rvalue 39, result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); @@ -479,8 +519,10 @@ static int DetectByteMathParseTest02(void) static int DetectByteMathParseTest03(void) { /* bytes value invalid */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 11, offset 2, oper +, " - "rvalue 39, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 11, offset 2, oper +, " + "rvalue 39, result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); PASS; @@ -489,8 +531,10 @@ static int DetectByteMathParseTest03(void) static int DetectByteMathParseTest04(void) { /* offset value invalid */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 70000, oper +," - " rvalue 39, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 70000, oper +," + " rvalue 39, result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); @@ -500,8 +544,10 @@ static int DetectByteMathParseTest04(void) static int DetectByteMathParseTest05(void) { /* oper value invalid */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 11, offset 16, oper &," - "rvalue 39, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 11, offset 16, oper &," + "rvalue 39, result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); PASS; @@ -512,9 +558,10 @@ static int DetectByteMathParseTest06(void) uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE; char *rvalue = NULL; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 0, oper +," - "rvalue 248, result var, relative", - &rvalue); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 0, oper +," + "rvalue 248, result var, relative", + NULL, &rvalue); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -535,9 +582,10 @@ static int DetectByteMathParseTest07(void) { char *rvalue = NULL; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue foo, result bar", - &rvalue); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue foo, result bar", + NULL, &rvalue); FAIL_IF_NOT(rvalue); FAIL_IF_NOT(bmd->nbytes == 4); FAIL_IF_NOT(bmd->offset == 2); @@ -557,8 +605,10 @@ static int DetectByteMathParseTest07(void) static int DetectByteMathParseTest08(void) { /* ensure Parse checks the pointer value when rvalue is a var */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue foo, result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue foo, result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); PASS; @@ -568,9 +618,10 @@ static int DetectByteMathParseTest09(void) { uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue 39, result bar, relative", - NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue 39, result bar, relative", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -591,9 +642,11 @@ static int DetectByteMathParseTest10(void) { uint8_t flags = DETECT_BYTEMATH_FLAG_ENDIAN; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue 39, result bar, endian" - " big", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue 39, result bar, endian" + " big", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -614,9 +667,10 @@ static int DetectByteMathParseTest11(void) { uint8_t flags = DETECT_BYTEMATH_FLAG_ENDIAN; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, " - "rvalue 39, result bar, dce", - NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +, " + "rvalue 39, result bar, dce", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -637,9 +691,11 @@ static int DetectByteMathParseTest12(void) { uint8_t flags = DETECT_BYTEMATH_FLAG_RELATIVE | DETECT_BYTEMATH_FLAG_STRING; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue 39, result bar, " - "relative, string dec", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue 39, result bar, " + "relative, string dec", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -662,10 +718,12 @@ static int DetectByteMathParseTest13(void) DETECT_BYTEMATH_FLAG_RELATIVE | DETECT_BYTEMATH_FLAG_BITMASK; - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, " - "rvalue 39, result bar, " - "relative, string dec, bitmask " - "0x8f40", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +, " + "rvalue 39, result bar, " + "relative, string dec, bitmask " + "0x8f40", + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); @@ -688,8 +746,10 @@ static int DetectByteMathParseTest13(void) static int DetectByteMathParseTest14(void) { /* incomplete */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +," - "rvalue foo", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +," + "rvalue foo", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); @@ -700,8 +760,10 @@ static int DetectByteMathParseTest15(void) { /* incomplete */ - DetectByteMathData *bmd = DetectByteMathParse(NULL, "bytes 4, offset 2, oper +, " - "result bar", NULL); + DetectByteMathData *bmd = DetectByteMathParse(NULL, + "bytes 4, offset 2, oper +, " + "result bar", + NULL, NULL); FAIL_IF_NOT(bmd == NULL); @@ -718,7 +780,7 @@ static int DetectByteMathParseTest16(void) "rvalue 39, result bar, " "relative, string dec, bitmask " "0x8f40", - NULL); + NULL, NULL); FAIL_IF(bmd == NULL); FAIL_IF_NOT(bmd->nbytes == 4); diff --git a/src/detect-bytemath.h b/src/detect-bytemath.h index 099bd81f10..672f799ca4 100644 --- a/src/detect-bytemath.h +++ b/src/detect-bytemath.h @@ -28,6 +28,6 @@ void DetectBytemathRegister(void); SigMatch *DetectByteMathRetrieveSMVar(const char *, const Signature *); int DetectByteMathDoMatch(DetectEngineThreadCtx *, const SigMatchData *, const Signature *, - const uint8_t *, uint16_t, uint64_t, uint64_t *, uint8_t); + const uint8_t *, uint16_t, uint8_t, uint64_t, uint64_t *, uint8_t); #endif /* __DETECT_BYTEMATH_H__ */ diff --git a/src/detect-engine-content-inspection.c b/src/detect-engine-content-inspection.c index 6f57ad55f5..77ebb3f827 100644 --- a/src/detect-engine-content-inspection.c +++ b/src/detect-engine-content-inspection.c @@ -588,8 +588,15 @@ uint8_t DetectEngineContentInspection(DetectEngineCtx *de_ctx, DetectEngineThrea rvalue = bmd->rvalue; } + uint8_t nbytes; + if (bmd->flags & DETECT_BYTEMATH_FLAG_NBYTES_VAR) { + nbytes = (uint8_t)det_ctx->byte_values[bmd->nbytes]; + } else { + nbytes = bmd->nbytes; + } + DEBUG_VALIDATE_BUG_ON(buffer_len > UINT16_MAX); - if (DetectByteMathDoMatch(det_ctx, smd, s, buffer, (uint16_t)buffer_len, rvalue, + if (DetectByteMathDoMatch(det_ctx, smd, s, buffer, (uint16_t)buffer_len, nbytes, rvalue, &det_ctx->byte_values[bmd->local_id], endian) != 1) { goto no_match; }