From: Philippe Antoine Date: Wed, 23 Mar 2022 19:44:44 +0000 (+0100) Subject: detect: rust generic functions for integers X-Git-Tag: suricata-7.0.0-beta1~519 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f29b43defd15e5f1722763a4955271450b951b55;p=thirdparty%2Fsuricata.git detect: rust generic functions for integers Move it away from http2 to generic core crate. And use it for DCERPC (and SMB) And remove the C version. Main change in API is the free function is not free itself, but a rust wrapper around unbox. Ticket: #4112 --- diff --git a/rust/src/dcerpc/detect.rs b/rust/src/dcerpc/detect.rs index 7cc3a59977..63e5188cc5 100644 --- a/rust/src/dcerpc/detect.rs +++ b/rust/src/dcerpc/detect.rs @@ -19,23 +19,17 @@ use super::dcerpc::{ DCERPCState, DCERPCTransaction, DCERPC_TYPE_REQUEST, DCERPC_TYPE_RESPONSE, DCERPC_UUID_ENTRY_FLAG_FF, }; +use crate::detect::{detect_match_uint, detect_parse_uint, DetectUintData}; use std::ffi::CStr; use std::os::raw::{c_char, c_void}; use uuid::Uuid; -pub const DETECT_DCE_IFACE_OP_NONE: u8 = 0; -pub const DETECT_DCE_IFACE_OP_LT: u8 = 1; -pub const DETECT_DCE_IFACE_OP_GT: u8 = 2; -pub const DETECT_DCE_IFACE_OP_EQ: u8 = 3; -pub const DETECT_DCE_IFACE_OP_NE: u8 = 4; - pub const DETECT_DCE_OPNUM_RANGE_UNINITIALIZED: u32 = 100000; #[derive(Debug)] pub struct DCEIfaceData { pub if_uuid: Vec, - pub op: u8, - pub version: u16, + pub du16: Option>, pub any_frag: u8, } @@ -59,47 +53,6 @@ pub struct DCEOpnumData { pub data: Vec, } -fn extract_op_version(opver: &str) -> Result<(u8, u16), ()> { - if !opver.is_char_boundary(1){ - return Err(()); - } - let (op, version) = opver.split_at(1); - let opval: u8 = match op { - ">" => DETECT_DCE_IFACE_OP_GT, - "<" => DETECT_DCE_IFACE_OP_LT, - "=" => DETECT_DCE_IFACE_OP_EQ, - "!" => DETECT_DCE_IFACE_OP_NE, - _ => DETECT_DCE_IFACE_OP_NONE, - }; - - let version: u16 = match version.parse::() { - Ok(res) => res, - _ => { - return Err(()); - } - }; - if opval == DETECT_DCE_IFACE_OP_NONE - || (opval == DETECT_DCE_IFACE_OP_LT && version == std::u16::MIN) - || (opval == DETECT_DCE_IFACE_OP_GT && version == std::u16::MAX) - { - return Err(()); - } - - Ok((opval, version)) -} - -fn match_iface_version(version: u16, if_data: &DCEIfaceData) -> bool { - match if_data.op { - DETECT_DCE_IFACE_OP_LT => version < if_data.version, - DETECT_DCE_IFACE_OP_GT => version > if_data.version, - DETECT_DCE_IFACE_OP_EQ => version == if_data.version, - DETECT_DCE_IFACE_OP_NE => version != if_data.version, - _ => { - return true; - } - } -} - fn match_backuuid( tx: &mut DCERPCTransaction, state: &mut DCERPCState, if_data: &mut DCEIfaceData, ) -> u8 { @@ -132,11 +85,11 @@ fn match_backuuid( continue; } - if if_data.op != DETECT_DCE_IFACE_OP_NONE - && !match_iface_version(uuidentry.version, if_data) - { - SCLogDebug!("Interface version did not match"); - ret &= 0; + if let Some(x) = &if_data.du16 { + if !detect_match_uint(&x, uuidentry.version) { + SCLogDebug!("Interface version did not match"); + ret &= 0; + } } if ret == 1 { @@ -150,7 +103,7 @@ fn match_backuuid( fn parse_iface_data(arg: &str) -> Result { let split_args: Vec<&str> = arg.split(',').collect(); - let mut op_version = (0, 0); + let mut du16 = None; let mut any_frag: u8 = 0; let if_uuid = match Uuid::parse_str(split_args[0]) { Ok(res) => res.as_bytes().to_vec(), @@ -166,8 +119,8 @@ fn parse_iface_data(arg: &str) -> Result { any_frag = 1; } _ => { - op_version = match extract_op_version(split_args[1]) { - Ok((op, ver)) => (op, ver), + match detect_parse_uint(split_args[1]) { + Ok((_, x)) => du16 = Some(x), _ => { return Err(()); } @@ -175,8 +128,8 @@ fn parse_iface_data(arg: &str) -> Result { } }, 3 => { - op_version = match extract_op_version(split_args[1]) { - Ok((op, ver)) => (op, ver), + match detect_parse_uint(split_args[1]) { + Ok((_, x)) => du16 = Some(x), _ => { return Err(()); } @@ -193,8 +146,7 @@ fn parse_iface_data(arg: &str) -> Result { Ok(DCEIfaceData { if_uuid: if_uuid, - op: op_version.0, - version: op_version.1, + du16: du16, any_frag: any_frag, }) } @@ -343,30 +295,39 @@ pub unsafe extern "C" fn rs_dcerpc_opnum_free(ptr: *mut c_void) { #[cfg(test)] mod test { use super::*; + use crate::detect::DetectUintMode; + fn extract_op_version(i: &str) -> Result<(DetectUintMode, u16), ()> { + match detect_parse_uint(i) { + Ok((_, d)) => return Ok((d.mode, d.arg1)), + _ => { + return Err(()); + } + } + } #[test] fn test_extract_op_version() { let op_version = "<1"; assert_eq!( - Ok((DETECT_DCE_IFACE_OP_LT, 1)), + Ok((DetectUintMode::DetectUintModeLt, 1)), extract_op_version(op_version) ); let op_version = ">10"; assert_eq!( - Ok((DETECT_DCE_IFACE_OP_GT, 10)), + Ok((DetectUintMode::DetectUintModeGt, 10)), extract_op_version(op_version) ); let op_version = "=45"; assert_eq!( - Ok((DETECT_DCE_IFACE_OP_EQ, 45)), + Ok((DetectUintMode::DetectUintModeEqual, 45)), extract_op_version(op_version) ); let op_version = "!0"; assert_eq!( - Ok((DETECT_DCE_IFACE_OP_NE, 0)), + Ok((DetectUintMode::DetectUintModeNe, 0)), extract_op_version(op_version) ); @@ -374,26 +335,21 @@ mod test { assert_eq!(true, extract_op_version(op_version).is_err()); let op_version = ""; - assert_eq!( - Err(()), - extract_op_version(op_version) - ); - + assert_eq!(Err(()), extract_op_version(op_version)); } #[test] fn test_match_iface_version() { - let iface_data = DCEIfaceData { - if_uuid: Vec::new(), - op: 3, - version: 10, - any_frag: 0, + let iface_data = DetectUintData:: { + mode: DetectUintMode::DetectUintModeEqual, + arg1: 10, + arg2: 0, }; - let version = 10; - assert_eq!(true, match_iface_version(version, &iface_data)); + let version: u16 = 10; + assert_eq!(true, detect_match_uint(&iface_data, version)); - let version = 2; - assert_eq!(false, match_iface_version(version, &iface_data)); + let version: u16 = 2; + assert_eq!(false, detect_match_uint(&iface_data, version)); } #[test] @@ -411,8 +367,9 @@ mod test { let uuid = Uuid::from_slice(iface_data.if_uuid.as_slice()); let uuid = uuid.map(|uuid| uuid.to_hyphenated().to_string()); assert_eq!(expected_uuid, uuid); - assert_eq!(DETECT_DCE_IFACE_OP_GT, iface_data.op); - assert_eq!(1, iface_data.version); + let du16 = iface_data.du16.unwrap(); + assert_eq!(DetectUintMode::DetectUintModeGt, du16.mode); + assert_eq!(1, du16.arg1); let arg = "12345678-1234-1234-1234-123456789ABC,any_frag"; let iface_data = parse_iface_data(arg).unwrap(); @@ -420,9 +377,8 @@ mod test { let uuid = Uuid::from_slice(iface_data.if_uuid.as_slice()); let uuid = uuid.map(|uuid| uuid.to_hyphenated().to_string()); assert_eq!(expected_uuid, uuid); - assert_eq!(DETECT_DCE_IFACE_OP_NONE, iface_data.op); + assert!(iface_data.du16.is_none()); assert_eq!(1, iface_data.any_frag); - assert_eq!(0, iface_data.version); let arg = "12345678-1234-1234-1234-123456789ABC,!10,any_frag"; let iface_data = parse_iface_data(arg).unwrap(); @@ -430,9 +386,10 @@ mod test { let uuid = Uuid::from_slice(iface_data.if_uuid.as_slice()); let uuid = uuid.map(|uuid| uuid.to_hyphenated().to_string()); assert_eq!(expected_uuid, uuid); - assert_eq!(DETECT_DCE_IFACE_OP_NE, iface_data.op); assert_eq!(1, iface_data.any_frag); - assert_eq!(10, iface_data.version); + let du16 = iface_data.du16.unwrap(); + assert_eq!(DetectUintMode::DetectUintModeNe, du16.mode); + assert_eq!(10, du16.arg1); let arg = "12345678-1234-1234-1234-123456789ABC,>1,ay_frag"; let iface_data = parse_iface_data(arg); @@ -458,7 +415,7 @@ mod test { let iface_data = parse_iface_data(arg); assert_eq!(iface_data.is_err(), true); - let arg = "12345678-1234-1234-1234-123456789ABC,>=1,any_frag"; + let arg = "12345678-1234-1234-1234-123456789ABC,>=0,any_frag"; let iface_data = parse_iface_data(arg); assert_eq!(iface_data.is_err(), true); diff --git a/rust/src/detect.rs b/rust/src/detect.rs new file mode 100644 index 0000000000..582728b285 --- /dev/null +++ b/rust/src/detect.rs @@ -0,0 +1,302 @@ +/* Copyright (C) 2022 Open Information Security Foundation + * + * You can copy, redistribute or modify this Program under the terms of + * the GNU General Public License version 2 as published by the Free + * Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * version 2 along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA + * 02110-1301, USA. + */ + +use nom7::branch::alt; +use nom7::bytes::complete::{is_a, tag, take_while}; +use nom7::character::complete::digit1; +use nom7::combinator::{all_consuming, map_opt, opt, value, verify}; +use nom7::error::{make_error, ErrorKind}; +use nom7::Err; +use nom7::IResult; + +use std::ffi::CStr; + +#[derive(PartialEq, Clone, Debug)] +#[repr(u8)] +pub enum DetectUintMode { + DetectUintModeEqual, + DetectUintModeLt, + DetectUintModeLte, + DetectUintModeGt, + DetectUintModeGte, + DetectUintModeRange, + DetectUintModeNe, +} + +#[derive(Debug)] +#[repr(C)] +pub struct DetectUintData { + pub arg1: T, + pub arg2: T, + pub mode: DetectUintMode, +} + +pub trait DetectIntType: + std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded +{ +} +impl DetectIntType for T where + T: std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded +{ +} + +fn detect_parse_uint_start_equal(i: &str) -> IResult<&str, DetectUintData> { + let (i, _) = opt(tag("="))(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + Ok(( + i, + DetectUintData { + arg1, + arg2: T::min_value(), + mode: DetectUintMode::DetectUintModeEqual, + }, + )) +} + +fn detect_parse_uint_start_interval(i: &str) -> IResult<&str, DetectUintData> { + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = alt((tag("-"), tag("<>")))(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg2) = verify(map_opt(digit1, |s: &str| s.parse::().ok()), |x| { + x > &arg1 && *x - arg1 > T::one() + })(i)?; + Ok(( + i, + DetectUintData { + arg1, + arg2, + mode: DetectUintMode::DetectUintModeRange, + }, + )) +} + +fn detect_parse_uint_mode(i: &str) -> IResult<&str, DetectUintMode> { + let (i, mode) = alt(( + value(DetectUintMode::DetectUintModeGte, tag(">=")), + value(DetectUintMode::DetectUintModeLte, tag("<=")), + value(DetectUintMode::DetectUintModeGt, tag(">")), + value(DetectUintMode::DetectUintModeLt, tag("<")), + value(DetectUintMode::DetectUintModeNe, tag("!")), + ))(i)?; + return Ok((i, mode)); +} + +fn detect_parse_uint_start_symbol(i: &str) -> IResult<&str, DetectUintData> { + let (i, mode) = detect_parse_uint_mode(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + + match mode { + DetectUintMode::DetectUintModeNe => {} + DetectUintMode::DetectUintModeLt => { + if arg1 == T::min_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeLte => { + if arg1 == T::max_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeGt => { + if arg1 == T::max_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + DetectUintMode::DetectUintModeGte => { + if arg1 == T::min_value() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + } + _ => { + return Err(Err::Error(make_error(i, ErrorKind::MapOpt))); + } + } + + Ok(( + i, + DetectUintData { + arg1, + arg2: T::min_value(), + mode: mode, + }, + )) +} + +pub fn detect_match_uint(x: &DetectUintData, val: T) -> bool { + match x.mode { + DetectUintMode::DetectUintModeEqual => { + if val == x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeNe => { + if val != x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeLt => { + if val < x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeLte => { + if val <= x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeGt => { + if val > x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeGte => { + if val >= x.arg1 { + return true; + } + } + DetectUintMode::DetectUintModeRange => { + if val > x.arg1 && val < x.arg2 { + return true; + } + } + } + return false; +} + +pub fn detect_parse_uint(i: &str) -> IResult<&str, DetectUintData> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, uint) = alt(( + detect_parse_uint_start_interval, + detect_parse_uint_start_equal, + detect_parse_uint_start_symbol, + ))(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + Ok((i, uint)) +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u64_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u64_free(ctx: *mut std::os::raw::c_void) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx as *mut DetectUintData)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_match( + arg: u32, ctx: &DetectUintData, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u32_free(ctx: &mut DetectUintData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_match( + arg: u8, ctx: &DetectUintData, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u8_free(ctx: &mut DetectUintData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectUintData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_uint::(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_match( + arg: u16, ctx: &DetectUintData, +) -> std::os::raw::c_int { + if detect_match_uint(ctx, arg) { + return 1; + } + return 0; +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} diff --git a/rust/src/http2/detect.rs b/rust/src/http2/detect.rs index b4e834a225..eed1999a08 100644 --- a/rust/src/http2/detect.rs +++ b/rust/src/http2/detect.rs @@ -20,6 +20,7 @@ use super::http2::{ }; use super::parser; use crate::core::Direction; +use crate::detect::{detect_match_uint, DetectUintData}; use std::ffi::CStr; use std::str::FromStr; @@ -254,28 +255,11 @@ fn http2_detect_settings_match( None => { return 1; } - Some(x) => match x.mode { - parser::DetectUintMode::DetectUintModeEqual => { - if set[i].value == x.value { - return 1; - } - } - parser::DetectUintMode::DetectUintModeLt => { - if set[i].value <= x.value { - return 1; - } - } - parser::DetectUintMode::DetectUintModeGt => { - if set[i].value >= x.value { - return 1; - } - } - parser::DetectUintMode::DetectUintModeRange => { - if set[i].value <= x.value && set[i].value >= x.valrange { - return 1; - } + Some(x) => { + if detect_match_uint(&x, set[i].value) { + return 1; } - }, + } } } } @@ -320,59 +304,13 @@ pub unsafe extern "C" fn rs_http2_detect_settingsctx_match( return http2_detect_settingsctx_match(ctx, tx, direction.into()); } -#[no_mangle] -pub unsafe extern "C" fn rs_detect_u64_parse( - str: *const std::os::raw::c_char, -) -> *mut std::os::raw::c_void { - let ft_name: &CStr = CStr::from_ptr(str); //unsafe - if let Ok(s) = ft_name.to_str() { - if let Ok((_, ctx)) = parser::detect_parse_u64(s) { - let boxed = Box::new(ctx); - return Box::into_raw(boxed) as *mut _; - } - } - return std::ptr::null_mut(); -} - -#[no_mangle] -pub unsafe extern "C" fn rs_detect_u64_free(ctx: *mut std::os::raw::c_void) { - // Just unbox... - std::mem::drop(Box::from_raw(ctx as *mut parser::DetectU64Data)); -} - fn http2_detect_sizeupdate_match( - blocks: &[parser::HTTP2FrameHeaderBlock], ctx: &parser::DetectU64Data, + blocks: &[parser::HTTP2FrameHeaderBlock], ctx: &DetectUintData, ) -> std::os::raw::c_int { for block in blocks.iter() { - match ctx.mode { - parser::DetectUintMode::DetectUintModeEqual => { - if block.sizeupdate == ctx.value - && block.error == parser::HTTP2HeaderDecodeStatus::HTTP2HeaderDecodeSizeUpdate - { - return 1; - } - } - parser::DetectUintMode::DetectUintModeLt => { - if block.sizeupdate <= ctx.value - && block.error == parser::HTTP2HeaderDecodeStatus::HTTP2HeaderDecodeSizeUpdate - { - return 1; - } - } - parser::DetectUintMode::DetectUintModeGt => { - if block.sizeupdate >= ctx.value - && block.error == parser::HTTP2HeaderDecodeStatus::HTTP2HeaderDecodeSizeUpdate - { - return 1; - } - } - parser::DetectUintMode::DetectUintModeRange => { - if block.sizeupdate <= ctx.value - && block.sizeupdate >= ctx.valrange - && block.error == parser::HTTP2HeaderDecodeStatus::HTTP2HeaderDecodeSizeUpdate - { - return 1; - } + if block.error == parser::HTTP2HeaderDecodeStatus::HTTP2HeaderDecodeSizeUpdate { + if detect_match_uint(&ctx, block.sizeupdate) { + return 1; } } } @@ -396,7 +334,7 @@ fn http2_header_blocks(frame: &HTTP2Frame) -> Option<&[parser::HTTP2FrameHeaderB } fn http2_detect_sizeupdatectx_match( - ctx: &mut parser::DetectU64Data, tx: &mut HTTP2Transaction, direction: Direction, + ctx: &mut DetectUintData, tx: &mut HTTP2Transaction, direction: Direction, ) -> std::os::raw::c_int { if direction == Direction::ToServer { for i in 0..tx.frames_ts.len() { @@ -422,7 +360,7 @@ fn http2_detect_sizeupdatectx_match( pub unsafe extern "C" fn rs_http2_detect_sizeupdatectx_match( ctx: *const std::os::raw::c_void, tx: *mut std::os::raw::c_void, direction: u8, ) -> std::os::raw::c_int { - let ctx = cast_pointer!(ctx, parser::DetectU64Data); + let ctx = cast_pointer!(ctx, DetectUintData); let tx = cast_pointer!(tx, HTTP2Transaction); return http2_detect_sizeupdatectx_match(ctx, tx, direction.into()); } diff --git a/rust/src/http2/parser.rs b/rust/src/http2/parser.rs index f768026547..bf2d2d13c7 100644 --- a/rust/src/http2/parser.rs +++ b/rust/src/http2/parser.rs @@ -17,11 +17,11 @@ use super::huffman; use crate::common::nom7::bits; +use crate::detect::{detect_parse_uint, DetectUintData}; use crate::http2::http2::{HTTP2DynTable, HTTP2_MAX_TABLESIZE}; use nom7::bits::streaming::take as take_bits; use nom7::branch::alt; -use nom7::bytes::streaming::{is_a, is_not, tag, take, take_while}; -use nom7::character::complete::digit1; +use nom7::bytes::streaming::{is_a, is_not, take, take_while}; use nom7::combinator::{complete, cond, map_opt, opt, rest, verify}; use nom7::error::{make_error, ErrorKind}; use nom7::multi::many0; @@ -715,96 +715,9 @@ impl std::str::FromStr for HTTP2SettingsId { } } -//TODOask move elsewhere generic with DetectU64Data and such -#[derive(PartialEq, Debug)] -pub enum DetectUintMode { - DetectUintModeEqual, - DetectUintModeLt, - DetectUintModeGt, - DetectUintModeRange, -} - -pub struct DetectU32Data { - pub value: u32, - pub valrange: u32, - pub mode: DetectUintMode, -} - pub struct DetectHTTP2settingsSigCtx { - pub id: HTTP2SettingsId, //identifier - pub value: Option, //optional value -} - -fn detect_parse_u32_start_equal(i: &str) -> IResult<&str, DetectU32Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = opt(tag("="))(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU32Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeEqual, - }, - )) -} - -fn detect_parse_u32_start_interval(i: &str) -> IResult<&str, DetectU32Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag("-")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, valrange) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU32Data { - value, - valrange, - mode: DetectUintMode::DetectUintModeRange, - }, - )) -} - -fn detect_parse_u32_start_lesser(i: &str) -> IResult<&str, DetectU32Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag("<")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU32Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeLt, - }, - )) -} - -fn detect_parse_u32_start_greater(i: &str) -> IResult<&str, DetectU32Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag(">")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU32Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeGt, - }, - )) -} - -fn detect_parse_u32(i: &str) -> IResult<&str, DetectU32Data> { - let (i, u32) = alt(( - detect_parse_u32_start_lesser, - detect_parse_u32_start_greater, - complete(detect_parse_u32_start_interval), - detect_parse_u32_start_equal, - ))(i)?; - Ok((i, u32)) + pub id: HTTP2SettingsId, //identifier + pub value: Option>, //optional value } pub fn http2_parse_settingsctx(i: &str) -> IResult<&str, DetectHTTP2settingsSigCtx> { @@ -812,88 +725,10 @@ pub fn http2_parse_settingsctx(i: &str) -> IResult<&str, DetectHTTP2settingsSigC let (i, id) = map_opt(alt((complete(is_not(" <>=")), rest)), |s: &str| { HTTP2SettingsId::from_str(s).ok() })(i)?; - let (i, value) = opt(complete(detect_parse_u32))(i)?; + let (i, value) = opt(complete(detect_parse_uint))(i)?; Ok((i, DetectHTTP2settingsSigCtx { id, value })) } -pub struct DetectU64Data { - pub value: u64, - pub valrange: u64, - pub mode: DetectUintMode, -} - -fn detect_parse_u64_start_equal(i: &str) -> IResult<&str, DetectU64Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = opt(tag("="))(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU64Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeEqual, - }, - )) -} - -fn detect_parse_u64_start_interval(i: &str) -> IResult<&str, DetectU64Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag("-")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, valrange) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU64Data { - value, - valrange, - mode: DetectUintMode::DetectUintModeRange, - }, - )) -} - -fn detect_parse_u64_start_lesser(i: &str) -> IResult<&str, DetectU64Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag("<")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU64Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeLt, - }, - )) -} - -fn detect_parse_u64_start_greater(i: &str) -> IResult<&str, DetectU64Data> { - let (i, _) = opt(is_a(" "))(i)?; - let (i, _) = tag(">")(i)?; - let (i, _) = opt(is_a(" "))(i)?; - let (i, value) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; - Ok(( - i, - DetectU64Data { - value, - valrange: 0, - mode: DetectUintMode::DetectUintModeGt, - }, - )) -} - -pub fn detect_parse_u64(i: &str) -> IResult<&str, DetectU64Data> { - let (i, u64) = alt(( - detect_parse_u64_start_lesser, - detect_parse_u64_start_greater, - complete(detect_parse_u64_start_interval), - detect_parse_u64_start_equal, - ))(i)?; - Ok((i, u64)) -} - #[derive(Clone, Copy, Debug)] pub struct HTTP2FrameSettings { pub id: HTTP2SettingsId, @@ -914,6 +749,7 @@ pub fn http2_parse_frame_settings(i: &[u8]) -> IResult<&[u8], Vec { - assert_eq!(ctxval.value, 42); + assert_eq!(ctxval.arg1, 42); } None => { panic!("No value"); @@ -1096,9 +932,9 @@ mod tests { assert_eq!(ctx.id, HTTP2SettingsId::SETTINGSMAXCONCURRENTSTREAMS); match ctx.value { Some(ctxval) => { - assert_eq!(ctxval.value, 42); + assert_eq!(ctxval.arg1, 42); assert_eq!(ctxval.mode, DetectUintMode::DetectUintModeRange); - assert_eq!(ctxval.valrange, 68); + assert_eq!(ctxval.arg2, 68); } None => { panic!("No value"); @@ -1118,7 +954,7 @@ mod tests { assert_eq!(ctx.id, HTTP2SettingsId::SETTINGSMAXCONCURRENTSTREAMS); match ctx.value { Some(ctxval) => { - assert_eq!(ctxval.value, 54); + assert_eq!(ctxval.arg1, 54); assert_eq!(ctxval.mode, DetectUintMode::DetectUintModeLt); } None => { @@ -1139,7 +975,7 @@ mod tests { assert_eq!(ctx.id, HTTP2SettingsId::SETTINGSMAXCONCURRENTSTREAMS); match ctx.value { Some(ctxval) => { - assert_eq!(ctxval.value, 76); + assert_eq!(ctxval.arg1, 76); assert_eq!(ctxval.mode, DetectUintMode::DetectUintModeGt); } None => { diff --git a/rust/src/lib.rs b/rust/src/lib.rs index fa7ce05815..a8b729bf61 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -101,6 +101,7 @@ pub mod frames; pub mod filecontainer; pub mod filetracker; pub mod kerberos; +pub mod detect; #[cfg(feature = "lua")] pub mod lua; diff --git a/rust/src/smb/detect.rs b/rust/src/smb/detect.rs index a6e1560a1a..97ab525d9f 100644 --- a/rust/src/smb/detect.rs +++ b/rust/src/smb/detect.rs @@ -20,6 +20,7 @@ use crate::core::*; use crate::smb::smb::*; use crate::dcerpc::detect::{DCEIfaceData, DCEOpnumData, DETECT_DCE_OPNUM_RANGE_UNINITIALIZED}; use crate::dcerpc::dcerpc::DCERPC_TYPE_REQUEST; +use crate::detect::detect_match_uint; #[no_mangle] pub unsafe extern "C" fn rs_smb_tx_get_share(tx: &mut SMBTransaction, @@ -125,40 +126,6 @@ pub extern "C" fn rs_smb_tx_match_dce_opnum(tx: &mut SMBTransaction, return 0; } -/* based on: - * typedef enum DetectDceIfaceOperators_ { - * DETECT_DCE_IFACE_OP_NONE = 0, - * DETECT_DCE_IFACE_OP_LT, - * DETECT_DCE_IFACE_OP_GT, - * DETECT_DCE_IFACE_OP_EQ, - * DETECT_DCE_IFACE_OP_NE, - * } DetectDceIfaceOperators; - */ -#[inline] -fn match_version(op: u8, them: u16, us: u16) -> bool { - let result = match op { - 0 => { // NONE - true - }, - 1 => { // LT - them < us - }, - 2 => { // GT - them > us - }, - 3 => { // EQ - them == us - }, - 4 => { // NE - them != us - }, - _ => { - panic!("called with invalid op {}", op); - }, - }; - result -} - /* mimic logic that is/was in the C code: * - match on REQUEST (so not on BIND/BINDACK (probably for mixing with * dce_opnum and dce_stub_data) @@ -170,8 +137,6 @@ pub extern "C" fn rs_smb_tx_get_dce_iface(state: &mut SMBState, -> u8 { let if_uuid = dce_data.if_uuid.as_slice(); - let if_op = dce_data.op; - let if_version = dce_data.version; let is_dcerpc_request = match tx.type_data { Some(SMBTransactionTypeData::DCERPC(ref x)) => { x.req_cmd == DCERPC_TYPE_REQUEST @@ -194,7 +159,11 @@ pub extern "C" fn rs_smb_tx_get_dce_iface(state: &mut SMBState, SCLogDebug!("stored UUID {:?} acked {} ack_result {}", i, i.acked, i.ack_result); if i.acked && i.ack_result == 0 && i.uuid == if_uuid { - if match_version(if_op as u8, if_version as u16, i.ver) { + if let Some(x) = &dce_data.du16 { + if detect_match_uint(&x, i.ver) { + return 1; + } + } else { return 1; } } diff --git a/src/detect-engine-prefilter-common.h b/src/detect-engine-prefilter-common.h index 6b5c79e7c7..f5ea765215 100644 --- a/src/detect-engine-prefilter-common.h +++ b/src/detect-engine-prefilter-common.h @@ -18,6 +18,8 @@ #ifndef __DETECT_ENGINE_PREFILTER_COMMON_H__ #define __DETECT_ENGINE_PREFILTER_COMMON_H__ +#include "rust.h" + typedef union { uint8_t u8[16]; uint16_t u16[8]; @@ -51,10 +53,10 @@ typedef struct PrefilterPacketU8HashCtx_ { SigsArray *array[256]; } PrefilterPacketU8HashCtx; -#define PREFILTER_U8HASH_MODE_EQ 0 -#define PREFILTER_U8HASH_MODE_LT 1 -#define PREFILTER_U8HASH_MODE_GT 2 -#define PREFILTER_U8HASH_MODE_RA 3 +#define PREFILTER_U8HASH_MODE_EQ DetectUintModeEqual +#define PREFILTER_U8HASH_MODE_LT DetectUintModeLt +#define PREFILTER_U8HASH_MODE_GT DetectUintModeGt +#define PREFILTER_U8HASH_MODE_RA DetectUintModeRange int PrefilterSetupPacketHeader(DetectEngineCtx *de_ctx, SigGroupHead *sgh, int sm_type, diff --git a/src/detect-engine-uint.c b/src/detect-engine-uint.c index c8cf28dd72..6388bb6992 100644 --- a/src/detect-engine-uint.c +++ b/src/detect-engine-uint.c @@ -28,79 +28,10 @@ #include "detect-parse.h" #include "detect-engine-uint.h" -/** - * \brief Regex for parsing our options - */ -#define PARSE_REGEX "^\\s*([0-9]*)?\\s*([<>=-]+)?\\s*([0-9]+)?\\s*$" - -static DetectParseRegex uint_pcre; - -int DetectU32Match(const uint32_t parg, const DetectU32Data *du32) +int DetectU32Match(const uint32_t parg, const DetectUintData_u32 *du32) { - switch (du32->mode) { - case DETECT_UINT_EQ: - if (parg == du32->arg1) { - return 1; - } - return 0; - case DETECT_UINT_LT: - if (parg < du32->arg1) { - return 1; - } - return 0; - case DETECT_UINT_LTE: - if (parg <= du32->arg1) { - return 1; - } - return 0; - case DETECT_UINT_GT: - if (parg > du32->arg1) { - return 1; - } - return 0; - case DETECT_UINT_GTE: - if (parg >= du32->arg1) { - return 1; - } - return 0; - case DETECT_UINT_RA: - if (parg > du32->arg1 && parg < du32->arg2) { - return 1; - } - return 0; - default: - BUG_ON(1); // unknown mode - } - return 0; -} - -static int DetectU32Validate(DetectU32Data *du32) -{ - switch (du32->mode) { - case DETECT_UINT_LT: - if (du32->arg1 == 0) { - return 1; - } - break; - case DETECT_UINT_GT: - if (du32->arg1 == UINT32_MAX) { - return 1; - } - break; - case DETECT_UINT_RA: - if (du32->arg1 >= du32->arg2) { - return 1; - } - // we need at least one value that can match parg > du32->arg1 && parg < du32->arg2 - if (du32->arg1 + 1 >= du32->arg2) { - return 1; - } - break; - default: - break; - } - return 0; + return rs_detect_u32_match(parg, du32); } /** @@ -112,157 +43,15 @@ static int DetectU32Validate(DetectU32Data *du32) * \retval NULL on failure */ -DetectU32Data *DetectU32Parse (const char *u32str) +DetectUintData_u32 *DetectU32Parse(const char *u32str) { - /* We initialize these to please static checkers, these values will - either be updated or not used later on */ - DetectU32Data u32da = {0, 0, 0}; - DetectU32Data *u32d = NULL; - char arg1[16] = ""; - char arg2[16] = ""; - char arg3[16] = ""; - - int ret = 0, res = 0; - size_t pcre2len; - - ret = DetectParsePcreExec(&uint_pcre, u32str, 0, 0); - if (ret < 2 || ret > 4) { - SCLogError(SC_ERR_PCRE_MATCH, "parse error, ret %" PRId32 "", ret); - return NULL; - } - - pcre2len = sizeof(arg1); - res = pcre2_substring_copy_bynumber(uint_pcre.match, 1, (PCRE2_UCHAR8 *)arg1, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg1 \"%s\"", arg1); - - if (ret >= 3) { - pcre2len = sizeof(arg2); - res = pcre2_substring_copy_bynumber(uint_pcre.match, 2, (PCRE2_UCHAR8 *)arg2, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg2 \"%s\"", arg2); - - if (ret >= 4) { - pcre2len = sizeof(arg3); - res = pcre2_substring_copy_bynumber( - uint_pcre.match, 3, (PCRE2_UCHAR8 *)arg3, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg3 \"%s\"", arg3); - } - } - - if (strlen(arg2) > 0) { - /*set the values*/ - switch(arg2[0]) { - case '<': - case '>': - if (strlen(arg2) == 1) { - if (strlen(arg3) == 0) - return NULL; - - if (ByteExtractStringUint32(&u32da.arg1, 10, strlen(arg3), arg3) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint32 failed"); - return NULL; - } - - SCLogDebug("u32 is %" PRIu32 "", u32da.arg1); - if (strlen(arg1) > 0) - return NULL; - - if (arg2[0] == '<') { - u32da.mode = DETECT_UINT_LT; - } else { // arg2[0] == '>' - u32da.mode = DETECT_UINT_GT; - } - break; - } else if (strlen(arg2) == 2) { - if (arg2[0] == '<' && arg2[1] == '=') { - u32da.mode = DETECT_UINT_LTE; - break; - } else if (arg2[0] == '>' || arg2[1] == '=') { - u32da.mode = DETECT_UINT_GTE; - break; - } else if (arg2[0] != '<' || arg2[1] != '>') { - return NULL; - } - } else { - return NULL; - } - // fall through - case '-': - if (strlen(arg1)== 0) - return NULL; - if (strlen(arg3)== 0) - return NULL; - - u32da.mode = DETECT_UINT_RA; - if (ByteExtractStringUint32(&u32da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint32 failed"); - return NULL; - } - if (ByteExtractStringUint32(&u32da.arg2, 10, strlen(arg3), arg3) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint32 failed"); - return NULL; - } - - SCLogDebug("u32 is %"PRIu32" to %"PRIu32"", u32da.arg1, u32da.arg2); - if (u32da.arg1 >= u32da.arg2) { - SCLogError(SC_ERR_INVALID_SIGNATURE, "Invalid u32 range. "); - return NULL; - } - break; - default: - u32da.mode = DETECT_UINT_EQ; - - if (strlen(arg2) > 0 || - strlen(arg3) > 0 || - strlen(arg1) == 0) - return NULL; - - if (ByteExtractStringUint32(&u32da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint32 failed"); - return NULL; - } - } - } else { - u32da.mode = DETECT_UINT_EQ; - - if (strlen(arg3) > 0 || - strlen(arg1) == 0) - return NULL; - - if (ByteExtractStringUint32(&u32da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint32 failed"); - return NULL; - } - } - if (DetectU32Validate(&u32da)) { - SCLogError(SC_ERR_INVALID_VALUE, "Impossible value for uint32 condition : %s", u32str); - return NULL; - } - u32d = SCCalloc(1, sizeof (DetectU32Data)); - if (unlikely(u32d == NULL)) - return NULL; - u32d->arg1 = u32da.arg1; - u32d->arg2 = u32da.arg2; - u32d->mode = u32da.mode; - - return u32d; + return rs_detect_u32_parse(u32str); } void PrefilterPacketU32Set(PrefilterPacketHeaderValue *v, void *smctx) { - const DetectU32Data *a = smctx; + const DetectUintData_u32 *a = smctx; v->u8[0] = a->mode; v->u32[1] = a->arg1; v->u32[2] = a->arg2; @@ -271,7 +60,7 @@ PrefilterPacketU32Set(PrefilterPacketHeaderValue *v, void *smctx) bool PrefilterPacketU32Compare(PrefilterPacketHeaderValue v, void *smctx) { - const DetectU32Data *a = smctx; + const DetectUintData_u32 *a = smctx; if (v.u8[0] == a->mode && v.u32[1] == a->arg1 && v.u32[2] == a->arg2) @@ -279,83 +68,10 @@ PrefilterPacketU32Compare(PrefilterPacketHeaderValue v, void *smctx) return false; } -static bool g_detect_uint_registered = false; - -void DetectUintRegister(void) -{ - if (g_detect_uint_registered == false) { - // register only once - DetectSetupParseRegexes(PARSE_REGEX, &uint_pcre); - g_detect_uint_registered = true; - } -} - //same as u32 but with u8 -int DetectU8Match(const uint8_t parg, const DetectU8Data *du8) -{ - switch (du8->mode) { - case DETECT_UINT_EQ: - if (parg == du8->arg1) { - return 1; - } - return 0; - case DETECT_UINT_LT: - if (parg < du8->arg1) { - return 1; - } - return 0; - case DETECT_UINT_LTE: - if (parg <= du8->arg1) { - return 1; - } - return 0; - case DETECT_UINT_GT: - if (parg > du8->arg1) { - return 1; - } - return 0; - case DETECT_UINT_GTE: - if (parg >= du8->arg1) { - return 1; - } - return 0; - case DETECT_UINT_RA: - if (parg > du8->arg1 && parg < du8->arg2) { - return 1; - } - return 0; - default: - BUG_ON(1); // unknown mode - } - return 0; -} - -static int DetectU8Validate(DetectU8Data *du8) +int DetectU8Match(const uint8_t parg, const DetectUintData_u8 *du8) { - switch (du8->mode) { - case DETECT_UINT_LT: - if (du8->arg1 == 0) { - return 1; - } - break; - case DETECT_UINT_GT: - if (du8->arg1 == UINT8_MAX) { - return 1; - } - break; - case DETECT_UINT_RA: - if (du8->arg1 >= du8->arg2) { - return 1; - } - // we need at least one value that can match parg > du8->arg1 && parg < du8->arg2 - if (du8->arg1 + 1 >= du8->arg2) { - return 1; - } - break; - default: - break; - } - return 0; + return rs_detect_u8_match(parg, du8); } /** @@ -367,139 +83,59 @@ static int DetectU8Validate(DetectU8Data *du8) * \retval NULL on failure */ -DetectU8Data *DetectU8Parse (const char *u8str) +DetectUintData_u8 *DetectU8Parse(const char *u8str) { - /* We initialize these to please static checkers, these values will - either be updated or not used later on */ - DetectU8Data u8da = {0, 0, 0}; - DetectU8Data *u8d = NULL; - char arg1[16] = ""; - char arg2[16] = ""; - char arg3[16] = ""; - - int ret = 0, res = 0; - size_t pcre2len; - - ret = DetectParsePcreExec(&uint_pcre, u8str, 0, 0); - if (ret < 2 || ret > 4) { - SCLogError(SC_ERR_PCRE_MATCH, "parse error, ret %" PRId32 "", ret); - return NULL; - } - - pcre2len = sizeof(arg1); - res = pcre2_substring_copy_bynumber(uint_pcre.match, 1, (PCRE2_UCHAR8 *)arg1, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg1 \"%s\"", arg1); - - if (ret >= 3) { - pcre2len = sizeof(arg2); - res = pcre2_substring_copy_bynumber(uint_pcre.match, 2, (PCRE2_UCHAR8 *)arg2, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg2 \"%s\"", arg2); - - if (ret >= 4) { - pcre2len = sizeof(arg3); - res = pcre2_substring_copy_bynumber( - uint_pcre.match, 3, (PCRE2_UCHAR8 *)arg3, &pcre2len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_COPY_SUBSTRING, "pcre2_substring_copy_bynumber failed"); - return NULL; - } - SCLogDebug("Arg3 \"%s\"", arg3); - } - } - - if (strlen(arg2) > 0) { - /*set the values*/ - switch(arg2[0]) { - case '<': - case '>': - if (strlen(arg2) == 1) { - if (StringParseUint8(&u8da.arg1, 10, strlen(arg3), arg3) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint8 failed"); - return NULL; - } - - SCLogDebug("u8 is %" PRIu8 "", u8da.arg1); - if (strlen(arg1) > 0) - return NULL; + return rs_detect_u8_parse(u8str); +} - if (arg2[0] == '<') { - u8da.mode = DETECT_UINT_LT; - } else { // arg2[0] == '>' - u8da.mode = DETECT_UINT_GT; - } - break; - } else if (strlen(arg2) == 2) { - if (arg2[0] == '<' && arg2[1] == '=') { - u8da.mode = DETECT_UINT_LTE; - break; - } else if (arg2[0] == '>' || arg2[1] == '=') { - u8da.mode = DETECT_UINT_GTE; - break; - } else if (arg2[0] != '<' || arg2[1] != '>') { - return NULL; - } - } else { - return NULL; - } - // fall through - case '-': - u8da.mode = DETECT_UINT_RA; - if (StringParseUint8(&u8da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint8 failed"); - return NULL; - } - if (StringParseUint8(&u8da.arg2, 10, strlen(arg3), arg3) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint8 failed"); - return NULL; - } +void PrefilterPacketU8Set(PrefilterPacketHeaderValue *v, void *smctx) +{ + const DetectUintData_u8 *a = smctx; + v->u8[0] = a->mode; + v->u8[1] = a->arg1; + v->u8[2] = a->arg2; +} - SCLogDebug("u8 is %"PRIu8" to %"PRIu8"", u8da.arg1, u8da.arg2); - if (u8da.arg1 >= u8da.arg2) { - SCLogError(SC_ERR_INVALID_SIGNATURE, "Invalid u8 range. "); - return NULL; - } - break; - default: - u8da.mode = DETECT_UINT_EQ; +bool PrefilterPacketU8Compare(PrefilterPacketHeaderValue v, void *smctx) +{ + const DetectUintData_u8 *a = smctx; + if (v.u8[0] == a->mode && v.u8[1] == a->arg1 && v.u8[2] == a->arg2) + return true; + return false; +} - if (strlen(arg2) > 0 || - strlen(arg3) > 0) - return NULL; +// same as u32 but with u16 +int DetectU16Match(const uint16_t parg, const DetectUintData_u16 *du16) +{ + return rs_detect_u16_match(parg, du16); +} - if (StringParseUint8(&u8da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint8 failed"); - return NULL; - } - } - } else { - u8da.mode = DETECT_UINT_EQ; +/** + * \brief This function is used to parse u16 options passed via some u16 keyword + * + * \param u16str Pointer to the user provided u16 options + * + * \retval DetectU16Data pointer to DetectU16Data on success + * \retval NULL on failure + */ - if (strlen(arg3) > 0) - return NULL; +DetectUintData_u16 *DetectU16Parse(const char *u16str) +{ + return rs_detect_u16_parse(u16str); +} - if (StringParseUint8(&u8da.arg1, 10, strlen(arg1), arg1) < 0) { - SCLogError(SC_ERR_BYTE_EXTRACT_FAILED, "ByteExtractStringUint8 failed"); - return NULL; - } - } - if (DetectU8Validate(&u8da)) { - SCLogError(SC_ERR_INVALID_VALUE, "Impossible value for uint8 condition : %s", u8str); - return NULL; - } - u8d = SCCalloc(1, sizeof (DetectU8Data)); - if (unlikely(u8d == NULL)) - return NULL; - u8d->arg1 = u8da.arg1; - u8d->arg2 = u8da.arg2; - u8d->mode = u8da.mode; +void PrefilterPacketU16Set(PrefilterPacketHeaderValue *v, void *smctx) +{ + const DetectUintData_u16 *a = smctx; + v->u8[0] = a->mode; + v->u16[1] = a->arg1; + v->u16[2] = a->arg2; +} - return u8d; +bool PrefilterPacketU16Compare(PrefilterPacketHeaderValue v, void *smctx) +{ + const DetectUintData_u16 *a = smctx; + if (v.u8[0] == a->mode && v.u16[1] == a->arg1 && v.u16[2] == a->arg2) + return true; + return false; } diff --git a/src/detect-engine-uint.h b/src/detect-engine-uint.h index e3fc7f5b16..e20c3f713d 100644 --- a/src/detect-engine-uint.h +++ b/src/detect-engine-uint.h @@ -24,39 +24,36 @@ #ifndef __DETECT_ENGINE_UINT_H #define __DETECT_ENGINE_UINT_H +#include "rust.h" #include "detect-engine-prefilter-common.h" -typedef enum { - DETECT_UINT_LT = PREFILTER_U8HASH_MODE_LT, - DETECT_UINT_EQ = PREFILTER_U8HASH_MODE_EQ, - DETECT_UINT_GT = PREFILTER_U8HASH_MODE_GT, - DETECT_UINT_RA = PREFILTER_U8HASH_MODE_RA, - DETECT_UINT_LTE, - DETECT_UINT_GTE, -} DetectUintMode; - -typedef struct DetectU32Data_ { - uint32_t arg1; /**< first arg value in the signature*/ - uint32_t arg2; /**< second arg value in the signature, in case of range - operator*/ - DetectUintMode mode; /**< operator used in the signature */ -} DetectU32Data; - -int DetectU32Match(const uint32_t parg, const DetectU32Data *du32); -DetectU32Data *DetectU32Parse (const char *u32str); +// These definitions are kept to minimize the diff +// We can run a big sed commit next +#define DETECT_UINT_GT DetectUintModeGt +#define DETECT_UINT_GTE DetectUintModeGte +#define DETECT_UINT_RA DetectUintModeRange +#define DETECT_UINT_EQ DetectUintModeEqual +#define DETECT_UINT_NE DetectUintModeNe +#define DETECT_UINT_LT DetectUintModeLt +#define DETECT_UINT_LTE DetectUintModeLte + +typedef DetectUintData_u32 DetectU32Data; +typedef DetectUintData_u16 DetectU16Data; +typedef DetectUintData_u8 DetectU8Data; + +int DetectU32Match(const uint32_t parg, const DetectUintData_u32 *du32); +DetectUintData_u32 *DetectU32Parse(const char *u32str); void PrefilterPacketU32Set(PrefilterPacketHeaderValue *v, void *smctx); bool PrefilterPacketU32Compare(PrefilterPacketHeaderValue v, void *smctx); -void DetectUintRegister(void); +int DetectU8Match(const uint8_t parg, const DetectUintData_u8 *du8); +DetectUintData_u8 *DetectU8Parse(const char *u8str); +void PrefilterPacketU8Set(PrefilterPacketHeaderValue *v, void *smctx); +bool PrefilterPacketU8Compare(PrefilterPacketHeaderValue v, void *smctx); -typedef struct DetectU8Data_ { - uint8_t arg1; /**< first arg value in the signature*/ - uint8_t arg2; /**< second arg value in the signature, in case of range - operator*/ - DetectUintMode mode; /**< operator used in the signature */ -} DetectU8Data; - -int DetectU8Match(const uint8_t parg, const DetectU8Data *du8); -DetectU8Data *DetectU8Parse (const char *u8str); +int DetectU16Match(const uint16_t parg, const DetectUintData_u16 *du16); +DetectUintData_u16 *DetectU16Parse(const char *u16str); +void PrefilterPacketU16Set(PrefilterPacketHeaderValue *v, void *smctx); +bool PrefilterPacketU16Compare(PrefilterPacketHeaderValue v, void *smctx); #endif /* __DETECT_UTIL_UINT_H */ diff --git a/src/detect-http2.c b/src/detect-http2.c index cffbcbbfa4..5e4a41724b 100644 --- a/src/detect-http2.c +++ b/src/detect-http2.c @@ -244,8 +244,6 @@ void DetectHttp2Register(void) "http2", ALPROTO_HTTP2, SIG_FLAG_TOCLIENT, 0, DetectEngineInspectHTTP2, NULL); g_http2_match_buffer_id = DetectBufferTypeRegister("http2"); - DetectUintRegister(); - return; } @@ -465,7 +463,7 @@ static int DetectHTTP2prioritySetup (DetectEngineCtx *de_ctx, Signature *s, cons SigMatch *sm = SigMatchAlloc(); if (sm == NULL) { - SCFree(prio); + rs_detect_u8_free(prio); return -1; } @@ -484,7 +482,7 @@ static int DetectHTTP2prioritySetup (DetectEngineCtx *de_ctx, Signature *s, cons */ void DetectHTTP2priorityFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u8_free(ptr); } /** @@ -532,7 +530,7 @@ static int DetectHTTP2windowSetup (DetectEngineCtx *de_ctx, Signature *s, const SigMatch *sm = SigMatchAlloc(); if (sm == NULL) { - SCFree(wu); + rs_detect_u32_free(wu); return -1; } @@ -551,7 +549,7 @@ static int DetectHTTP2windowSetup (DetectEngineCtx *de_ctx, Signature *s, const */ void DetectHTTP2windowFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u32_free(ptr); } /** diff --git a/src/detect-icmpv6-mtu.c b/src/detect-icmpv6-mtu.c index 46bdb0363f..ecb87343d1 100644 --- a/src/detect-icmpv6-mtu.c +++ b/src/detect-icmpv6-mtu.c @@ -58,8 +58,6 @@ void DetectICMPv6mtuRegister(void) #endif sigmatch_table[DETECT_ICMPV6MTU].SupportsPrefilter = PrefilterIcmpv6mtuIsPrefilterable; sigmatch_table[DETECT_ICMPV6MTU].SetupPrefilter = PrefilterSetupIcmpv6mtu; - - DetectUintRegister(); return; } @@ -139,7 +137,7 @@ static int DetectICMPv6mtuSetup (DetectEngineCtx *de_ctx, Signature *s, const ch */ void DetectICMPv6mtuFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u32_free(ptr); } /* prefilter code */ diff --git a/src/detect-icode.c b/src/detect-icode.c index 7058853c31..6a42c0c40b 100644 --- a/src/detect-icode.c +++ b/src/detect-icode.c @@ -136,7 +136,7 @@ static int DetectICodeSetup(DetectEngineCtx *de_ctx, Signature *s, const char *i error: if (icd != NULL) - SCFree(icd); + rs_detect_u8_free(icd); if (sm != NULL) SCFree(sm); return -1; } @@ -148,7 +148,7 @@ error: */ void DetectICodeFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u8_free(ptr); } /* prefilter code */ @@ -177,30 +177,10 @@ static void PrefilterPacketICodeMatch(DetectEngineThreadCtx *det_ctx, } } -static void -PrefilterPacketICodeSet(PrefilterPacketHeaderValue *v, void *smctx) -{ - const DetectU8Data *a = smctx; - v->u8[0] = a->mode; - v->u8[1] = a->arg1; - v->u8[2] = a->arg2; -} - -static bool -PrefilterPacketICodeCompare(PrefilterPacketHeaderValue v, void *smctx) -{ - const DetectU8Data *a = smctx; - if (v.u8[0] == a->mode && v.u8[1] == a->arg1 && v.u8[2] == a->arg2) - return true; - return false; -} - static int PrefilterSetupICode(DetectEngineCtx *de_ctx, SigGroupHead *sgh) { - return PrefilterSetupPacketHeaderU8Hash(de_ctx, sgh, DETECT_ICODE, - PrefilterPacketICodeSet, - PrefilterPacketICodeCompare, - PrefilterPacketICodeMatch); + return PrefilterSetupPacketHeaderU8Hash(de_ctx, sgh, DETECT_ICODE, PrefilterPacketU8Set, + PrefilterPacketU8Compare, PrefilterPacketICodeMatch); } static bool PrefilterICodeIsPrefilterable(const Signature *s) diff --git a/src/detect-ike-exch-type.c b/src/detect-ike-exch-type.c index 8132e0a24f..729d584864 100644 --- a/src/detect-ike-exch-type.c +++ b/src/detect-ike-exch-type.c @@ -68,8 +68,6 @@ void DetectIkeExchTypeRegister(void) DetectEngineInspectIkeExchTypeGeneric, NULL); g_ike_exch_type_buffer_id = DetectBufferTypeGetByName("ike.exchtype"); - - DetectUintRegister(); } static int DetectEngineInspectIkeExchTypeGeneric(DetectEngineCtx *de_ctx, @@ -152,5 +150,5 @@ error: */ static void DetectIkeExchTypeFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u8_free(ptr); } diff --git a/src/detect-ike-key-exchange-payload-length.c b/src/detect-ike-key-exchange-payload-length.c index 65580e8393..5fbb615ada 100644 --- a/src/detect-ike-key-exchange-payload-length.c +++ b/src/detect-ike-key-exchange-payload-length.c @@ -73,8 +73,6 @@ void DetectIkeKeyExchangePayloadLengthRegister(void) g_ike_key_exch_payload_length_buffer_id = DetectBufferTypeGetByName("ike.key_exchange_payload_length"); - - DetectUintRegister(); } static int DetectEngineInspectIkeKeyExchangePayloadLengthGeneric(DetectEngineCtx *de_ctx, @@ -158,5 +156,5 @@ error: */ static void DetectIkeKeyExchangePayloadLengthFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u32_free(ptr); } diff --git a/src/detect-ike-nonce-payload-length.c b/src/detect-ike-nonce-payload-length.c index d7091b2472..7ec7b2ce0d 100644 --- a/src/detect-ike-nonce-payload-length.c +++ b/src/detect-ike-nonce-payload-length.c @@ -68,8 +68,6 @@ void DetectIkeNoncePayloadLengthRegister(void) 1, DetectEngineInspectIkeNoncePayloadLengthGeneric, NULL); g_ike_nonce_payload_length_buffer_id = DetectBufferTypeGetByName("ike.nonce_payload_length"); - - DetectUintRegister(); } static int DetectEngineInspectIkeNoncePayloadLengthGeneric(DetectEngineCtx *de_ctx, @@ -152,5 +150,5 @@ error: */ static void DetectIkeNoncePayloadLengthFree(DetectEngineCtx *de_ctx, void *ptr) { - SCFree(ptr); + rs_detect_u32_free(ptr); } diff --git a/src/detect-mqtt-protocol-version.c b/src/detect-mqtt-protocol-version.c index 8a09d36755..0483bfa3e2 100644 --- a/src/detect-mqtt-protocol-version.c +++ b/src/detect-mqtt-protocol-version.c @@ -141,7 +141,7 @@ static int DetectMQTTProtocolVersionSetup(DetectEngineCtx *de_ctx, Signature *s, error: if (de != NULL) - SCFree(de); + rs_detect_u8_free(de); if (sm != NULL) SCFree(sm); return -1; @@ -155,8 +155,7 @@ error: */ void DetectMQTTProtocolVersionFree(DetectEngineCtx *de_ctx, void *de_ptr) { - if (de_ptr != NULL) - SCFree(de_ptr); + rs_detect_u8_free(de_ptr); } /*