]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
transforms: move base64 to rust 13361/head
authorPhilippe Antoine <pantoine@oisf.net>
Tue, 3 Jun 2025 11:32:47 +0000 (13:32 +0200)
committerVictor Julien <victor@inliniac.net>
Wed, 4 Jun 2025 07:39:53 +0000 (09:39 +0200)
Ticket: 7733

rust/src/detect/mod.rs
rust/src/detect/transforms/base64.rs
src/Makefile.am
src/detect-engine-register.c
src/detect-engine-register.h
src/detect-transform-base64.c [deleted file]
src/detect-transform-base64.h [deleted file]

index d1b26e6950fc82438767bae5dbd08b8f6251d428..9ca6d426fc45fa1cf09eeb072da769140cc4d26d 100644 (file)
@@ -113,7 +113,9 @@ pub unsafe extern "C" fn SCDetectSigMatchNamesFree(kw: &mut SCSigTableNamesElmt)
     let _ = CString::from_raw(kw.url);
 }
 
+// TODO bindgen these
 pub const SIGMATCH_NOOPT: u16 = 1; // BIT_U16(0) in detect.h
+pub(crate) const SIGMATCH_OPTIONAL_OPT: u16 = 0x10; // BIT_U16(4) in detect.h
 pub(crate) const SIGMATCH_QUOTES_MANDATORY: u16 = 0x40; // BIT_U16(6) in detect.h
 pub const SIGMATCH_INFO_STICKY_BUFFER: u16 = 0x200; // BIT_U16(9)
 
index 9ac974ac31b0a2627e34259ec961654d974fd5bc..ee3ec1ba155692d004e6413642a8eb1aeb4e85c3 100644 (file)
 
 use crate::detect::error::RuleParseError;
 use crate::detect::parser::{parse_var, take_until_whitespace, ResultValue};
-use crate::ffi::base64::SCBase64Mode;
-use std::ffi::{CStr, CString};
-use std::os::raw::c_char;
+use crate::detect::SIGMATCH_OPTIONAL_OPT;
+use crate::ffi::base64::{SCBase64Decode, SCBase64Mode};
+use crate::utils::base64::get_decoded_buffer_size;
+
+#[cfg(test)]
+use crate::detect::transforms::base64::tests::{
+    SCInspectionBufferCheckAndExpand, SCInspectionBufferTruncate,
+};
+use suricata_sys::sys::{
+    DetectEngineCtx, DetectEngineThreadCtx, InspectionBuffer, SCDetectHelperTransformRegister,
+    SCDetectSignatureAddTransform, SCTransformTableElmt, Signature,
+};
+#[cfg(not(test))]
+use suricata_sys::sys::{SCInspectionBufferCheckAndExpand, SCInspectionBufferTruncate};
 
 use nom7::bytes::complete::tag;
 use nom7::character::complete::multispace0;
 use nom7::sequence::preceded;
 use nom7::{Err, IResult};
+
+use std::ffi::CStr;
+use std::os::raw::{c_char, c_int, c_void};
 use std::str;
 
-pub const TRANSFORM_FROM_BASE64_MODE_DEFAULT: SCBase64Mode = SCBase64Mode::SCBase64ModeRFC4648;
+const TRANSFORM_FROM_BASE64_MODE_DEFAULT: SCBase64Mode = SCBase64Mode::SCBase64ModeRFC4648;
 
 const DETECT_TRANSFORM_BASE64_MAX_PARAM_COUNT: usize = 3;
-pub const DETECT_TRANSFORM_BASE64_FLAG_MODE: u8 = 0x01;
-pub const DETECT_TRANSFORM_BASE64_FLAG_NBYTES: u8 = 0x02;
-pub const DETECT_TRANSFORM_BASE64_FLAG_OFFSET: u8 = 0x04;
-pub const DETECT_TRANSFORM_BASE64_FLAG_OFFSET_VAR: u8 = 0x08;
-pub const DETECT_TRANSFORM_BASE64_FLAG_NBYTES_VAR: u8 = 0x10;
+const DETECT_TRANSFORM_BASE64_FLAG_MODE: u8 = 0x01;
+const DETECT_TRANSFORM_BASE64_FLAG_NBYTES: u8 = 0x02;
+const DETECT_TRANSFORM_BASE64_FLAG_OFFSET: u8 = 0x04;
 
+// repr C to ensure a stable layout
+// good field ordering to avoid padding as rust does not have stable zeroed allocs
 #[repr(C)]
