]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
smb: Add rust registration function
authorShivani Bhardwaj <shivanib134@gmail.com>
Sat, 19 Jun 2021 07:53:14 +0000 (13:23 +0530)
committerShivani Bhardwaj <shivanib134@gmail.com>
Thu, 5 Aug 2021 15:30:39 +0000 (21:00 +0530)
Get rid of the C glue code and move registration completely to Rust.

rust/src/smb/smb.rs
src/app-layer-smb.c

index d5576e70308fab0cf8bc598320d62d18039e22b5..4237a6400945f000ca592b5ef4be7a51ca6528a0 100644 (file)
@@ -28,7 +28,7 @@
 use std;
 use std::mem::transmute;
 use std::str;
-use std::ffi::{self, CStr};
+use std::ffi::{self, CStr, CString};
 
 use std::collections::HashMap;
 
@@ -37,6 +37,7 @@ use nom;
 use crate::core::*;
 use crate::applayer;
 use crate::applayer::*;
+use crate::conf::*;
 use crate::filecontainer::*;
 
 use crate::smb::nbss_records::*;
@@ -2189,3 +2190,143 @@ pub extern "C" fn rs_smb_state_get_event_info(event_name: *const std::os::raw::c
     }
     0
 }
+
+pub extern "C" fn smb3_probe_tcp(f: *const Flow, dir: u8, input: *const u8, len: u32, rdir: *mut u8) -> u16 {
+    let retval = rs_smb_probe_tcp(f, dir, input, len, rdir);
+    let f = cast_pointer!(f, Flow);
+        if unsafe { retval != ALPROTO_SMB } {
+            return retval;
+        }
+        let (sp, dp) = f.get_ports();
+        let flags = f.get_flags();
+        let fsp = if (flags & FLOW_DIR_REVERSED) != 0 { dp } else { sp };
+        let fdp = if (flags & FLOW_DIR_REVERSED) != 0 { sp } else { dp };
+        if fsp == 445 && fdp != 445 {
+            unsafe {
+            if dir & STREAM_TOSERVER != 0 {
+                *rdir = STREAM_TOCLIENT;
+            } else {
+                *rdir = STREAM_TOSERVER;
+            }
+            }
+        }
+    return unsafe { ALPROTO_SMB };
+}
+
+fn register_pattern_probe() -> i8 {
+    let mut r = 0;
+    unsafe {
+        // SMB1
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+        // SMB2/3
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+        // SMB3 encrypted records
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOSERVER, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+        r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB,
+                                                     b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4,
+                                                     STREAM_TOCLIENT, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE);
+    }
+
+    if r == 0 {
+        return 0;
+    } else {
+        return -1;
+    }
+}
+
+// Parser name as a C style string.
+const PARSER_NAME: &'static [u8] = b"smb\0";
+
+#[no_mangle]
+pub unsafe extern "C" fn rs_smb_register_parser() {
+    let default_port = CString::new("445").unwrap();
+    let mut stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH;
+    let parser = RustParser {
+        name: PARSER_NAME.as_ptr() as *const std::os::raw::c_char,
+        default_port: default_port.as_ptr(),
+        ipproto: IPPROTO_TCP,
+        probe_ts: None,
+        probe_tc: None,
+        min_depth: 0,
+        max_depth: 16,
+        state_new: rs_smb_state_new,
+        state_free: rs_smb_state_free,
+        tx_free: rs_smb_state_tx_free,
+        parse_ts: rs_smb_parse_request_tcp,
+        parse_tc: rs_smb_parse_response_tcp,
+        get_tx_count: rs_smb_state_get_tx_count,
+        get_tx: rs_smb_state_get_tx,
+        tx_comp_st_ts: 1,
+        tx_comp_st_tc: 1,
+        tx_get_progress: rs_smb_tx_get_alstate_progress,
+        get_de_state: rs_smb_state_get_tx_detect_state,
+        set_de_state: rs_smb_state_set_tx_detect_state,
+        get_events: Some(rs_smb_state_get_events),
+        get_eventinfo: Some(rs_smb_state_get_event_info),
+        get_eventinfo_byid : Some(rs_smb_state_get_event_info_by_id),
+        localstorage_new: None,
+        localstorage_free: None,
+        get_files: Some(rs_smb_getfiles),
+        get_tx_iterator: Some(rs_smb_state_get_tx_iterator),
+        get_tx_data: rs_smb_get_tx_data,
+        apply_tx_config: None,
+        flags: APP_LAYER_PARSER_OPT_ACCEPT_GAPS,
+        truncate: Some(rs_smb_state_truncate),
+    };
+
+    let ip_proto_str = CString::new("tcp").unwrap();
+
+    if AppLayerProtoDetectConfProtoDetectionEnabled(
+        ip_proto_str.as_ptr(),
+        parser.name,
+    ) != 0
+    {
+        let alproto = AppLayerRegisterProtocolDetection(&parser, 1);
+        ALPROTO_SMB = alproto;
+        if register_pattern_probe() < 0 {
+            return;
+        }
+
+        let have_cfg = AppLayerProtoDetectPPParseConfPorts(ip_proto_str.as_ptr(),
+                    IPPROTO_TCP as u8, parser.name, ALPROTO_SMB, 0,
+                    MIN_REC_SIZE, rs_smb_probe_tcp, rs_smb_probe_tcp);
+
+        if have_cfg == 0 {
+            AppLayerProtoDetectPPRegister(IPPROTO_TCP as u8, parser.default_port, ALPROTO_SMB,
+                                          0, MIN_REC_SIZE, STREAM_TOSERVER, rs_smb_probe_tcp, rs_smb_probe_tcp);
+        }
+
+        if AppLayerParserConfParserEnabled(
+            ip_proto_str.as_ptr(),
+            parser.name,
+        ) != 0
+        {
+            let _ = AppLayerRegisterParser(&parser, alproto);
+        }
+        SCLogDebug!("Rust SMB parser registered.");
+        let retval = conf_get("app-layer.protocols.smb.stream-depth");
+        if let Some(val) = retval {
+            let val = val.parse::<i32>().unwrap();
+            if val < 0 {
+                SCLogError!("invalid value for stream-depth");
+            } else {
+                stream_depth = val as u32;
+           }
+            AppLayerParserSetStreamDepth(IPPROTO_TCP as u8, ALPROTO_SMB, stream_depth);
+        }
+    } else {
+        SCLogDebug!("Protocol detector and parser disabled for SMB.");
+    }
+}
index 206902add07b8ea7803c80c1b8e8231f2d47ef26..f9063f7c28433c5af0b5d280e3efb736627f1877 100644 (file)
 #include "app-layer-smb.h"
 #include "util-misc.h"
 
