From: Shivani Bhardwaj Date: Wed, 11 Aug 2021 11:29:48 +0000 (+0530) Subject: dcerpc: use Direction enum X-Git-Tag: suricata-7.0.0-beta1~1226 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a866499bcaa676768f99a2bd36f5193fd4a4e07e;p=thirdparty%2Fsuricata.git dcerpc: use Direction enum --- diff --git a/rust/src/dcerpc/dcerpc.rs b/rust/src/dcerpc/dcerpc.rs index 6dec6e22cb..30d4935626 100644 --- a/rust/src/dcerpc/dcerpc.rs +++ b/rust/src/dcerpc/dcerpc.rs @@ -321,8 +321,8 @@ pub struct DCERPCState { pub bytes_consumed: u16, pub tx_id: u64, pub query_completed: bool, - pub data_needed_for_dir: u8, - pub prev_dir: u8, + pub data_needed_for_dir: Direction, + pub prev_dir: Direction, pub prev_tx_call_id: u32, pub clear_bind_cache: bool, pub ts_gap: bool, @@ -337,8 +337,8 @@ pub struct DCERPCState { impl DCERPCState { pub fn new() -> Self { return Self { - data_needed_for_dir: core::STREAM_TOSERVER, - prev_dir: core::STREAM_TOSERVER, + data_needed_for_dir: Direction::ToServer, + prev_dir: Direction::ToServer, ..Default::default() } } @@ -450,13 +450,13 @@ impl DCERPCState { return 0; } - pub fn clean_buffer(&mut self, direction: u8) { + pub fn clean_buffer(&mut self, direction: Direction) { match direction { - core::STREAM_TOSERVER => { + Direction::ToServer => { self.buffer_ts.clear(); self.ts_gap = false; } - _ => { + Direction::ToClient => { self.buffer_tc.clear(); self.tc_gap = false; } @@ -464,23 +464,23 @@ impl DCERPCState { self.bytes_consumed = 0; } - pub fn extend_buffer(&mut self, buffer: &[u8], direction: u8) { + pub fn extend_buffer(&mut self, buffer: &[u8], direction: Direction) { match direction { - core::STREAM_TOSERVER => { + Direction::ToServer => { self.buffer_ts.extend_from_slice(buffer); } - _ => { + Direction::ToClient => { self.buffer_tc.extend_from_slice(buffer); } } self.data_needed_for_dir = direction; } - pub fn reset_direction(&mut self, direction: u8) { - if direction == core::STREAM_TOSERVER { - self.data_needed_for_dir = core::STREAM_TOCLIENT; + pub fn reset_direction(&mut self, direction: Direction) { + if direction == Direction::ToServer { + self.data_needed_for_dir = Direction::ToClient; } else { - self.data_needed_for_dir = core::STREAM_TOSERVER; + self.data_needed_for_dir = Direction::ToServer; } } @@ -513,24 +513,24 @@ impl DCERPCState { /// type: unsigned 32 bit integer /// description: call_id param derived from TCP Header /// * `dir`: - /// type: unsigned 8 bit integer + /// type: enum Direction /// description: direction of the flow /// /// Return value: /// Option mutable reference to DCERPCTransaction - pub fn get_tx_by_call_id(&mut self, call_id: u32, dir: u8) -> Option<&mut DCERPCTransaction> { + pub fn get_tx_by_call_id(&mut self, call_id: u32, dir: Direction) -> Option<&mut DCERPCTransaction> { let cmd = self.get_hdr_type().unwrap_or(0); for tx in &mut self.transactions { let found = tx.call_id == call_id; if found { match dir { - core::STREAM_TOSERVER => { + Direction::ToServer => { let resp_cmd = get_resp_type_for_req(cmd); if resp_cmd != tx.resp_cmd { continue; } } - _ => { + Direction::ToClient => { let req_cmd = get_req_type_for_resp(cmd); if req_cmd != tx.req_cmd { continue; @@ -556,13 +556,13 @@ impl DCERPCState { self.prev_tx_call_id = call_id; } - pub fn parse_data_gap(&mut self, direction: u8) -> AppLayerResult { + pub fn parse_data_gap(&mut self, direction: Direction) -> AppLayerResult { match direction { - core::STREAM_TOSERVER => { + Direction::ToServer => { self.ts_gap = true; self.ts_ssn_gap = true; }, - _ => { + Direction::ToClient => { self.tc_gap = true; self.tc_ssn_gap = true; }, @@ -570,9 +570,9 @@ impl DCERPCState { AppLayerResult::ok() } - pub fn post_gap_housekeeping(&mut self, dir: u8) { + pub fn post_gap_housekeeping(&mut self, dir: Direction) { SCLogDebug!("ts ssn gap: {:?}, tc ssn gap: {:?}, dir: {:?}", self.ts_ssn_gap, self.tc_ssn_gap, dir); - if self.ts_ssn_gap && dir == core::STREAM_TOSERVER { + if self.ts_ssn_gap && dir == Direction::ToServer { for tx in &mut self.transactions { if tx.id >= self.tx_id { SCLogDebug!("post_gap_housekeeping: done"); @@ -583,10 +583,10 @@ impl DCERPCState { } tx.req_done = true; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir as i32); } } - } else if self.tc_ssn_gap && dir == core::STREAM_TOCLIENT { + } else if self.tc_ssn_gap && dir == Direction::ToClient { for tx in &mut self.transactions { if tx.id >= self.tx_id { SCLogDebug!("post_gap_housekeeping: done"); @@ -601,7 +601,7 @@ impl DCERPCState { tx.req_done = true; tx.resp_done = true; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir as i32); } } } @@ -716,7 +716,7 @@ impl DCERPCState { tx.req_cmd = self.get_hdr_type().unwrap_or(0); tx.req_done = true; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); } tx.frag_cnt_ts = 1; self.transactions.push(tx); @@ -772,7 +772,7 @@ impl DCERPCState { } } - pub fn handle_stub_data(&mut self, input: &[u8], input_len: u16, dir: u8) -> u16 { + pub fn handle_stub_data(&mut self, input: &[u8], input_len: u16, dir: Direction) -> u16 { let retval; let hdrpfcflags = self.get_hdr_pfcflags().unwrap_or(0); let padleft = self.padleft; @@ -801,7 +801,7 @@ impl DCERPCState { tx.req_done = true; tx.frag_cnt_ts = 1; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); } } DCERPC_TYPE_RESPONSE => { @@ -816,7 +816,7 @@ impl DCERPCState { tx.resp_done = true; tx.frag_cnt_tc = 1; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); } } _ => { @@ -844,13 +844,13 @@ impl DCERPCState { /// type: 16 bit unsigned integer. /// description: bytes consumed *after* parsing header. /// * `dir`: - /// type: 8 bit unsigned integer. + /// type: enum Direction. /// description: direction whose stub is supposed to be handled. /// /// Return value: /// * Success: Number of bytes successfully parsed. /// * Failure: -1 in case fragment length defined by header mismatches the data. - pub fn handle_common_stub(&mut self, input: &[u8], bytes_consumed: u16, dir: u8) -> i32 { + pub fn handle_common_stub(&mut self, input: &[u8], bytes_consumed: u16, dir: Direction) -> i32 { let fraglen = self.get_hdr_fraglen().unwrap_or(0); if fraglen < bytes_consumed as u16 + DCERPC_HDR_LEN { return -1; @@ -866,7 +866,7 @@ impl DCERPCState { } else if input_left > 0 { SCLogDebug!( "Error parsing DCERPC {} stub data", - if dir == core::STREAM_TOSERVER { + if dir == Direction::ToServer { "request" } else { "response" @@ -885,7 +885,7 @@ impl DCERPCState { Ok((leftover_input, request)) => { let call_id = self.get_hdr_call_id().unwrap_or(0); let hdr_type = self.get_hdr_type().unwrap_or(0); - let mut transaction = self.get_tx_by_call_id(call_id, core::STREAM_TOSERVER); + let mut transaction = self.get_tx_by_call_id(call_id, Direction::ToServer); match transaction { Some(ref mut tx) => { tx.req_cmd = hdr_type; @@ -905,7 +905,7 @@ impl DCERPCState { let parsed = self.handle_common_stub( input, (input.len() - leftover_input.len()) as u16, - core::STREAM_TOSERVER, + Direction::ToServer, ); parsed } @@ -922,7 +922,7 @@ impl DCERPCState { } } - pub fn handle_input_data(&mut self, input: &[u8], direction: u8) -> AppLayerResult { + pub fn handle_input_data(&mut self, input: &[u8], direction: Direction) -> AppLayerResult { let mut parsed; let retval; let mut cur_i = input; @@ -932,7 +932,7 @@ impl DCERPCState { self.query_completed = false; // Skip the record since this means that its in the middle of a known length record - if (self.ts_gap && direction == core::STREAM_TOSERVER) || (self.tc_gap && direction == core::STREAM_TOCLIENT) { + if (self.ts_gap && direction == Direction::ToServer) || (self.tc_gap && direction == Direction::ToClient) { SCLogDebug!("Trying to catch up after GAP (input {})", cur_i.len()); match self.search_dcerpc_record(cur_i) { Ok((_, pg)) => { @@ -940,10 +940,10 @@ impl DCERPCState { let offset = cur_i.len() - pg.len(); cur_i = &cur_i[offset..]; match direction { - core::STREAM_TOSERVER => { + Direction::ToServer => { self.ts_gap = false; }, - _ => { + Direction::ToClient => { self.tc_gap = false; } } @@ -969,7 +969,7 @@ impl DCERPCState { } let buffer = match direction { - core::STREAM_TOSERVER => { + Direction::ToServer => { if self.buffer_ts.len() + input_len > 1024 * 1024 { SCLogDebug!("DCERPC TOSERVER stream: Buffer Overflow"); return AppLayerResult::err(); @@ -978,7 +978,7 @@ impl DCERPCState { v.extend_from_slice(cur_i); v.as_slice() } - _ => { + Direction::ToClient => { if self.buffer_tc.len() + input_len > 1024 * 1024 { SCLogDebug!("DCERPC TOCLIENT stream: Buffer Overflow"); return AppLayerResult::err(); @@ -1037,7 +1037,7 @@ impl DCERPCState { if retval == -1 { return AppLayerResult::err(); } - let tx = if let Some(tx) = self.get_tx_by_call_id(current_call_id, core::STREAM_TOCLIENT) { + let tx = if let Some(tx) = self.get_tx_by_call_id(current_call_id, Direction::ToClient) { tx.resp_cmd = x; tx } else { @@ -1049,7 +1049,7 @@ impl DCERPCState { tx.resp_done = true; tx.frag_cnt_tc = 1; if let Some(flow) = self.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into()); + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); } self.handle_bind_cache(current_call_id, false); } @@ -1063,7 +1063,7 @@ impl DCERPCState { self.handle_bind_cache(current_call_id, false); } DCERPC_TYPE_RESPONSE => { - let transaction = self.get_tx_by_call_id(current_call_id, core::STREAM_TOCLIENT); + let transaction = self.get_tx_by_call_id(current_call_id, Direction::ToClient); match transaction { Some(tx) => { tx.resp_cmd = x; @@ -1077,7 +1077,7 @@ impl DCERPCState { retval = self.handle_common_stub( &buffer[parsed as usize..], 0, - core::STREAM_TOCLIENT, + Direction::ToClient, ); if retval < 0 { return AppLayerResult::err(); @@ -1132,7 +1132,7 @@ pub extern "C" fn rs_parse_dcerpc_request_gap( state: &mut DCERPCState, _input_len: u32, ) -> AppLayerResult { - state.parse_data_gap(core::STREAM_TOSERVER) + state.parse_data_gap(Direction::ToServer) } #[no_mangle] @@ -1140,7 +1140,7 @@ pub extern "C" fn rs_parse_dcerpc_response_gap( state: &mut DCERPCState, _input_len: u32, ) -> AppLayerResult { - state.parse_data_gap(core::STREAM_TOCLIENT) + state.parse_data_gap(Direction::ToClient) } #[no_mangle] @@ -1161,7 +1161,7 @@ pub unsafe extern "C" fn rs_dcerpc_parse_request( if input_len > 0 && !input.is_null() { let buf = build_slice!(input, input_len as usize); state.flow = Some(flow); - return state.handle_input_data(buf, core::STREAM_TOSERVER); + return state.handle_input_data(buf, Direction::ToServer); } AppLayerResult::err() } @@ -1183,7 +1183,7 @@ pub unsafe extern "C" fn rs_dcerpc_parse_response( if !input.is_null() { let buf = build_slice!(input, input_len as usize); state.flow = Some(flow); - return state.handle_input_data(buf, core::STREAM_TOCLIENT); + return state.handle_input_data(buf, Direction::ToClient); } } AppLayerResult::err() @@ -1211,24 +1211,27 @@ pub unsafe extern "C" fn rs_dcerpc_state_transaction_free(state: *mut std::os::r #[no_mangle] pub unsafe extern "C" fn rs_dcerpc_state_trunc(state: *mut std::os::raw::c_void, direction: u8) { let dce_state = cast_pointer!(state, DCERPCState); - if direction & core::STREAM_TOSERVER != 0 { - dce_state.ts_ssn_trunc = true; - for tx in &mut dce_state.transactions { - tx.req_done = true; - if let Some(flow) = dce_state.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into()); + match direction.into() { + Direction::ToServer => { + dce_state.ts_ssn_trunc = true; + for tx in &mut dce_state.transactions { + tx.req_done = true; + if let Some(flow) = dce_state.flow { + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); + } } + SCLogDebug!("dce_state.ts_ssn_trunc = true; txs {}", dce_state.transactions.len()); } - SCLogDebug!("dce_state.ts_ssn_trunc = true; txs {}", dce_state.transactions.len()); - } else if direction & core::STREAM_TOCLIENT != 0 { - dce_state.tc_ssn_trunc = true; - for tx in &mut dce_state.transactions { - tx.resp_done = true; - if let Some(flow) = dce_state.flow { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into()); + Direction::ToClient => { + dce_state.tc_ssn_trunc = true; + for tx in &mut dce_state.transactions { + tx.resp_done = true; + if let Some(flow) = dce_state.flow { + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); + } } + SCLogDebug!("dce_state.tc_ssn_trunc = true; txs {}", dce_state.transactions.len()); } - SCLogDebug!("dce_state.tc_ssn_trunc = true; txs {}", dce_state.transactions.len()); } } @@ -1273,10 +1276,10 @@ pub unsafe extern "C" fn rs_dcerpc_get_tx_cnt(vtx: *mut std::os::raw::c_void) -> pub unsafe extern "C" fn rs_dcerpc_get_alstate_progress(tx: *mut std::os::raw::c_void, direction: u8 )-> std::os::raw::c_int { let tx = cast_pointer!(tx, DCERPCTransaction); - if direction == core::STREAM_TOSERVER && tx.req_done { + if direction == Direction::ToServer.into() && tx.req_done { SCLogDebug!("tx {} TOSERVER progress 1 => {:?}", tx.call_id, tx); return 1; - } else if direction == core::STREAM_TOCLIENT && tx.resp_done { + } else if direction == Direction::ToClient.into() && tx.resp_done { SCLogDebug!("tx {} TOCLIENT progress 1 => {:?}", tx.call_id, tx); return 1; } @@ -1297,13 +1300,13 @@ pub unsafe extern "C" fn rs_dcerpc_get_tx_data( pub unsafe extern "C" fn rs_dcerpc_get_stub_data( tx: &mut DCERPCTransaction, buf: *mut *const u8, len: *mut u32, endianness: *mut u8, dir: u8, ) { - match dir { - core::STREAM_TOSERVER => { + match dir.into() { + Direction::ToServer => { *len = tx.stub_data_buffer_ts.len() as u32; *buf = tx.stub_data_buffer_ts.as_ptr(); SCLogDebug!("DCERPC Request stub buffer: Setting buffer to: {:?}", *buf); } - _ => { + Direction::ToClient => { *len = tx.stub_data_buffer_tc.len() as u32; *buf = tx.stub_data_buffer_tc.as_ptr(); SCLogDebug!("DCERPC Response stub buffer: Setting buffer to: {:?}", *buf); @@ -1339,12 +1342,12 @@ pub unsafe extern "C" fn rs_dcerpc_probe_tcp(_f: *const core::Flow, direction: u let (is_dcerpc, is_request, ) = probe(slice); if is_dcerpc { let dir = if is_request { - core::STREAM_TOSERVER + Direction::ToServer } else { - core::STREAM_TOCLIENT + Direction::ToClient }; - if direction & (core::STREAM_TOSERVER|core::STREAM_TOCLIENT) != dir { - *rdir = dir; + if (direction & DIR_BOTH) != dir as u8 { + *rdir = dir as u8; } return ALPROTO_DCERPC; } @@ -1355,13 +1358,13 @@ fn register_pattern_probe() -> i8 { unsafe { if AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_DCERPC, b"|05 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0, - core::STREAM_TOSERVER, rs_dcerpc_probe_tcp, 0, 0) < 0 { + Direction::ToServer.into(), rs_dcerpc_probe_tcp, 0, 0) < 0 { SCLogDebug!("TOSERVER => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED"); return -1; } if AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_DCERPC, b"|05 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0, - core::STREAM_TOCLIENT, rs_dcerpc_probe_tcp, 0, 0) < 0 { + Direction::ToClient.into(), rs_dcerpc_probe_tcp, 0, 0) < 0 { SCLogDebug!("TOCLIENT => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED"); return -1; } @@ -1437,7 +1440,7 @@ pub unsafe extern "C" fn rs_dcerpc_register_parser() { #[cfg(test)] mod tests { use crate::applayer::AppLayerResult; - use crate::core; + use crate::core::*; use crate::dcerpc::dcerpc::DCERPCState; use std::cmp; @@ -1864,7 +1867,7 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request, Direction::ToServer) ); if let Some(hdr) = dcerpc_state.header { assert_eq!(0, hdr.hdrtype); @@ -1900,11 +1903,11 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), // TODO ASK if this is correct? - dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind2, Direction::ToServer) ); } @@ -1970,11 +1973,11 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind2, Direction::ToServer) ); if let Some(ref bind) = dcerpc_state.bind { assert_eq!(16, bind.numctxitems); @@ -1994,15 +1997,15 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request2, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request3, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request3, Direction::ToServer) ); let tx = &dcerpc_state.transactions[0]; assert_eq!(20, tx.stub_data_buffer_ts.len()); @@ -2018,7 +2021,7 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); } @@ -2032,7 +2035,7 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); } @@ -2052,15 +2055,15 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::err(), - dcerpc_state.handle_input_data(fault, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(fault, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request2, Direction::ToServer) ); let tx = &dcerpc_state.transactions[0]; assert_eq!(12, tx.stub_data_buffer_ts.len()); @@ -2082,15 +2085,15 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request2, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request3, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request3, Direction::ToServer) ); } @@ -2107,14 +2110,14 @@ mod tests { 0x00, 0x00, ]; let mut dcerpc_state = DCERPCState::new(); - dcerpc_state.data_needed_for_dir = core::STREAM_TOCLIENT; + dcerpc_state.data_needed_for_dir = Direction::ToClient; assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack1, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack1, Direction::ToClient) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack2, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack2, Direction::ToClient) ); } @@ -2137,7 +2140,7 @@ mod tests { ]; assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bindbuf, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bindbuf, Direction::ToServer) ); if let Some(ref bind) = dcerpc_state.bind { let bind_uuid = &bind.uuid_list[0].uuid; @@ -2171,7 +2174,7 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bindbuf, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bindbuf, Direction::ToServer) ); } @@ -2188,10 +2191,10 @@ mod tests { 0xFF, ]; let mut dcerpc_state = DCERPCState::new(); - dcerpc_state.data_needed_for_dir = core::STREAM_TOCLIENT; + dcerpc_state.data_needed_for_dir = Direction::ToClient; assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack, Direction::ToClient) ); } @@ -2442,11 +2445,11 @@ mod tests { ]; assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack1, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack1, Direction::ToClient) ); if let Some(ref back) = dcerpc_state.bindack { assert_eq!(1, back.accepted_uuid_list.len()); @@ -2455,11 +2458,11 @@ mod tests { } assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind2, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack2, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack2, Direction::ToClient) ); if let Some(ref back) = dcerpc_state.bindack { assert_eq!(1, back.accepted_uuid_list.len()); @@ -2468,15 +2471,15 @@ mod tests { } assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind3, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind3, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind_ack3, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bind_ack3, Direction::ToClient) ); if let Some(ref back) = dcerpc_state.bindack { assert_eq!(1, back.accepted_uuid_list.len()); - dcerpc_state.data_needed_for_dir = core::STREAM_TOSERVER; + dcerpc_state.data_needed_for_dir = Direction::ToServer; assert_eq!(11, back.accepted_uuid_list[0].ctxid); assert_eq!(expected_uuid3, back.accepted_uuid_list[0].uuid); } @@ -2525,11 +2528,11 @@ mod tests { ]; assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bind, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(bind, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(bindack, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(bindack, Direction::ToClient) ); if let Some(ref back) = dcerpc_state.bindack { assert_eq!(1, back.accepted_uuid_list.len()); @@ -2538,11 +2541,11 @@ mod tests { } assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(alter_context, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(alter_context, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(alter_context_resp, core::STREAM_TOCLIENT) + dcerpc_state.handle_input_data(alter_context_resp, Direction::ToClient) ); if let Some(ref back) = dcerpc_state.bindack { assert_eq!(1, back.accepted_uuid_list.len()); @@ -2564,11 +2567,11 @@ mod tests { let mut dcerpc_state = DCERPCState::new(); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request1, Direction::ToServer) ); assert_eq!( AppLayerResult::ok(), - dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER) + dcerpc_state.handle_input_data(request2, Direction::ToServer) ); let tx = &dcerpc_state.transactions[0]; assert_eq!(2, tx.opnum); diff --git a/rust/src/dcerpc/dcerpc_udp.rs b/rust/src/dcerpc/dcerpc_udp.rs index 0ece8025ef..376fc4755f 100644 --- a/rust/src/dcerpc/dcerpc_udp.rs +++ b/rust/src/dcerpc/dcerpc_udp.rs @@ -16,7 +16,7 @@ */ use crate::applayer::*; -use crate::core; +use crate::core::{self, Direction, DIR_BOTH}; use crate::dcerpc::dcerpc::{ DCERPCTransaction, DCERPC_TYPE_REQUEST, DCERPC_TYPE_RESPONSE, PFCL1_FRAG, PFCL1_LASTFRAG, rs_dcerpc_get_alstate_progress, ALPROTO_DCERPC, PARSER_NAME, @@ -314,14 +314,16 @@ pub unsafe extern "C" fn rs_dcerpc_probe_udp(_f: *const core::Flow, direction: u //is_incomplete is checked by caller let (is_dcerpc, is_request) = probe(slice); if is_dcerpc { - let dir = if is_request { - core::STREAM_TOSERVER + let dir: Direction = (direction & DIR_BOTH).into(); + if is_request { + if dir != Direction::ToServer { + *rdir = Direction::ToServer.into(); + } } else { - core::STREAM_TOCLIENT + if dir != Direction::ToClient { + *rdir = Direction::ToClient.into(); + } }; - if direction & (core::STREAM_TOSERVER|core::STREAM_TOCLIENT) != dir { - *rdir = dir; - } return ALPROTO_DCERPC; } return core::ALPROTO_FAILED; @@ -331,7 +333,7 @@ fn register_pattern_probe() -> i8 { unsafe { if AppLayerProtoDetectPMRegisterPatternCSwPP(core::IPPROTO_UDP as u8, ALPROTO_DCERPC, b"|04 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0, - core::STREAM_TOSERVER, rs_dcerpc_probe_udp, 0, 0) < 0 { + Direction::ToServer.into(), rs_dcerpc_probe_udp, 0, 0) < 0 { SCLogDebug!("TOSERVER => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED"); return -1; }