-#[derive(Debug)]
-pub struct SCDetectTransformFromBase64Data {
-    flags: u8,
+#[derive(Debug, PartialEq)]
+struct DetectTransformFromBase64Data {
     nbytes: u32,
-    nbytes_str: *const c_char,
     offset: u32,
-    offset_str: *const c_char,
-    mode: SCBase64Mode,
-
-    // serialized data for hashing
-    serialized: *const u8,
-    serialized_len: u32,
+    mode: SCBase64Mode, // repr u8
+    flags: u8,
 }
 
-impl Drop for SCDetectTransformFromBase64Data {
-    fn drop(&mut self) {
-        unsafe {
-            if !self.serialized.is_null() {
-                drop(Vec::from_raw_parts(
-                    self.serialized as *mut u8,
-                    self.serialized_len as usize,
-                    self.serialized_len as usize,
-                ));
-            }
-            if !self.offset_str.is_null() {
-                let _ = CString::from_raw(self.offset_str as *mut c_char);
-            }
-            if !self.nbytes_str.is_null() {
-                let _ = CString::from_raw(self.nbytes_str as *mut c_char);
-            }
-        }
-    }
-}
-impl Default for SCDetectTransformFromBase64Data {
+impl Default for DetectTransformFromBase64Data {
     fn default() -> Self {
-        SCDetectTransformFromBase64Data {
-            flags: 0,
+        DetectTransformFromBase64Data {
+            mode: TRANSFORM_FROM_BASE64_MODE_DEFAULT,
             nbytes: 0,
-            nbytes_str: std::ptr::null_mut(),
             offset: 0,
-            offset_str: std::ptr::null_mut(),
-            mode: TRANSFORM_FROM_BASE64_MODE_DEFAULT,
-            serialized: std::ptr::null_mut(),
-            serialized_len: 0,
-        }
-    }
-}
-
-impl SCDetectTransformFromBase64Data {
-    pub fn new() -> Self {
-        Self {
-            ..Default::default()
+            flags: 0,
         }
     }
-
-    pub fn serialize(&mut self, nbytes: &str, offset: &str) {
-        let mut r = Vec::with_capacity(12 + nbytes.len() + offset.len());
-        r.push(self.flags);
-        r.push(self.mode as u8);
-        r.extend_from_slice(&self.nbytes.to_le_bytes());
-        r.extend_from_slice(&self.offset.to_le_bytes());
-        r.push(nbytes.len() as u8);
-        r.extend_from_slice(nbytes.as_bytes());
-        r.push(offset.len() as u8);
-        r.extend_from_slice(offset.as_bytes());
-        self.serialized_len = r.len() as u32;
-        let ptr = r.as_mut_ptr();
-        std::mem::forget(r);
-        self.serialized = ptr;
-    }
 }
 
 fn get_mode_value(value: &str) -> Option<SCBase64Mode> {
@@ -124,12 +85,12 @@ fn get_mode_value(value: &str) -> Option<SCBase64Mode> {
 
 fn parse_transform_base64(
     input: &str,
-) -> IResult<&str, SCDetectTransformFromBase64Data, RuleParseError<&str>> {
+) -> IResult<&str, DetectTransformFromBase64Data, RuleParseError<&str>> {
     // Inner utility function for easy error creation.
     fn make_error(reason: String) -> nom7::Err<RuleParseError<&'static str>> {
         Err::Error(RuleParseError::InvalidTransformBase64(reason))
     }
-    let mut transform_base64 = SCDetectTransformFromBase64Data::new();
+    let mut transform_base64 = DetectTransformFromBase64Data::default();
 
     // No options so return defaults
     if input.is_empty() {
@@ -146,8 +107,6 @@ fn parse_transform_base64(
             DETECT_TRANSFORM_BASE64_MAX_PARAM_COUNT, input)));
     }
 
-    let mut nbytes_str = String::new();
-    let mut offset_str = String::new();
     for value in values {
         let (mut val, mut name) = take_until_whitespace(value)?;
         val = val.trim();
@@ -183,9 +142,11 @@ fn parse_transform_base64(
                             )));
                         }
                     }
-                    ResultValue::String(val) => {
-                        offset_str = val.clone();
-                        transform_base64.flags |= DETECT_TRANSFORM_BASE64_FLAG_OFFSET_VAR;
+                    ResultValue::String(_val) => {
+                        SCLogError!("offset value must be a value, not a variable name");
+                        return Err(make_error(
+                            "offset value must be a value, not a variable name".to_string(),
+                        ));
                     }
                 }
 
@@ -210,9 +171,11 @@ fn parse_transform_base64(
                             )));
                         }
                     }
-                    ResultValue::String(val) => {
-                        nbytes_str = val.clone();
-                        transform_base64.flags |= DETECT_TRANSFORM_BASE64_FLAG_NBYTES_VAR;
+                    ResultValue::String(_val) => {
+                        SCLogError!("byte value must be a value, not a variable name");
+                        return Err(make_error(
+                            "byte value must be a value, not a variable name".to_string(),
+                        ));
                     }
                 }
                 transform_base64.flags |= DETECT_TRANSFORM_BASE64_FLAG_NBYTES;
@@ -223,215 +186,287 @@ fn parse_transform_base64(
         };
     }
 
-    transform_base64.serialize(&nbytes_str, &offset_str);
-    if (transform_base64.flags & DETECT_TRANSFORM_BASE64_FLAG_NBYTES_VAR) != 0 {
-        if let Ok(newval) = CString::new(nbytes_str) {
-            transform_base64.nbytes_str = newval.into_raw();
-        } else {
-            return Err(make_error(
-                "parse string not safely convertible to C".to_string(),
-            ));
-        }
-    }
-    if (transform_base64.flags & DETECT_TRANSFORM_BASE64_FLAG_OFFSET_VAR) != 0 {
-        if let Ok(newval) = CString::new(offset_str) {
-            transform_base64.offset_str = newval.into_raw();
-        } else {
-            return Err(make_error(
-                "parse string not safely convertible to C".to_string(),
-            ));
-        }
-    }
     Ok((input, transform_base64))
 }
 