-#define MIN_REC_SIZE 32+4 // SMB hdr + nbss hdr
-
-static AppLayerResult SMBTCPParseRequest(Flow *f, void *state,
-        AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len,
-        void *local_data, const uint8_t flags)
-{
-    SCLogDebug("SMBTCPParseRequest");
-    uint16_t file_flags = FileFlowToFlags(f, STREAM_TOSERVER);
-    rs_smb_setfileflags(0, state, file_flags|FILE_USE_DETECT);
-
-    if (input == NULL && input_len > 0) {
-        AppLayerResult res = rs_smb_parse_request_tcp_gap(state, input_len);
-        SCLogDebug("SMB request GAP of %u bytes, retval %d", input_len, res.status);
-        SCReturnStruct(res);
-    } else {
-        AppLayerResult res = rs_smb_parse_request_tcp(f, state, pstate,
-                input, input_len, local_data, flags);
-        SCLogDebug("SMB request%s of %u bytes, retval %d",
-                (input == NULL && input_len > 0) ? " is GAP" : "", input_len, res.status);
-        SCReturnStruct(res);
-    }
-}
-
-static AppLayerResult SMBTCPParseResponse(Flow *f, void *state,
-        AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len,
-        void *local_data, const uint8_t flags)
-{
-    SCLogDebug("SMBTCPParseResponse");
-    uint16_t file_flags = FileFlowToFlags(f, STREAM_TOCLIENT);
-    rs_smb_setfileflags(1, state, file_flags|FILE_USE_DETECT);
-
-    SCLogDebug("SMBTCPParseResponse %p/%u", input, input_len);
-    if (input == NULL && input_len > 0) {
-        AppLayerResult res = rs_smb_parse_response_tcp_gap(state, input_len);
-        SCLogDebug("SMB response GAP of %u bytes, retval %d", input_len, res.status);
-        SCReturnStruct(res);
-    } else {
-        AppLayerResult res = rs_smb_parse_response_tcp(f, state, pstate,
-                input, input_len, local_data, flags);
-        SCReturnStruct(res);
-    }
-}
-
-static uint16_t SMBTCPProbe(Flow *f, uint8_t direction,
-        const uint8_t *input, uint32_t len, uint8_t *rdir)
-{
-    SCLogDebug("SMBTCPProbe");
-
-    if (len < MIN_REC_SIZE) {
-        return ALPROTO_UNKNOWN;
-    }
-
-    const int r = rs_smb_probe_tcp(f, direction, input, len, rdir);
-    switch (r) {
-        case 1:
-            return ALPROTO_SMB;
-        case 0:
-            return ALPROTO_UNKNOWN;
-        case -1:
-        default:
-            return ALPROTO_FAILED;
-    }
-}
-
-/** \internal
- *  \brief as SMB3 records have no direction indicator, fall
- *         back to the port numbers for a hint
- */
-static uint16_t SMB3TCPProbe(Flow *f, uint8_t direction,
-        const uint8_t *input, uint32_t len, uint8_t *rdir)
-{
-    SCEnter();
-
-    AppProto p = SMBTCPProbe(f, direction, input, len, rdir);
-    if (p != ALPROTO_SMB) {
-        SCReturnUInt(p);
-    }
-
-    uint16_t fsp = (f->flags & FLOW_DIR_REVERSED) ? f->dp : f->sp;
-    uint16_t fdp = (f->flags & FLOW_DIR_REVERSED) ? f->sp : f->dp;
-    SCLogDebug("direction %s flow sp %u dp %u fsp %u fdp %u",
-            (direction & STREAM_TOSERVER) ? "toserver" : "toclient",
-            f->sp, f->dp, fsp, fdp);
-
-    if (fsp == 445 && fdp != 445) {
-        if (direction & STREAM_TOSERVER) {
-            *rdir = STREAM_TOCLIENT;
-        } else {
-            *rdir = STREAM_TOSERVER;
-        }
-    }
-    SCLogDebug("returning ALPROTO_SMB for dir %s with rdir %s",
-            (direction & STREAM_TOSERVER) ? "toserver" : "toclient",
-            (*rdir == STREAM_TOSERVER) ? "toserver" : "toclient");
-    SCReturnUInt(ALPROTO_SMB);
-}
-
-static int SMBGetAlstateProgress(void *tx, uint8_t direction)
-{
-    return rs_smb_tx_get_alstate_progress(tx, direction);
-}
-
-static uint64_t SMBGetTxCnt(void *alstate)
-{
-    return rs_smb_state_get_tx_count(alstate);
-}
-
-static void *SMBGetTx(void *alstate, uint64_t tx_id)
-{
-    return rs_smb_state_get_tx(alstate, tx_id);
-}
-
-static AppLayerGetTxIterTuple SMBGetTxIterator(
-        const uint8_t ipproto, const AppProto alproto,
-        void *alstate, uint64_t min_tx_id, uint64_t max_tx_id,
-        AppLayerGetTxIterState *istate)
-{
-    return rs_smb_state_get_tx_iterator(
-            ipproto, alproto, alstate, min_tx_id, max_tx_id, (uint64_t *)istate);
-}
-
-
-static void SMBStateTransactionFree(void *state, uint64_t tx_id)
-{
-    rs_smb_state_tx_free(state, tx_id);
-}
-
-static DetectEngineState *SMBGetTxDetectState(void *tx)
-{
-    return rs_smb_state_get_tx_detect_state(tx);
-}
-
-static int SMBSetTxDetectState(void *tx, DetectEngineState *s)
-{
-    rs_smb_state_set_tx_detect_state(tx, s);
-    return 0;
-}
-
-static FileContainer *SMBGetFiles(void *state, uint8_t direction)
-{
-    return rs_smb_getfiles(state, direction);
-}
-
-static AppLayerDecoderEvents *SMBGetEvents(void *tx)
-{
-    return rs_smb_state_get_events(tx);
-}
-
-static int SMBGetEventInfoById(int event_id, const char **event_name,
-    AppLayerEventType *event_type)
-{
-    return rs_smb_state_get_event_info_by_id(event_id, event_name, event_type);
-}
-
-static int SMBGetEventInfo(const char *event_name, int *event_id,
-    AppLayerEventType *event_type)
-{
-    return rs_smb_state_get_event_info(event_name, event_id, event_type);
-}
-
-static void SMBStateTruncate(void *state, uint8_t direction)
-{
-    return rs_smb_state_truncate(state, direction);
-}
-
-static int SMBRegisterPatternsForProtocolDetection(void)
-{
-    int r = 0;
-    /* SMB1 */
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|ff|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|ff|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-
-    /* SMB2/3 */
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|fe|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|fe|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-
-    /* SMB3 encrypted records */
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|fd|SMB", 8, 4, STREAM_TOSERVER, SMB3TCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-    r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB,
-            "|fd|SMB", 8, 4, STREAM_TOCLIENT, SMB3TCPProbe,
-            MIN_REC_SIZE, MIN_REC_SIZE);
-    return r == 0 ? 0 : -1;
-}
 
 static StreamingBufferConfig sbcfg = STREAMING_BUFFER_CONFIG_INITIALIZER;
 static SuricataFileContext sfc = { &sbcfg };
