From: Shivani Bhardwaj Date: Tue, 29 Jun 2021 09:30:07 +0000 (+0530) Subject: smb: add missing code from rust impl of fns X-Git-Tag: suricata-7.0.0-beta1~1545 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=27af4bb0024bc2d1689ed7ed5708967e69174e66;p=thirdparty%2Fsuricata.git smb: add missing code from rust impl of fns --- diff --git a/rust/src/smb/files.rs b/rust/src/smb/files.rs index d02bca8e2d..704d60a0cf 100644 --- a/rust/src/smb/files.rs +++ b/rust/src/smb/files.rs @@ -112,7 +112,7 @@ impl SMBState { } fn setfileflags(&mut self, direction: u8, flags: u16) { SCLogDebug!("direction: {}, flags: {}", direction, flags); - if direction == 1 { + if direction == STREAM_TOCLIENT { self.files.flags_tc = flags; } else { self.files.flags_ts = flags; diff --git a/rust/src/smb/smb.rs b/rust/src/smb/smb.rs index 7837d71873..d5576e7030 100644 --- a/rust/src/smb/smb.rs +++ b/rust/src/smb/smb.rs @@ -1824,8 +1824,13 @@ pub extern "C" fn rs_smb_parse_request_tcp(flow: *const Flow, let buf = unsafe{std::slice::from_raw_parts(input, input_len as usize)}; let mut state = cast_pointer!(state, SMBState); let flow = cast_pointer!(flow, Flow); + let file_flags = unsafe { FileFlowToFlags(flow, STREAM_TOSERVER) }; + rs_smb_setfileflags(STREAM_TOSERVER, state, file_flags|FILE_USE_DETECT); SCLogDebug!("parsing {} bytes of request data", input_len); + if input.is_null() && input_len > 0 { + return rs_smb_parse_request_tcp_gap(state, input_len); + } /* START with MISTREAM set: record might be starting the middle. */ if flags & (STREAM_START|STREAM_MIDSTREAM) == (STREAM_START|STREAM_MIDSTREAM) { state.ts_gap = true; @@ -1857,6 +1862,12 @@ pub extern "C" fn rs_smb_parse_response_tcp(flow: *const Flow, { let mut state = cast_pointer!(state, SMBState); let flow = cast_pointer!(flow, Flow); + let file_flags = unsafe { FileFlowToFlags(flow, STREAM_TOCLIENT) }; + rs_smb_setfileflags(STREAM_TOCLIENT, state, file_flags|FILE_USE_DETECT); + + if input.is_null() && input_len > 0 { + return rs_smb_parse_response_tcp_gap(state, input_len); + } SCLogDebug!("parsing {} bytes of response data", input_len); let buf = unsafe{std::slice::from_raw_parts(input, input_len as usize)}; @@ -1948,6 +1959,9 @@ pub extern "C" fn rs_smb_probe_tcp(_f: *const Flow, flags: u8, input: *const u8, len: u32, rdir: *mut u8) -> AppProto { + if len < MIN_REC_SIZE as u32 { + return ALPROTO_UNKNOWN; + } let slice = build_slice!(input, len as usize); if flags & STREAM_MIDSTREAM == STREAM_MIDSTREAM { if smb_probe_tcp_midstream(flags, slice, rdir) == 1 {