-/// Intermediary function between the C code and the parsing functions.
-#[no_mangle]
-pub unsafe extern "C" fn SCTransformBase64Parse(
-    c_arg: *const c_char,
-) -> *mut SCDetectTransformFromBase64Data {
+unsafe fn base64_parse(c_arg: *const c_char) -> *mut DetectTransformFromBase64Data {
     if c_arg.is_null() {
         return std::ptr::null_mut();
     }
 
-    let arg = CStr::from_ptr(c_arg).to_str().unwrap_or("");
+    if let Ok(arg) = CStr::from_ptr(c_arg).to_str() {
+        match parse_transform_base64(arg) {
+            Ok((_, detect)) => return Box::into_raw(Box::new(detect)),
+            Err(_) => return std::ptr::null_mut(),
+        }
+    }
+    return std::ptr::null_mut();
+}
+
+unsafe extern "C" fn base64_free(_de: *mut DetectEngineCtx, ctx: *mut c_void) {
+    std::mem::drop(Box::from_raw(ctx as *mut DetectTransformFromBase64Data));
+}
+
+static mut G_TRANSFORM_BASE64_ID: c_int = 0;
 
-    match parse_transform_base64(arg) {
-        Ok((_, detect)) => return Box::into_raw(Box::new(detect)),
-        Err(_) => return std::ptr::null_mut(),
+unsafe extern "C" fn base64_setup(
+    de: *mut DetectEngineCtx, s: *mut Signature, opt_str: *const std::os::raw::c_char,
+) -> c_int {
+    let ctx = base64_parse(opt_str) as *mut c_void;
+    if ctx.is_null() {
+        return -1;
     }
+    let r = SCDetectSignatureAddTransform(s, G_TRANSFORM_BASE64_ID, ctx);
+    if r != 0 {
+        base64_free(de, ctx);
+    }
+    return r;
 }
 
-#[no_mangle]
-pub unsafe extern "C" fn SCTransformBase64Free(ptr: *mut SCDetectTransformFromBase64Data) {
-    if !ptr.is_null() {
-        let _ = Box::from_raw(ptr);
+unsafe extern "C" fn base64_id(data: *mut *const u8, length: *mut u32, ctx: *mut c_void) {
+    if data.is_null() || length.is_null() || ctx.is_null() {
+        return;
     }
+
+    // This works because the structure is flat
+    // Once variables are really implemented, we should investigate if the structure should own
+    // its serialization or just borrow it to a caller
+    *data = ctx as *const u8;
+    *length = std::mem::size_of::<DetectTransformFromBase64Data>() as u32;
 }
 
-#[cfg(test)]
-mod tests {
-    use super::*;
-    // structure equality only used by test cases
-    impl PartialEq for SCDetectTransformFromBase64Data {
-        fn eq(&self, other: &Self) -> bool {
-            let mut res: bool = true;
-
-            if !self.nbytes_str.is_null() && !other.nbytes_str.is_null() {
-                let s_val = unsafe { CStr::from_ptr(self.nbytes_str) };
-                let o_val = unsafe { CStr::from_ptr(other.nbytes_str) };
-                res = s_val == o_val;
-            } else if !self.nbytes_str.is_null() || !other.nbytes_str.is_null() {
-                return false;
-            }
+unsafe extern "C" fn base64_transform(
+    _det: *mut DetectEngineThreadCtx, buffer: *mut InspectionBuffer, ctx: *mut c_void,
+) {
+    let input = (*buffer).inspect;
+    let input_len = (*buffer).inspect_len;
+    if input.is_null() || input_len == 0 {
+        return;
+    }
+    let mut input = build_slice!(input, input_len as usize);
 
-            if !self.offset_str.is_null() && !other.offset_str.is_null() {
-                let s_val = unsafe { CStr::from_ptr(self.offset_str) };
-                let o_val = unsafe { CStr::from_ptr(other.offset_str) };
-                res = s_val == o_val;
-            } else if !self.offset_str.is_null() || !other.offset_str.is_null() {
-                return false;
-            }
+    let ctx = cast_pointer!(ctx, DetectTransformFromBase64Data);
 
-            res && self.nbytes == other.nbytes
-                && self.flags == other.flags
-                && self.offset == other.offset
-                && self.mode == other.mode
+    if ctx.offset > 0 {
+        if ctx.offset >= input_len {
+            return;
+        }
+        input = &input[ctx.offset as usize..];
+    }
+    if ctx.nbytes > 0 {
+        if ctx.nbytes as usize >= input.len() {
+            return;
         }
+        input = &input[..ctx.nbytes as usize];
     }
 
-    fn valid_test(
-        args: &str, nbytes: u32, nbytes_str: &str, offset: u32, offset_str: &str,
-        mode: SCBase64Mode, flags: u8,
-    ) {
-        let tbd = SCDetectTransformFromBase64Data {
-            flags,
-            nbytes,
-            nbytes_str: if !nbytes_str.is_empty() {
-                CString::new(nbytes_str).unwrap().into_raw()
-            } else {
-                std::ptr::null_mut()
-            },
-            offset,
-            offset_str: if !offset_str.is_empty() {
-                CString::new(offset_str).unwrap().into_raw()
-            } else {
-                std::ptr::null_mut()
-            },
-            mode,
-            serialized: std::ptr::null_mut(),
-            serialized_len: 0,
-        };
+    let output_len = get_decoded_buffer_size(input.len() as u32);
+    // no realloc, we only can shrink
+    let output = SCInspectionBufferCheckAndExpand(buffer, output_len);
+    if output.is_null() {
+        // allocation failure
+        return;
+    }
 
-        let (_, val) = parse_transform_base64(args).unwrap();
-        assert_eq!(val, tbd);
+    let num_decoded = SCBase64Decode(input.as_ptr(), input.len(), ctx.mode, output);
+    if num_decoded > 0 {
+        SCInspectionBufferTruncate(buffer, num_decoded);
     }
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn DetectTransformFromBase64DecodeRegister() {
+    let kw = SCTransformTableElmt {
+        name: b"from_base64\0".as_ptr() as *const libc::c_char,
+        desc: b"convert the base64 decode of the buffer\0".as_ptr() as *const libc::c_char,
+        url: b"/rules/transforms.html#from_base64\0".as_ptr() as *const libc::c_char,
+        Setup: Some(base64_setup),
+        flags: SIGMATCH_OPTIONAL_OPT,
+        Transform: Some(base64_transform),
+        Free: Some(base64_free),
+        TransformValidate: None,
+        TransformId: Some(base64_id),
+    };
+    unsafe {
+        G_TRANSFORM_BASE64_ID = SCDetectHelperTransformRegister(&kw);
+        if G_TRANSFORM_BASE64_ID < 0 {
+            SCLogWarning!("Failed registering transform base64");
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
 
     #[test]
     fn test_parser_invalid() {
         assert!(parse_transform_base64("bytes 4, offset 3933, mode unknown").is_err());
         assert!(parse_transform_base64("bytes 4, offset 70000, mode strict").is_err());
-        assert!(
-            parse_transform_base64("bytes 4, offset 70000, mode strict, mode rfc2045").is_err()
-        );
+        assert!(parse_transform_base64("bytes 4, offset 3933, mode strict, mode rfc2045").is_err());
     }
 
     #[test]
     fn test_parser_parse_partial_valid() {
-        let mut tbd = SCDetectTransformFromBase64Data {
+        let (_, val) = parse_transform_base64("bytes 4").unwrap();
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                nbytes: 4,
+                offset: 0,
+                mode: TRANSFORM_FROM_BASE64_MODE_DEFAULT,
+                flags: DETECT_TRANSFORM_BASE64_FLAG_NBYTES,
+            }
+        );
+
+        let args = DetectTransformFromBase64Data {
             nbytes: 4,
-            offset: 0,
+            offset: 3933,
             mode: TRANSFORM_FROM_BASE64_MODE_DEFAULT,
-            flags: 0,
-            ..Default::default()
+            flags: DETECT_TRANSFORM_BASE64_FLAG_NBYTES | DETECT_TRANSFORM_BASE64_FLAG_OFFSET,
         };
-
-        tbd.mode = TRANSFORM_FROM_BASE64_MODE_DEFAULT;
-        tbd.flags = DETECT_TRANSFORM_BASE64_FLAG_NBYTES;
-        let (_, val) = parse_transform_base64("bytes 4").unwrap();
-        assert_eq!(val, tbd);
-
-        tbd.offset = 3933;
-        tbd.flags = DETECT_TRANSFORM_BASE64_FLAG_NBYTES | DETECT_TRANSFORM_BASE64_FLAG_OFFSET;
         let (_, val) = parse_transform_base64("bytes 4, offset 3933").unwrap();
-        assert_eq!(val, tbd);
-
-        tbd.flags = DETECT_TRANSFORM_BASE64_FLAG_NBYTES | DETECT_TRANSFORM_BASE64_FLAG_OFFSET;
+        assert_eq!(val, args);
         let (_, val) = parse_transform_base64("offset 3933, bytes 4").unwrap();
-        assert_eq!(val, tbd);
+        assert_eq!(val, args);
 
-        tbd.flags = DETECT_TRANSFORM_BASE64_FLAG_MODE;
-        tbd.mode = SCBase64Mode::SCBase64ModeRFC2045;
-        tbd.offset = 0;
-        tbd.nbytes = 0;
         let (_, val) = parse_transform_base64("mode rfc2045").unwrap();
-        assert_eq!(val, tbd);
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                nbytes: 0,
+                offset: 0,
+                mode: SCBase64Mode::SCBase64ModeRFC2045,
+                flags: DETECT_TRANSFORM_BASE64_FLAG_MODE,
+            }
+        );
     }
 
     #[test]
     fn test_parser_parse_valid() {
-        valid_test("", 0, "", 0, "", TRANSFORM_FROM_BASE64_MODE_DEFAULT, 0);
-
-        valid_test(
-            "bytes 4, offset 3933, mode strict",
-            4,
-            "",
-            3933,
-            "",
-            SCBase64Mode::SCBase64ModeStrict,
-            DETECT_TRANSFORM_BASE64_FLAG_NBYTES
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
-                | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+        let (_, val) = parse_transform_base64("").unwrap();
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                mode: TRANSFORM_FROM_BASE64_MODE_DEFAULT,
+                ..Default::default()
+            }
         );
 
-        valid_test(
-            "bytes 4, offset 3933, mode rfc2045",
-            4,
-            "",
-            3933,
-            "",
-            SCBase64Mode::SCBase64ModeRFC2045,
-            DETECT_TRANSFORM_BASE64_FLAG_NBYTES
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
-                | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+        let (_, val) = parse_transform_base64("bytes 4, offset 3933, mode strict").unwrap();
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                nbytes: 4,
+                offset: 3933,
+                mode: SCBase64Mode::SCBase64ModeStrict,
+                flags: DETECT_TRANSFORM_BASE64_FLAG_NBYTES
+                    | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
+                    | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+            }
         );
 
-        valid_test(
-            "bytes 4, offset 3933, mode rfc4648",
-            4,
-            "",
-            3933,
-            "",
-            SCBase64Mode::SCBase64ModeRFC4648,
-            DETECT_TRANSFORM_BASE64_FLAG_NBYTES
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
-                | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+        let (_, val) = parse_transform_base64("bytes 4, offset 3933, mode rfc2045").unwrap();
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                nbytes: 4,
+                offset: 3933,
+                mode: SCBase64Mode::SCBase64ModeRFC2045,
+                flags: DETECT_TRANSFORM_BASE64_FLAG_NBYTES
+                    | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
+                    | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+            }
         );
 
-        valid_test(
-            "bytes 4, offset var, mode rfc4648",
-            4,
-            "",
-            0,
-            "var",
-            SCBase64Mode::SCBase64ModeRFC4648,
-            DETECT_TRANSFORM_BASE64_FLAG_NBYTES
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET_VAR
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
-                | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+        let (_, val) = parse_transform_base64("bytes 4, offset 3933, mode rfc4648").unwrap();
+        assert_eq!(
+            val,
+            DetectTransformFromBase64Data {
+                nbytes: 4,
+                offset: 3933,
+                mode: SCBase64Mode::SCBase64ModeRFC4648,
+                flags: DETECT_TRANSFORM_BASE64_FLAG_NBYTES
+                    | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
+                    | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+            }
         );
 
-        valid_test(
-            "bytes var, offset 3933, mode rfc4648",
-            0,
-            "var",
-            3933,
-            "",
-            SCBase64Mode::SCBase64ModeRFC4648,
-            DETECT_TRANSFORM_BASE64_FLAG_NBYTES
-                | DETECT_TRANSFORM_BASE64_FLAG_NBYTES_VAR
-                | DETECT_TRANSFORM_BASE64_FLAG_OFFSET
-                | DETECT_TRANSFORM_BASE64_FLAG_MODE,
+        assert!(parse_transform_base64("bytes 4, offset var, mode rfc4648").is_err());
+        assert!(parse_transform_base64("bytes var, offset 3933, mode rfc4648").is_err());
+    }
+
+    // Test/mock versions to keep tests in rust
+    #[allow(non_snake_case)]
+    pub(crate) unsafe fn SCInspectionBufferCheckAndExpand(
+        buffer: *mut InspectionBuffer, min_size: u32,
+    ) -> *mut u8 {
+        assert!(min_size <= (*buffer).inspect_len);
+        return (*buffer).inspect as *mut u8;
+    }
+
+    #[allow(non_snake_case)]
+    pub(crate) unsafe fn SCInspectionBufferTruncate(buffer: *mut InspectionBuffer, buf_len: u32) {
+        (*buffer).inspect_len = buf_len;
+    }
+
+    fn test_base64_sample(sig: &str, buf: &[u8], out: &[u8]) {
+        let mut ibuf: InspectionBuffer = unsafe { std::mem::zeroed() };
+        let mut input = Vec::new();
+        // we will overwrite it, so do not create it const
+        input.extend_from_slice(buf);
+        ibuf.inspect = input.as_ptr();
+        ibuf.inspect_len = input.len() as u32;
+        let (_, mut ctx) = parse_transform_base64(sig).unwrap();
+        unsafe {
+            base64_transform(
+                std::ptr::null_mut(),
+                &mut ibuf as *mut InspectionBuffer,
+                &mut ctx as *mut DetectTransformFromBase64Data as *mut c_void,
+            );
+        }
+        let ibufi = ibuf.inspect;
+        let output = unsafe { build_slice!(ibufi, ibuf.inspect_len as usize) };
+        assert_eq!(output, out);
+    }
+
+    #[test]
+    fn test_base64_transform() {
+        /* Simple success case -- check buffer */
+        test_base64_sample("", b"VGhpcyBpcyBTdXJpY2F0YQ==", b"This is Suricata");
+        /* Simple success case with RFC2045 -- check buffer */
+        test_base64_sample("mode rfc2045", b"Zm 9v Ym Fy", b"foobar");
+        /* Decode failure case -- ensure no change to buffer */
+        test_base64_sample("mode strict", b"This is Suricata\n", b"This is Suricata\n");
+        /* bytes > len so --> no transform */
+        test_base64_sample(
+            "bytes 25",
+            b"VGhpcyBpcyBTdXJpY2F0YQ==",
+            b"VGhpcyBpcyBTdXJpY2F0YQ==",
+        );
+        /* offset > len so --> no transform */
+        test_base64_sample(
+            "offset 25",
+            b"VGhpcyBpcyBTdXJpY2F0YQ==",
+            b"VGhpcyBpcyBTdXJpY2F0YQ==",
+        );
+        /* partial transform */
+        test_base64_sample("bytes 12", b"VGhpcyBpcyBTdXJpY2F0YQ==", b"This is S");
+        /* transform from non-zero offset */
+        test_base64_sample("offset 4", b"VGhpcyBpcyBTdXJpY2F0YQ==", b"s is Suricata");
+        /* partial decode */
+        test_base64_sample(
+            "mode rfc2045, bytes 15",
+            b"SGVs bG8 gV29y bGQ=",
+            b"Hello Wor",
+        );
+        /* input is not base64 encoded */
+        test_base64_sample(
+            "mode rfc2045",
+            b"This is not base64-encoded",
+            &[
+                78, 24, 172, 138, 201, 232, 181, 182, 172, 123, 174, 30, 157, 202, 29,
+            ],
         );
     }
 }
index d54c57375f42d9972c4126aa25031debc8d48408..ed279d4770f9706e9e5e27e60c0bc9d106b7a756 100755 (executable)
@@ -321,7 +321,6 @@ noinst_HEADERS = \
        detect-tls-version.h \
        detect-tls.h \
        detect-tos.h \
-       detect-transform-base64.h \
        detect-transform-luaxform.h \
        detect-transform-pcrexform.h \
        detect-ttl.h \
@@ -919,7 +918,6 @@ libsuricata_c_a_SOURCES = \
        detect-tls-version.c \
        detect-tls.c \
        detect-tos.c \
-       detect-transform-base64.c \
        detect-transform-luaxform.c \
        detect-transform-pcrexform.c \
        detect-ttl.c \
index a5899f33377c8a070c73b57be97f158eac8a0d5e..5e0023fd8982404fafc51495cbeeba53532c0456 100644 (file)
 #include "detect-engine-content-inspection.h"
 
 #include "detect-transform-pcrexform.h"
-#include "detect-transform-base64.h"
 #include "detect-transform-luaxform.h"
 
 #include "util-rule-vars.h"
index 3027b0c8389ca2299039b2808a120a8b4ad35285..65a1c192ed7978d48ba2c6386401bb9619f716bd 100644 (file)
@@ -302,7 +302,6 @@ enum DetectKeywordId {
     DETECT_PREFILTER,
 
     DETECT_TRANSFORM_PCREXFORM,
-    DETECT_TRANSFORM_FROM_BASE64,
     DETECT_TRANSFORM_LUAXFORM,
 
     DETECT_IKE_EXCH_TYPE,
diff --git a/src/detect-transform-base64.c b/src/detect-transform-base64.c
deleted file mode 100644 (file)
index fa49304..0000000
+++ /dev/null
@@ -1,380 +0,0 @@
-/* Copyright (C) 2024 Open Information Security Foundation
- *
- * You can copy, redistribute or modify this Program under the terms of
- * the GNU General Public License version 2 as published by the Free
- * Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * version 2 along with this program; if not, write to the Free Software
- * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
- * 02110-1301, USA.
- */
-
-/**
- * \file
- *
- * \author Jeff Lucovsky <jlucovsky@oisf.net>
- *
- * Implements the from_base64 transformation keyword
- */
-
-#include "suricata-common.h"
-
-#include "detect.h"
-#include "detect-engine.h"
-#include "detect-engine-buffer.h"
-#include "detect-byte.h"
-
-#include "rust.h"
-
-#include "detect-transform-base64.h"
-
-#include "util-unittest.h"
-#include "util-print.h"
-
-#ifdef UNITTESTS
-#define DETECT_TRANSFORM_FROM_BASE64_MODE_DEFAULT (uint8_t) SCBase64ModeRFC4648
-static void DetectTransformFromBase64DecodeRegisterTests(void);
-#endif
-
-static void DetectTransformFromBase64Id(const uint8_t **data, uint32_t *length, void *context)
-{
-    if (context) {
-        SCDetectTransformFromBase64Data *b64d = (SCDetectTransformFromBase64Data *)context;
-        *data = b64d->serialized;
-        *length = b64d->serialized_len;
-    }
-}
-
-static void DetectTransformFromBase64DecodeFree(DetectEngineCtx *de_ctx, void *ptr)
-{
-    if (ptr) {
-        SCTransformBase64Free(ptr);
-    }
-}
-
-static SCDetectTransformFromBase64Data *DetectTransformFromBase64DecodeParse(const char *str)
-{
-    SCDetectTransformFromBase64Data *tbd = SCTransformBase64Parse(str);
-    if (tbd == NULL) {
-        SCLogError("invalid transform_base64 values");
-    }
-    return tbd;
-}
-
-/**
- *  \internal
- *  \brief Base64 decode the input buffer
- *  \param det_ctx detection engine ctx
- *  \param s signature
- *  \param opts_str transform options, if any
- *  \retval 0 No decode
- *  \retval >0 Decoded byte count
- */
-static int DetectTransformFromBase64DecodeSetup(
-        DetectEngineCtx *de_ctx, Signature *s, const char *opts_str)
-{
-    int r = -1;
-
-    SCEnter();
-
-    SCDetectTransformFromBase64Data *b64d = DetectTransformFromBase64DecodeParse(opts_str);
-    if (b64d == NULL)
-        SCReturnInt(r);
-
-    if (b64d->flags & DETECT_TRANSFORM_BASE64_FLAG_OFFSET_VAR) {
-        SCLogError("offset value must be a value, not a variable name");
-        goto exit_path;
-    }
-
-    if (b64d->flags & DETECT_TRANSFORM_BASE64_FLAG_NBYTES_VAR) {
-        SCLogError("byte value must be a value, not a variable name");
-        goto exit_path;
-    }
-
-    r = SCDetectSignatureAddTransform(s, DETECT_TRANSFORM_FROM_BASE64, b64d);
-
-exit_path:
-    if (r != 0)
-        DetectTransformFromBase64DecodeFree(de_ctx, b64d);
-    SCReturnInt(r);
-}
-
-static void TransformFromBase64Decode(
-        DetectEngineThreadCtx *det_ctx, InspectionBuffer *buffer, void *options)
-{
-    SCDetectTransformFromBase64Data *b64d = options;
-    const uint8_t *input = buffer->inspect;
-    const uint32_t input_len = buffer->inspect_len;
-    uint32_t decode_length = input_len;
-
-    SCBase64Mode mode = b64d->mode;
-    uint32_t offset = b64d->offset;
-    uint32_t nbytes = b64d->nbytes;
-
-    if (offset) {
-        if (offset > input_len) {
-            SCLogDebug("offset %d exceeds length %d; returning", offset, input_len);
-            return;
-        }
-        input += offset;
-        decode_length -= offset;
-    }
-
-    if (nbytes) {
-        if (nbytes > decode_length) {
-            SCLogDebug("byte count %d plus offset %d exceeds length %d; returning", nbytes, offset,
-                    input_len);
-            return;
-        }
-        decode_length = nbytes;
-    }
-    if (decode_length == 0) {
-        return;
-    }
-
-    uint32_t decoded_size = SCBase64DecodeBufferSize(decode_length);
-    uint8_t decoded[decoded_size];
-    uint32_t num_decoded = SCBase64Decode((const uint8_t *)input, decode_length, mode, decoded);
-    if (num_decoded > 0) {
-        //            PrintRawDataFp(stdout, output, b64data->decoded_len);
-        InspectionBufferCopy(buffer, decoded, num_decoded);
-    }
-}
-
-void DetectTransformFromBase64DecodeRegister(void)
-{
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].name = "from_base64";
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].desc = "convert the base64 decode of the buffer";
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].url = "/rules/transforms.html#from_base64";
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].Setup = DetectTransformFromBase64DecodeSetup;
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].Transform = TransformFromBase64Decode;
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].TransformId = DetectTransformFromBase64Id;
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].Free = DetectTransformFromBase64DecodeFree;
-#ifdef UNITTESTS
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].RegisterTests =
-            DetectTransformFromBase64DecodeRegisterTests;
-#endif
-    sigmatch_table[DETECT_TRANSFORM_FROM_BASE64].flags |= SIGMATCH_OPTIONAL_OPT;
-}
-
-#ifdef UNITTESTS
-/* Simple success case -- check buffer */
-static int DetectTransformFromBase64DecodeTest01(void)
-{
-    const uint8_t *input = (const uint8_t *)"VGhpcyBpcyBTdXJpY2F0YQ==";
-    uint32_t input_len = strlen((char *)input);
-    const char *result = "This is Suricata";
-    uint32_t result_len = strlen((char *)result);
-    SCDetectTransformFromBase64Data b64d = {
-        .nbytes = input_len,
-        .mode = DETECT_TRANSFORM_FROM_BASE64_MODE_DEFAULT,
-    };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == result_len);
-    FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* Simple success case with RFC2045 -- check buffer */
-static int DetectTransformFromBase64DecodeTest01a(void)
-{
-    const uint8_t *input = (const uint8_t *)"Zm 9v Ym Fy";
-    uint32_t input_len = strlen((char *)input);
-    const char *result = "foobar";
-    uint32_t result_len = strlen((char *)result);
-    SCDetectTransformFromBase64Data b64d = { .nbytes = input_len, .mode = SCBase64ModeRFC2045 };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == result_len);
-    FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* Decode failure case -- ensure no change to buffer */
-static int DetectTransformFromBase64DecodeTest02(void)
-{
-    const uint8_t *input = (const uint8_t *)"This is Suricata\n";
-    uint32_t input_len = strlen((char *)input);
-    SCDetectTransformFromBase64Data b64d = { .nbytes = input_len, .mode = SCBase64ModeStrict };
-    InspectionBuffer buffer;
-    InspectionBuffer buffer_orig;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    buffer_orig = buffer;
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_offset == buffer_orig.inspect_offset);
-    FAIL_IF_NOT(buffer.inspect_len == buffer_orig.inspect_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* bytes > len so --> no transform */
-static int DetectTransformFromBase64DecodeTest03(void)
-{
-    const uint8_t *input = (const uint8_t *)"VGhpcyBpcyBTdXJpY2F0YQ==";
-    uint32_t input_len = strlen((char *)input);
-
-    SCDetectTransformFromBase64Data b64d = {
-        .nbytes = input_len + 1,
-    };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(strncmp((const char *)input, (const char *)buffer.inspect, input_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* offset > len so --> no transform */
-static int DetectTransformFromBase64DecodeTest04(void)
-{
-    const uint8_t *input = (const uint8_t *)"VGhpcyBpcyBTdXJpY2F0YQ==";
-    uint32_t input_len = strlen((char *)input);
-
-    SCDetectTransformFromBase64Data b64d = {
-        .offset = input_len + 1,
-    };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(strncmp((const char *)input, (const char *)buffer.inspect, input_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* partial transform */
-static int DetectTransformFromBase64DecodeTest05(void)
-{
-    const uint8_t *input = (const uint8_t *)"VGhpcyBpcyBTdXJpY2F0YQ==";
-    uint32_t input_len = strlen((char *)input);
-    const char *result = "This is S";
-    uint32_t result_len = strlen((char *)result);
-
-    SCDetectTransformFromBase64Data b64d = {
-        .nbytes = 12,
-        .mode = DETECT_TRANSFORM_FROM_BASE64_MODE_DEFAULT,
-    };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == result_len);
-    FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* transform from non-zero offset */
-static int DetectTransformFromBase64DecodeTest06(void)
-{
-    const uint8_t *input = (const uint8_t *)"VGhpcyBpcyBTdXJpY2F0YQ==";
-    uint32_t input_len = strlen((char *)input);
-    const char *result = "s is Suricata";
-    uint32_t result_len = strlen((char *)result);
-
-    SCDetectTransformFromBase64Data b64d = {
-        .offset = 4,
-        .mode = DETECT_TRANSFORM_FROM_BASE64_MODE_DEFAULT,
-    };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == result_len);
-    FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* partial decode */
-static int DetectTransformFromBase64DecodeTest07(void)
-{
-    /* Full string decodes to Hello World */
-    const uint8_t *input = (const uint8_t *)"SGVs bG8 gV29y bGQ=";
-    uint32_t input_len = strlen((char *)input);
-    const char *result = "Hello Wor";
-    uint32_t result_len = strlen((char *)result);
-
-    SCDetectTransformFromBase64Data b64d = { .nbytes = input_len - 4, /* NB: stop early */
-        .mode = SCBase64ModeRFC2045 };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == result_len);
-    FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0);
-    PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-
-/* input is not base64 encoded */
-static int DetectTransformFromBase64DecodeTest08(void)
-{
-    /* A portion of this string will be decoded */
-    const uint8_t *input = (const uint8_t *)"This is not base64-encoded";
-    uint32_t input_len = strlen((char *)input);
-
-    SCDetectTransformFromBase64Data b64d = { .nbytes = input_len, .mode = SCBase64ModeRFC2045 };
-
-    InspectionBuffer buffer;
-    InspectionBufferInit(&buffer, input_len);
-    InspectionBufferSetup(NULL, -1, &buffer, input, input_len);
-    // PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    TransformFromBase64Decode(NULL, &buffer, &b64d);
-    FAIL_IF_NOT(buffer.inspect_len == 15);
-    // PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len);
-    InspectionBufferFree(&buffer);
-    PASS;
-}
-static void DetectTransformFromBase64DecodeRegisterTests(void)
-{
-    UtRegisterTest("DetectTransformFromBase64DecodeTest01", DetectTransformFromBase64DecodeTest01);
-    UtRegisterTest(
-            "DetectTransformFromBase64DecodeTest01a", DetectTransformFromBase64DecodeTest01a);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest02", DetectTransformFromBase64DecodeTest02);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest03", DetectTransformFromBase64DecodeTest03);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest04", DetectTransformFromBase64DecodeTest04);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest05", DetectTransformFromBase64DecodeTest05);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest06", DetectTransformFromBase64DecodeTest06);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest07", DetectTransformFromBase64DecodeTest07);
-    UtRegisterTest("DetectTransformFromBase64DecodeTest08", DetectTransformFromBase64DecodeTest08);
-}
-#endif
diff --git a/src/detect-transform-base64.h b/src/detect-transform-base64.h
deleted file mode 100644 (file)
index fc0847b..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright (C) 2024 Open Information Security Foundation
- *
- * You can copy, redistribute or modify this Program under the terms of
- * the GNU General Public License version 2 as published by the Free
- * Software Foundation.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * version 2 along with this program; if not, write to the Free Software
- * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
- * 02110-1301, USA.
- */
-
-/**
- * \file
- *
- * \author Jeff Lucovsky <jlucovsky@oisf.net>
- */
-
-#ifndef SURICATA_DETECT_TRANSFORM_BASE64_H
-#define SURICATA_DETECT_TRANSFORM_BASE64_H
-
-/* prototypes */
-void DetectTransformFromBase64DecodeRegister(void);
-
-#endif /* SURICATA_DETECT_TRANSFORM_BASE64_H */