@@ -231,95 +38,11 @@ static SuricataFileContext sfc = { &sbcfg };
 static void SMBParserRegisterTests(void);
 #endif
 
-static uint32_t stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH;
-
 void RegisterSMBParsers(void)
 {
-    const char *proto_name = "smb";
-
-    /** SMB */
-    if (AppLayerProtoDetectConfProtoDetectionEnabled("tcp", proto_name)) {
-        AppLayerProtoDetectRegisterProtocol(ALPROTO_SMB, proto_name);
-        if (SMBRegisterPatternsForProtocolDetection() < 0)
-            return;
-
-        rs_smb_init(&sfc);
-
-        if (RunmodeIsUnittests()) {
-            AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0,
-                    MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe,
-                    SMBTCPProbe);
-        } else {
-            int have_cfg = AppLayerProtoDetectPPParseConfPorts("tcp",
-                    IPPROTO_TCP, proto_name, ALPROTO_SMB, 0,
-                    MIN_REC_SIZE, SMBTCPProbe, SMBTCPProbe);
-            /* if we have no config, we enable the default port 445 */
-            if (!have_cfg) {
-                SCLogConfig("no SMB TCP config found, enabling SMB detection "
-                            "on port 445.");
-                AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0,
-                        MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe,
-                        SMBTCPProbe);
-            }
-        }
-    } else {
-        SCLogConfig("Protocol detection and parser disabled for %s protocol.",
-                  proto_name);
-        return;
-    }
-
-    if (AppLayerParserConfParserEnabled("tcp", proto_name)) {
-        AppLayerParserRegisterParser(IPPROTO_TCP, ALPROTO_SMB, STREAM_TOSERVER,
-                SMBTCPParseRequest);
-        AppLayerParserRegisterParser(IPPROTO_TCP , ALPROTO_SMB, STREAM_TOCLIENT,
-                SMBTCPParseResponse);
-        AppLayerParserRegisterStateFuncs(IPPROTO_TCP, ALPROTO_SMB,
-                rs_smb_state_new, rs_smb_state_free);
-        AppLayerParserRegisterTxFreeFunc(IPPROTO_TCP, ALPROTO_SMB,
-                SMBStateTransactionFree);
-
-        AppLayerParserRegisterGetEventsFunc(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetEvents);
-        AppLayerParserRegisterGetEventInfo(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetEventInfo);
-        AppLayerParserRegisterGetEventInfoById(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetEventInfoById);
-
-        AppLayerParserRegisterDetectStateFuncs(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetTxDetectState, SMBSetTxDetectState);
-        AppLayerParserRegisterGetTx(IPPROTO_TCP, ALPROTO_SMB, SMBGetTx);
-        AppLayerParserRegisterGetTxIterator(IPPROTO_TCP, ALPROTO_SMB, SMBGetTxIterator);
-        AppLayerParserRegisterGetTxCnt(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetTxCnt);
-        AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_SMB,
-                SMBGetAlstateProgress);
-        AppLayerParserRegisterStateProgressCompletionStatus(ALPROTO_SMB, 1, 1);
-        AppLayerParserRegisterTruncateFunc(IPPROTO_TCP, ALPROTO_SMB,
-                                          SMBStateTruncate);
-        AppLayerParserRegisterGetFilesFunc(IPPROTO_TCP, ALPROTO_SMB, SMBGetFiles);
-
-        AppLayerParserRegisterTxDataFunc(IPPROTO_TCP, ALPROTO_SMB, rs_smb_get_tx_data);
-
-        /* This parser accepts gaps. */
-        AppLayerParserRegisterOptionFlags(IPPROTO_TCP, ALPROTO_SMB,
-                APP_LAYER_PARSER_OPT_ACCEPT_GAPS);
-
-        ConfNode *p = ConfGetNode("app-layer.protocols.smb.stream-depth");
-        if (p != NULL) {
-            uint32_t value;
-            if (ParseSizeStringU32(p->val, &value) < 0) {
-                SCLogError(SC_ERR_SMB_CONFIG, "invalid value for stream-depth %s", p->val);
-            } else {
-                stream_depth = value;
-            }
-        }
-        SCLogConfig("SMB stream depth: %u", stream_depth);
+    rs_smb_init(&sfc);
+    rs_smb_register_parser();
 
-        AppLayerParserSetStreamDepth(IPPROTO_TCP, ALPROTO_SMB, stream_depth);
-    } else {
-        SCLogConfig("Parsed disabled for %s protocol. Protocol detection"
-                  "still on.", proto_name);
-    }
 #ifdef UNITTESTS
     AppLayerParserRegisterProtocolUnittests(IPPROTO_TCP, ALPROTO_SMB, SMBParserRegisterTests);
 #endif