From e5c948df87541509e34445d80c5ef0242da3326f Mon Sep 17 00:00:00 2001 From: Shivani Bhardwaj Date: Sat, 19 Jun 2021 13:23:14 +0530 Subject: [PATCH] smb: Add rust registration function Get rid of the C glue code and move registration completely to Rust. --- rust/src/smb/smb.rs | 143 +++++++++++++++++++++- src/app-layer-smb.c | 281 +------------------------------------------- 2 files changed, 144 insertions(+), 280 deletions(-) diff --git a/rust/src/smb/smb.rs b/rust/src/smb/smb.rs index d5576e7030..4237a64009 100644 --- a/rust/src/smb/smb.rs +++ b/rust/src/smb/smb.rs @@ -28,7 +28,7 @@ use std; use std::mem::transmute; use std::str; -use std::ffi::{self, CStr}; +use std::ffi::{self, CStr, CString}; use std::collections::HashMap; @@ -37,6 +37,7 @@ use nom; use crate::core::*; use crate::applayer; use crate::applayer::*; +use crate::conf::*; use crate::filecontainer::*; use crate::smb::nbss_records::*; @@ -2189,3 +2190,143 @@ pub extern "C" fn rs_smb_state_get_event_info(event_name: *const std::os::raw::c } 0 } + +pub extern "C" fn smb3_probe_tcp(f: *const Flow, dir: u8, input: *const u8, len: u32, rdir: *mut u8) -> u16 { + let retval = rs_smb_probe_tcp(f, dir, input, len, rdir); + let f = cast_pointer!(f, Flow); + if unsafe { retval != ALPROTO_SMB } { + return retval; + } + let (sp, dp) = f.get_ports(); + let flags = f.get_flags(); + let fsp = if (flags & FLOW_DIR_REVERSED) != 0 { dp } else { sp }; + let fdp = if (flags & FLOW_DIR_REVERSED) != 0 { sp } else { dp }; + if fsp == 445 && fdp != 445 { + unsafe { + if dir & STREAM_TOSERVER != 0 { + *rdir = STREAM_TOCLIENT; + } else { + *rdir = STREAM_TOSERVER; + } + } + } + return unsafe { ALPROTO_SMB }; +} + +fn register_pattern_probe() -> i8 { + let mut r = 0; + unsafe { + // SMB1 + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|ff|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + // SMB2/3 + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOSERVER, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|fe|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOCLIENT, rs_smb_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + // SMB3 encrypted records + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOSERVER, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_SMB, + b"|fd|SMB\0".as_ptr() as *const std::os::raw::c_char, 8, 4, + STREAM_TOCLIENT, smb3_probe_tcp, MIN_REC_SIZE, MIN_REC_SIZE); + } + + if r == 0 { + return 0; + } else { + return -1; + } +} + +// Parser name as a C style string. +const PARSER_NAME: &'static [u8] = b"smb\0"; + +#[no_mangle] +pub unsafe extern "C" fn rs_smb_register_parser() { + let default_port = CString::new("445").unwrap(); + let mut stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH; + let parser = RustParser { + name: PARSER_NAME.as_ptr() as *const std::os::raw::c_char, + default_port: default_port.as_ptr(), + ipproto: IPPROTO_TCP, + probe_ts: None, + probe_tc: None, + min_depth: 0, + max_depth: 16, + state_new: rs_smb_state_new, + state_free: rs_smb_state_free, + tx_free: rs_smb_state_tx_free, + parse_ts: rs_smb_parse_request_tcp, + parse_tc: rs_smb_parse_response_tcp, + get_tx_count: rs_smb_state_get_tx_count, + get_tx: rs_smb_state_get_tx, + tx_comp_st_ts: 1, + tx_comp_st_tc: 1, + tx_get_progress: rs_smb_tx_get_alstate_progress, + get_de_state: rs_smb_state_get_tx_detect_state, + set_de_state: rs_smb_state_set_tx_detect_state, + get_events: Some(rs_smb_state_get_events), + get_eventinfo: Some(rs_smb_state_get_event_info), + get_eventinfo_byid : Some(rs_smb_state_get_event_info_by_id), + localstorage_new: None, + localstorage_free: None, + get_files: Some(rs_smb_getfiles), + get_tx_iterator: Some(rs_smb_state_get_tx_iterator), + get_tx_data: rs_smb_get_tx_data, + apply_tx_config: None, + flags: APP_LAYER_PARSER_OPT_ACCEPT_GAPS, + truncate: Some(rs_smb_state_truncate), + }; + + let ip_proto_str = CString::new("tcp").unwrap(); + + if AppLayerProtoDetectConfProtoDetectionEnabled( + ip_proto_str.as_ptr(), + parser.name, + ) != 0 + { + let alproto = AppLayerRegisterProtocolDetection(&parser, 1); + ALPROTO_SMB = alproto; + if register_pattern_probe() < 0 { + return; + } + + let have_cfg = AppLayerProtoDetectPPParseConfPorts(ip_proto_str.as_ptr(), + IPPROTO_TCP as u8, parser.name, ALPROTO_SMB, 0, + MIN_REC_SIZE, rs_smb_probe_tcp, rs_smb_probe_tcp); + + if have_cfg == 0 { + AppLayerProtoDetectPPRegister(IPPROTO_TCP as u8, parser.default_port, ALPROTO_SMB, + 0, MIN_REC_SIZE, STREAM_TOSERVER, rs_smb_probe_tcp, rs_smb_probe_tcp); + } + + if AppLayerParserConfParserEnabled( + ip_proto_str.as_ptr(), + parser.name, + ) != 0 + { + let _ = AppLayerRegisterParser(&parser, alproto); + } + SCLogDebug!("Rust SMB parser registered."); + let retval = conf_get("app-layer.protocols.smb.stream-depth"); + if let Some(val) = retval { + let val = val.parse::().unwrap(); + if val < 0 { + SCLogError!("invalid value for stream-depth"); + } else { + stream_depth = val as u32; + } + AppLayerParserSetStreamDepth(IPPROTO_TCP as u8, ALPROTO_SMB, stream_depth); + } + } else { + SCLogDebug!("Protocol detector and parser disabled for SMB."); + } +} diff --git a/src/app-layer-smb.c b/src/app-layer-smb.c index 206902add0..f9063f7c28 100644 --- a/src/app-layer-smb.c +++ b/src/app-layer-smb.c @@ -28,199 +28,6 @@ #include "app-layer-smb.h" #include "util-misc.h" -#define MIN_REC_SIZE 32+4 // SMB hdr + nbss hdr - -static AppLayerResult SMBTCPParseRequest(Flow *f, void *state, - AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len, - void *local_data, const uint8_t flags) -{ - SCLogDebug("SMBTCPParseRequest"); - uint16_t file_flags = FileFlowToFlags(f, STREAM_TOSERVER); - rs_smb_setfileflags(0, state, file_flags|FILE_USE_DETECT); - - if (input == NULL && input_len > 0) { - AppLayerResult res = rs_smb_parse_request_tcp_gap(state, input_len); - SCLogDebug("SMB request GAP of %u bytes, retval %d", input_len, res.status); - SCReturnStruct(res); - } else { - AppLayerResult res = rs_smb_parse_request_tcp(f, state, pstate, - input, input_len, local_data, flags); - SCLogDebug("SMB request%s of %u bytes, retval %d", - (input == NULL && input_len > 0) ? " is GAP" : "", input_len, res.status); - SCReturnStruct(res); - } -} - -static AppLayerResult SMBTCPParseResponse(Flow *f, void *state, - AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len, - void *local_data, const uint8_t flags) -{ - SCLogDebug("SMBTCPParseResponse"); - uint16_t file_flags = FileFlowToFlags(f, STREAM_TOCLIENT); - rs_smb_setfileflags(1, state, file_flags|FILE_USE_DETECT); - - SCLogDebug("SMBTCPParseResponse %p/%u", input, input_len); - if (input == NULL && input_len > 0) { - AppLayerResult res = rs_smb_parse_response_tcp_gap(state, input_len); - SCLogDebug("SMB response GAP of %u bytes, retval %d", input_len, res.status); - SCReturnStruct(res); - } else { - AppLayerResult res = rs_smb_parse_response_tcp(f, state, pstate, - input, input_len, local_data, flags); - SCReturnStruct(res); - } -} - -static uint16_t SMBTCPProbe(Flow *f, uint8_t direction, - const uint8_t *input, uint32_t len, uint8_t *rdir) -{ - SCLogDebug("SMBTCPProbe"); - - if (len < MIN_REC_SIZE) { - return ALPROTO_UNKNOWN; - } - - const int r = rs_smb_probe_tcp(f, direction, input, len, rdir); - switch (r) { - case 1: - return ALPROTO_SMB; - case 0: - return ALPROTO_UNKNOWN; - case -1: - default: - return ALPROTO_FAILED; - } -} - -/** \internal - * \brief as SMB3 records have no direction indicator, fall - * back to the port numbers for a hint - */ -static uint16_t SMB3TCPProbe(Flow *f, uint8_t direction, - const uint8_t *input, uint32_t len, uint8_t *rdir) -{ - SCEnter(); - - AppProto p = SMBTCPProbe(f, direction, input, len, rdir); - if (p != ALPROTO_SMB) { - SCReturnUInt(p); - } - - uint16_t fsp = (f->flags & FLOW_DIR_REVERSED) ? f->dp : f->sp; - uint16_t fdp = (f->flags & FLOW_DIR_REVERSED) ? f->sp : f->dp; - SCLogDebug("direction %s flow sp %u dp %u fsp %u fdp %u", - (direction & STREAM_TOSERVER) ? "toserver" : "toclient", - f->sp, f->dp, fsp, fdp); - - if (fsp == 445 && fdp != 445) { - if (direction & STREAM_TOSERVER) { - *rdir = STREAM_TOCLIENT; - } else { - *rdir = STREAM_TOSERVER; - } - } - SCLogDebug("returning ALPROTO_SMB for dir %s with rdir %s", - (direction & STREAM_TOSERVER) ? "toserver" : "toclient", - (*rdir == STREAM_TOSERVER) ? "toserver" : "toclient"); - SCReturnUInt(ALPROTO_SMB); -} - -static int SMBGetAlstateProgress(void *tx, uint8_t direction) -{ - return rs_smb_tx_get_alstate_progress(tx, direction); -} - -static uint64_t SMBGetTxCnt(void *alstate) -{ - return rs_smb_state_get_tx_count(alstate); -} - -static void *SMBGetTx(void *alstate, uint64_t tx_id) -{ - return rs_smb_state_get_tx(alstate, tx_id); -} - -static AppLayerGetTxIterTuple SMBGetTxIterator( - const uint8_t ipproto, const AppProto alproto, - void *alstate, uint64_t min_tx_id, uint64_t max_tx_id, - AppLayerGetTxIterState *istate) -{ - return rs_smb_state_get_tx_iterator( - ipproto, alproto, alstate, min_tx_id, max_tx_id, (uint64_t *)istate); -} - - -static void SMBStateTransactionFree(void *state, uint64_t tx_id) -{ - rs_smb_state_tx_free(state, tx_id); -} - -static DetectEngineState *SMBGetTxDetectState(void *tx) -{ - return rs_smb_state_get_tx_detect_state(tx); -} - -static int SMBSetTxDetectState(void *tx, DetectEngineState *s) -{ - rs_smb_state_set_tx_detect_state(tx, s); - return 0; -} - -static FileContainer *SMBGetFiles(void *state, uint8_t direction) -{ - return rs_smb_getfiles(state, direction); -} - -static AppLayerDecoderEvents *SMBGetEvents(void *tx) -{ - return rs_smb_state_get_events(tx); -} - -static int SMBGetEventInfoById(int event_id, const char **event_name, - AppLayerEventType *event_type) -{ - return rs_smb_state_get_event_info_by_id(event_id, event_name, event_type); -} - -static int SMBGetEventInfo(const char *event_name, int *event_id, - AppLayerEventType *event_type) -{ - return rs_smb_state_get_event_info(event_name, event_id, event_type); -} - -static void SMBStateTruncate(void *state, uint8_t direction) -{ - return rs_smb_state_truncate(state, direction); -} - -static int SMBRegisterPatternsForProtocolDetection(void) -{ - int r = 0; - /* SMB1 */ - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|ff|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|ff|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - - /* SMB2/3 */ - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|fe|SMB", 8, 4, STREAM_TOSERVER, SMBTCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|fe|SMB", 8, 4, STREAM_TOCLIENT, SMBTCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - - /* SMB3 encrypted records */ - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|fd|SMB", 8, 4, STREAM_TOSERVER, SMB3TCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - r |= AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP, ALPROTO_SMB, - "|fd|SMB", 8, 4, STREAM_TOCLIENT, SMB3TCPProbe, - MIN_REC_SIZE, MIN_REC_SIZE); - return r == 0 ? 0 : -1; -} static StreamingBufferConfig sbcfg = STREAMING_BUFFER_CONFIG_INITIALIZER; static SuricataFileContext sfc = { &sbcfg }; @@ -231,95 +38,11 @@ static SuricataFileContext sfc = { &sbcfg }; static void SMBParserRegisterTests(void); #endif -static uint32_t stream_depth = SMB_CONFIG_DEFAULT_STREAM_DEPTH; - void RegisterSMBParsers(void) { - const char *proto_name = "smb"; - - /** SMB */ - if (AppLayerProtoDetectConfProtoDetectionEnabled("tcp", proto_name)) { - AppLayerProtoDetectRegisterProtocol(ALPROTO_SMB, proto_name); - if (SMBRegisterPatternsForProtocolDetection() < 0) - return; - - rs_smb_init(&sfc); - - if (RunmodeIsUnittests()) { - AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0, - MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe, - SMBTCPProbe); - } else { - int have_cfg = AppLayerProtoDetectPPParseConfPorts("tcp", - IPPROTO_TCP, proto_name, ALPROTO_SMB, 0, - MIN_REC_SIZE, SMBTCPProbe, SMBTCPProbe); - /* if we have no config, we enable the default port 445 */ - if (!have_cfg) { - SCLogConfig("no SMB TCP config found, enabling SMB detection " - "on port 445."); - AppLayerProtoDetectPPRegister(IPPROTO_TCP, "445", ALPROTO_SMB, 0, - MIN_REC_SIZE, STREAM_TOSERVER, SMBTCPProbe, - SMBTCPProbe); - } - } - } else { - SCLogConfig("Protocol detection and parser disabled for %s protocol.", - proto_name); - return; - } - - if (AppLayerParserConfParserEnabled("tcp", proto_name)) { - AppLayerParserRegisterParser(IPPROTO_TCP, ALPROTO_SMB, STREAM_TOSERVER, - SMBTCPParseRequest); - AppLayerParserRegisterParser(IPPROTO_TCP , ALPROTO_SMB, STREAM_TOCLIENT, - SMBTCPParseResponse); - AppLayerParserRegisterStateFuncs(IPPROTO_TCP, ALPROTO_SMB, - rs_smb_state_new, rs_smb_state_free); - AppLayerParserRegisterTxFreeFunc(IPPROTO_TCP, ALPROTO_SMB, - SMBStateTransactionFree); - - AppLayerParserRegisterGetEventsFunc(IPPROTO_TCP, ALPROTO_SMB, - SMBGetEvents); - AppLayerParserRegisterGetEventInfo(IPPROTO_TCP, ALPROTO_SMB, - SMBGetEventInfo); - AppLayerParserRegisterGetEventInfoById(IPPROTO_TCP, ALPROTO_SMB, - SMBGetEventInfoById); - - AppLayerParserRegisterDetectStateFuncs(IPPROTO_TCP, ALPROTO_SMB, - SMBGetTxDetectState, SMBSetTxDetectState); - AppLayerParserRegisterGetTx(IPPROTO_TCP, ALPROTO_SMB, SMBGetTx); - AppLayerParserRegisterGetTxIterator(IPPROTO_TCP, ALPROTO_SMB, SMBGetTxIterator); - AppLayerParserRegisterGetTxCnt(IPPROTO_TCP, ALPROTO_SMB, - SMBGetTxCnt); - AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_SMB, - SMBGetAlstateProgress); - AppLayerParserRegisterStateProgressCompletionStatus(ALPROTO_SMB, 1, 1); - AppLayerParserRegisterTruncateFunc(IPPROTO_TCP, ALPROTO_SMB, - SMBStateTruncate); - AppLayerParserRegisterGetFilesFunc(IPPROTO_TCP, ALPROTO_SMB, SMBGetFiles); - - AppLayerParserRegisterTxDataFunc(IPPROTO_TCP, ALPROTO_SMB, rs_smb_get_tx_data); - - /* This parser accepts gaps. */ - AppLayerParserRegisterOptionFlags(IPPROTO_TCP, ALPROTO_SMB, - APP_LAYER_PARSER_OPT_ACCEPT_GAPS); - - ConfNode *p = ConfGetNode("app-layer.protocols.smb.stream-depth"); - if (p != NULL) { - uint32_t value; - if (ParseSizeStringU32(p->val, &value) < 0) { - SCLogError(SC_ERR_SMB_CONFIG, "invalid value for stream-depth %s", p->val); - } else { - stream_depth = value; - } - } - SCLogConfig("SMB stream depth: %u", stream_depth); + rs_smb_init(&sfc); + rs_smb_register_parser(); - AppLayerParserSetStreamDepth(IPPROTO_TCP, ALPROTO_SMB, stream_depth); - } else { - SCLogConfig("Parsed disabled for %s protocol. Protocol detection" - "still on.", proto_name); - } #ifdef UNITTESTS AppLayerParserRegisterProtocolUnittests(IPPROTO_TCP, ALPROTO_SMB, SMBParserRegisterTests); #endif -- 2.47.2