]> git.ipfire.org Git - people/ms/suricata.git/commitdiff
dcerpc: use Direction enum
authorShivani Bhardwaj <shivanib134@gmail.com>
Wed, 11 Aug 2021 11:29:48 +0000 (16:59 +0530)
committerVictor Julien <vjulien@oisf.net>
Fri, 19 Nov 2021 16:20:01 +0000 (17:20 +0100)
rust/src/dcerpc/dcerpc.rs
rust/src/dcerpc/dcerpc_udp.rs

index 6dec6e22cbae4f2a0658a74377b6fa3737d274c5..30d493562634403ecc0cdc29df8d3a9aa93fba6b 100644 (file)
@@ -321,8 +321,8 @@ pub struct DCERPCState {
     pub bytes_consumed: u16,
     pub tx_id: u64,
     pub query_completed: bool,
-    pub data_needed_for_dir: u8,
-    pub prev_dir: u8,
+    pub data_needed_for_dir: Direction,
+    pub prev_dir: Direction,
     pub prev_tx_call_id: u32,
     pub clear_bind_cache: bool,
     pub ts_gap: bool,
@@ -337,8 +337,8 @@ pub struct DCERPCState {
 impl DCERPCState {
     pub fn new() -> Self {
         return Self {
-            data_needed_for_dir: core::STREAM_TOSERVER,
-            prev_dir: core::STREAM_TOSERVER,
+            data_needed_for_dir: Direction::ToServer,
+            prev_dir: Direction::ToServer,
             ..Default::default()
         }
     }
@@ -450,13 +450,13 @@ impl DCERPCState {
         return 0;
     }
 
-    pub fn clean_buffer(&mut self, direction: u8) {
+    pub fn clean_buffer(&mut self, direction: Direction) {
         match direction {
-            core::STREAM_TOSERVER => {
+            Direction::ToServer => {
                 self.buffer_ts.clear();
                 self.ts_gap = false;
             }
-            _ => {
+            Direction::ToClient => {
                 self.buffer_tc.clear();
                 self.tc_gap = false;
             }
@@ -464,23 +464,23 @@ impl DCERPCState {
         self.bytes_consumed = 0;
     }
 
-    pub fn extend_buffer(&mut self, buffer: &[u8], direction: u8) {
+    pub fn extend_buffer(&mut self, buffer: &[u8], direction: Direction) {
         match direction {
-            core::STREAM_TOSERVER => {
+            Direction::ToServer => {
                 self.buffer_ts.extend_from_slice(buffer);
             }
-            _ => {
+            Direction::ToClient => {
                 self.buffer_tc.extend_from_slice(buffer);
             }
         }
         self.data_needed_for_dir = direction;
     }
 
-    pub fn reset_direction(&mut self, direction: u8) {
-        if direction == core::STREAM_TOSERVER {
-            self.data_needed_for_dir = core::STREAM_TOCLIENT;
+    pub fn reset_direction(&mut self, direction: Direction) {
+        if direction == Direction::ToServer {
+            self.data_needed_for_dir = Direction::ToClient;
         } else {
-            self.data_needed_for_dir = core::STREAM_TOSERVER;
+            self.data_needed_for_dir = Direction::ToServer;
         }
     }
 
@@ -513,24 +513,24 @@ impl DCERPCState {
     ///             type: unsigned 32 bit integer
     ///      description: call_id param derived from TCP Header
     /// * `dir`:
-    ///         type: unsigned 8 bit integer
+    ///         type: enum Direction
     ///  description: direction of the flow
     ///
     /// Return value:
     /// Option mutable reference to DCERPCTransaction
-    pub fn get_tx_by_call_id(&mut self, call_id: u32, dir: u8) -> Option<&mut DCERPCTransaction> {
+    pub fn get_tx_by_call_id(&mut self, call_id: u32, dir: Direction) -> Option<&mut DCERPCTransaction> {
         let cmd = self.get_hdr_type().unwrap_or(0);
         for tx in &mut self.transactions {
             let found = tx.call_id == call_id;
             if found {
                 match dir {
-                    core::STREAM_TOSERVER => {
+                    Direction::ToServer => {
                         let resp_cmd = get_resp_type_for_req(cmd);
                         if resp_cmd != tx.resp_cmd {
                             continue;
                         }
                     }
-                    _ => {
+                    Direction::ToClient => {
                         let req_cmd = get_req_type_for_resp(cmd);
                         if req_cmd != tx.req_cmd {
                             continue;
@@ -556,13 +556,13 @@ impl DCERPCState {
         self.prev_tx_call_id = call_id;
     }
 
-    pub fn parse_data_gap(&mut self, direction: u8) -> AppLayerResult {
+    pub fn parse_data_gap(&mut self, direction: Direction) -> AppLayerResult {
         match direction {
-            core::STREAM_TOSERVER => {
+            Direction::ToServer => {
                 self.ts_gap = true;
                 self.ts_ssn_gap = true;
             },
-            _ => {
+            Direction::ToClient => {
                 self.tc_gap = true;
                 self.tc_ssn_gap = true;
             },
@@ -570,9 +570,9 @@ impl DCERPCState {
         AppLayerResult::ok()
     }
 
-    pub fn post_gap_housekeeping(&mut self, dir: u8) {
+    pub fn post_gap_housekeeping(&mut self, dir: Direction) {
         SCLogDebug!("ts ssn gap: {:?}, tc ssn gap: {:?}, dir: {:?}", self.ts_ssn_gap, self.tc_ssn_gap, dir);
-        if self.ts_ssn_gap && dir == core::STREAM_TOSERVER {
+        if self.ts_ssn_gap && dir == Direction::ToServer {
             for tx in &mut self.transactions {
                 if tx.id >= self.tx_id {
                     SCLogDebug!("post_gap_housekeeping: done");
@@ -583,10 +583,10 @@ impl DCERPCState {
                 }
                 tx.req_done = true;
                 if let Some(flow) = self.flow {
-                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir.into());
+                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir as i32);
                 }
             }
-        } else if self.tc_ssn_gap && dir == core::STREAM_TOCLIENT {
+        } else if self.tc_ssn_gap && dir == Direction::ToClient {
             for tx in &mut self.transactions {
                 if tx.id >= self.tx_id {
                     SCLogDebug!("post_gap_housekeeping: done");
@@ -601,7 +601,7 @@ impl DCERPCState {
                 tx.req_done = true;
                 tx.resp_done = true;
                 if let Some(flow) = self.flow {
-                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir.into());
+                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, dir as i32);
                 }
             }
         }
@@ -716,7 +716,7 @@ impl DCERPCState {
                 tx.req_cmd = self.get_hdr_type().unwrap_or(0);
                 tx.req_done = true;
                 if let Some(flow) = self.flow {
-                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into());
+                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32);
                 }
                 tx.frag_cnt_ts = 1;
                 self.transactions.push(tx);
@@ -772,7 +772,7 @@ impl DCERPCState {
         }
     }
 
-    pub fn handle_stub_data(&mut self, input: &[u8], input_len: u16, dir: u8) -> u16 {
+    pub fn handle_stub_data(&mut self, input: &[u8], input_len: u16, dir: Direction) -> u16 {
         let retval;
         let hdrpfcflags = self.get_hdr_pfcflags().unwrap_or(0);
         let padleft = self.padleft;
@@ -801,7 +801,7 @@ impl DCERPCState {
                     tx.req_done = true;
                     tx.frag_cnt_ts = 1;
                     if let Some(flow) = self.flow {
-                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into());
+                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32);
                     }
                 }
                 DCERPC_TYPE_RESPONSE => {
@@ -816,7 +816,7 @@ impl DCERPCState {
                     tx.resp_done = true;
                     tx.frag_cnt_tc = 1;
                     if let Some(flow) = self.flow {
-                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into());
+                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32);
                     }
                 }
                 _ => {
@@ -844,13 +844,13 @@ impl DCERPCState {
     ///           type: 16 bit unsigned integer.
     ///    description: bytes consumed *after* parsing header.
     /// * `dir`:
-    ///           type: 8 bit unsigned integer.
+    ///           type: enum Direction.
     ///    description: direction whose stub is supposed to be handled.
     ///
     /// Return value:
     /// * Success: Number of bytes successfully parsed.
     /// * Failure: -1 in case fragment length defined by header mismatches the data.
-    pub fn handle_common_stub(&mut self, input: &[u8], bytes_consumed: u16, dir: u8) -> i32 {
+    pub fn handle_common_stub(&mut self, input: &[u8], bytes_consumed: u16, dir: Direction) -> i32 {
         let fraglen = self.get_hdr_fraglen().unwrap_or(0);
         if fraglen < bytes_consumed as u16 + DCERPC_HDR_LEN {
             return -1;
@@ -866,7 +866,7 @@ impl DCERPCState {
             } else if input_left > 0 {
                 SCLogDebug!(
                     "Error parsing DCERPC {} stub data",
-                    if dir == core::STREAM_TOSERVER {
+                    if dir == Direction::ToServer {
                         "request"
                     } else {
                         "response"
@@ -885,7 +885,7 @@ impl DCERPCState {
             Ok((leftover_input, request)) => {
                 let call_id = self.get_hdr_call_id().unwrap_or(0);
                 let hdr_type = self.get_hdr_type().unwrap_or(0);
-                let mut transaction = self.get_tx_by_call_id(call_id, core::STREAM_TOSERVER);
+                let mut transaction = self.get_tx_by_call_id(call_id, Direction::ToServer);
                 match transaction {
                     Some(ref mut tx) => {
                         tx.req_cmd = hdr_type;
@@ -905,7 +905,7 @@ impl DCERPCState {
                 let parsed = self.handle_common_stub(
                     input,
                     (input.len() - leftover_input.len()) as u16,
-                    core::STREAM_TOSERVER,
+                    Direction::ToServer,
                 );
                 parsed
             }
@@ -922,7 +922,7 @@ impl DCERPCState {
         }
     }
 
-    pub fn handle_input_data(&mut self, input: &[u8], direction: u8) -> AppLayerResult {
+    pub fn handle_input_data(&mut self, input: &[u8], direction: Direction) -> AppLayerResult {
         let mut parsed;
         let retval;
         let mut cur_i = input;
@@ -932,7 +932,7 @@ impl DCERPCState {
         self.query_completed = false;
 
         // Skip the record since this means that its in the middle of a known length record
-        if (self.ts_gap && direction == core::STREAM_TOSERVER) || (self.tc_gap && direction == core::STREAM_TOCLIENT) {
+        if (self.ts_gap && direction == Direction::ToServer) || (self.tc_gap && direction == Direction::ToClient) {
             SCLogDebug!("Trying to catch up after GAP (input {})", cur_i.len());
             match self.search_dcerpc_record(cur_i) {
                 Ok((_, pg)) => {
@@ -940,10 +940,10 @@ impl DCERPCState {
                     let offset = cur_i.len() - pg.len();
                     cur_i = &cur_i[offset..];
                     match direction {
-                        core::STREAM_TOSERVER => {
+                        Direction::ToServer => {
                             self.ts_gap = false;
                         },
-                        _ => {
+                        Direction::ToClient => {
                             self.tc_gap = false;
                         }
                     }
@@ -969,7 +969,7 @@ impl DCERPCState {
         }
 
         let buffer = match direction {
-            core::STREAM_TOSERVER => {
+            Direction::ToServer => {
                 if self.buffer_ts.len() + input_len > 1024 * 1024 {
                     SCLogDebug!("DCERPC TOSERVER stream: Buffer Overflow");
                     return AppLayerResult::err();
@@ -978,7 +978,7 @@ impl DCERPCState {
                 v.extend_from_slice(cur_i);
                 v.as_slice()
             }
-            _ => {
+            Direction::ToClient => {
                 if self.buffer_tc.len() + input_len > 1024 * 1024 {
                     SCLogDebug!("DCERPC TOCLIENT stream: Buffer Overflow");
                     return AppLayerResult::err();
@@ -1037,7 +1037,7 @@ impl DCERPCState {
                     if retval == -1 {
                         return AppLayerResult::err();
                     }
-                    let tx = if let Some(tx) = self.get_tx_by_call_id(current_call_id, core::STREAM_TOCLIENT) {
+                    let tx = if let Some(tx) = self.get_tx_by_call_id(current_call_id, Direction::ToClient) {
                         tx.resp_cmd = x;
                         tx
                     } else {
@@ -1049,7 +1049,7 @@ impl DCERPCState {
                     tx.resp_done = true;
                     tx.frag_cnt_tc = 1;
                     if let Some(flow) = self.flow {
-                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into());
+                        sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32);
                     }
                     self.handle_bind_cache(current_call_id, false);
                 }
@@ -1063,7 +1063,7 @@ impl DCERPCState {
                     self.handle_bind_cache(current_call_id, false);
                 }
                 DCERPC_TYPE_RESPONSE => {
-                    let transaction = self.get_tx_by_call_id(current_call_id, core::STREAM_TOCLIENT);
+                    let transaction = self.get_tx_by_call_id(current_call_id, Direction::ToClient);
                     match transaction {
                         Some(tx) => {
                             tx.resp_cmd = x;
@@ -1077,7 +1077,7 @@ impl DCERPCState {
                     retval = self.handle_common_stub(
                         &buffer[parsed as usize..],
                         0,
-                        core::STREAM_TOCLIENT,
+                        Direction::ToClient,
                     );
                     if retval < 0 {
                         return AppLayerResult::err();
@@ -1132,7 +1132,7 @@ pub extern "C" fn rs_parse_dcerpc_request_gap(
     state: &mut DCERPCState,
     _input_len: u32,
 ) -> AppLayerResult {
-    state.parse_data_gap(core::STREAM_TOSERVER)
+    state.parse_data_gap(Direction::ToServer)
 }
 
 #[no_mangle]
@@ -1140,7 +1140,7 @@ pub extern "C" fn rs_parse_dcerpc_response_gap(
     state: &mut DCERPCState,
     _input_len: u32,
 ) -> AppLayerResult {
-    state.parse_data_gap(core::STREAM_TOCLIENT)
+    state.parse_data_gap(Direction::ToClient)
 }
 
 #[no_mangle]
@@ -1161,7 +1161,7 @@ pub unsafe extern "C" fn rs_dcerpc_parse_request(
     if input_len > 0 && !input.is_null() {
         let buf = build_slice!(input, input_len as usize);
         state.flow = Some(flow);
-        return state.handle_input_data(buf, core::STREAM_TOSERVER);
+        return state.handle_input_data(buf, Direction::ToServer);
     }
     AppLayerResult::err()
 }
@@ -1183,7 +1183,7 @@ pub unsafe extern "C" fn rs_dcerpc_parse_response(
         if !input.is_null() {
             let buf = build_slice!(input, input_len as usize);
             state.flow = Some(flow);
-            return state.handle_input_data(buf, core::STREAM_TOCLIENT);
+            return state.handle_input_data(buf, Direction::ToClient);
         }
     }
     AppLayerResult::err()
@@ -1211,24 +1211,27 @@ pub unsafe extern "C" fn rs_dcerpc_state_transaction_free(state: *mut std::os::r
 #[no_mangle]
 pub unsafe extern "C" fn rs_dcerpc_state_trunc(state: *mut std::os::raw::c_void, direction: u8) {
     let dce_state = cast_pointer!(state, DCERPCState);
-    if direction & core::STREAM_TOSERVER != 0 {
-        dce_state.ts_ssn_trunc = true;
-        for tx in &mut dce_state.transactions {
-            tx.req_done = true;
-            if let Some(flow) = dce_state.flow {
-                sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOSERVER.into());
+    match direction.into() {
+        Direction::ToServer =>  {
+            dce_state.ts_ssn_trunc = true;
+            for tx in &mut dce_state.transactions {
+                tx.req_done = true;
+                if let Some(flow) = dce_state.flow {
+                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32);
+                }
             }
+            SCLogDebug!("dce_state.ts_ssn_trunc = true; txs {}", dce_state.transactions.len());
         }
-        SCLogDebug!("dce_state.ts_ssn_trunc = true; txs {}", dce_state.transactions.len());
-    } else if direction & core::STREAM_TOCLIENT != 0 {
-        dce_state.tc_ssn_trunc = true;
-        for tx in &mut dce_state.transactions {
-            tx.resp_done = true;
-            if let Some(flow) = dce_state.flow {
-                sc_app_layer_parser_trigger_raw_stream_reassembly(flow, core::STREAM_TOCLIENT.into());
+        Direction::ToClient => {
+            dce_state.tc_ssn_trunc = true;
+            for tx in &mut dce_state.transactions {
+                tx.resp_done = true;
+                if let Some(flow) = dce_state.flow {
+                    sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32);
+                }
             }
+            SCLogDebug!("dce_state.tc_ssn_trunc = true; txs {}", dce_state.transactions.len());
         }
-        SCLogDebug!("dce_state.tc_ssn_trunc = true; txs {}", dce_state.transactions.len());
     }
 }
 
@@ -1273,10 +1276,10 @@ pub unsafe extern "C" fn rs_dcerpc_get_tx_cnt(vtx: *mut std::os::raw::c_void) ->
 pub unsafe extern "C" fn rs_dcerpc_get_alstate_progress(tx: *mut std::os::raw::c_void, direction: u8
                                                  )-> std::os::raw::c_int {
     let tx = cast_pointer!(tx, DCERPCTransaction);
-    if direction == core::STREAM_TOSERVER && tx.req_done {
+    if direction == Direction::ToServer.into() && tx.req_done {
         SCLogDebug!("tx {} TOSERVER progress 1 => {:?}", tx.call_id, tx);
         return 1;
-    } else if direction == core::STREAM_TOCLIENT && tx.resp_done {
+    } else if direction == Direction::ToClient.into() && tx.resp_done {
         SCLogDebug!("tx {} TOCLIENT progress 1 => {:?}", tx.call_id, tx);
         return 1;
     }
@@ -1297,13 +1300,13 @@ pub unsafe extern "C" fn rs_dcerpc_get_tx_data(
 pub unsafe extern "C" fn rs_dcerpc_get_stub_data(
     tx: &mut DCERPCTransaction, buf: *mut *const u8, len: *mut u32, endianness: *mut u8, dir: u8,
 ) {
-    match dir {
-        core::STREAM_TOSERVER => {
+    match dir.into() {
+        Direction::ToServer => {
             *len = tx.stub_data_buffer_ts.len() as u32;
             *buf = tx.stub_data_buffer_ts.as_ptr();
             SCLogDebug!("DCERPC Request stub buffer: Setting buffer to: {:?}", *buf);
         }
-        _ => {
+        Direction::ToClient => {
             *len = tx.stub_data_buffer_tc.len() as u32;
             *buf = tx.stub_data_buffer_tc.as_ptr();
             SCLogDebug!("DCERPC Response stub buffer: Setting buffer to: {:?}", *buf);
@@ -1339,12 +1342,12 @@ pub unsafe extern "C" fn rs_dcerpc_probe_tcp(_f: *const core::Flow, direction: u
     let (is_dcerpc, is_request, ) = probe(slice);
     if is_dcerpc {
         let dir = if is_request {
-            core::STREAM_TOSERVER
+            Direction::ToServer
         } else {
-            core::STREAM_TOCLIENT
+            Direction::ToClient
         };
-        if direction & (core::STREAM_TOSERVER|core::STREAM_TOCLIENT) != dir {
-            *rdir = dir;
+        if (direction & DIR_BOTH) != dir as u8 {
+            *rdir = dir as u8;
         }
         return ALPROTO_DCERPC;
     }
@@ -1355,13 +1358,13 @@ fn register_pattern_probe() -> i8 {
     unsafe {
         if AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_DCERPC,
                                                      b"|05 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0,
-                                                     core::STREAM_TOSERVER, rs_dcerpc_probe_tcp, 0, 0) < 0 {
+                                                     Direction::ToServer.into(), rs_dcerpc_probe_tcp, 0, 0) < 0 {
             SCLogDebug!("TOSERVER => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED");
             return -1;
         }
         if AppLayerProtoDetectPMRegisterPatternCSwPP(IPPROTO_TCP as u8, ALPROTO_DCERPC,
                                                      b"|05 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0,
-                                                     core::STREAM_TOCLIENT, rs_dcerpc_probe_tcp, 0, 0) < 0 {
+                                                     Direction::ToClient.into(), rs_dcerpc_probe_tcp, 0, 0) < 0 {
             SCLogDebug!("TOCLIENT => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED");
             return -1;
         }
@@ -1437,7 +1440,7 @@ pub unsafe extern "C" fn rs_dcerpc_register_parser() {
 #[cfg(test)]
 mod tests {
     use crate::applayer::AppLayerResult;
-    use crate::core;
+    use crate::core::*;
     use crate::dcerpc::dcerpc::DCERPCState;
     use std::cmp;
 
@@ -1864,7 +1867,7 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request, Direction::ToServer)
         );
         if let Some(hdr) = dcerpc_state.header {
             assert_eq!(0, hdr.hdrtype);
@@ -1900,11 +1903,11 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(), // TODO ASK if this is correct?
-            dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind2, Direction::ToServer)
         );
     }
 
@@ -1970,11 +1973,11 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind2, Direction::ToServer)
         );
         if let Some(ref bind) = dcerpc_state.bind {
             assert_eq!(16, bind.numctxitems);
@@ -1994,15 +1997,15 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request2, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request3, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request3, Direction::ToServer)
         );
         let tx = &dcerpc_state.transactions[0];
         assert_eq!(20, tx.stub_data_buffer_ts.len());
@@ -2018,7 +2021,7 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
     }
 
@@ -2032,7 +2035,7 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
     }
 
@@ -2052,15 +2055,15 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::err(),
-            dcerpc_state.handle_input_data(fault, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(fault, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request2, Direction::ToServer)
         );
         let tx = &dcerpc_state.transactions[0];
         assert_eq!(12, tx.stub_data_buffer_ts.len());
@@ -2082,15 +2085,15 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request2, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request3, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request3, Direction::ToServer)
         );
     }
 
@@ -2107,14 +2110,14 @@ mod tests {
             0x00, 0x00,
         ];
         let mut dcerpc_state = DCERPCState::new();
-        dcerpc_state.data_needed_for_dir = core::STREAM_TOCLIENT;
+        dcerpc_state.data_needed_for_dir = Direction::ToClient;
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack1, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack1, Direction::ToClient)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack2, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack2, Direction::ToClient)
         );
     }
 
@@ -2137,7 +2140,7 @@ mod tests {
         ];
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bindbuf, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bindbuf, Direction::ToServer)
         );
         if let Some(ref bind) = dcerpc_state.bind {
             let bind_uuid = &bind.uuid_list[0].uuid;
@@ -2171,7 +2174,7 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bindbuf, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bindbuf, Direction::ToServer)
         );
     }
 
@@ -2188,10 +2191,10 @@ mod tests {
             0xFF,
         ];
         let mut dcerpc_state = DCERPCState::new();
-        dcerpc_state.data_needed_for_dir = core::STREAM_TOCLIENT;
+        dcerpc_state.data_needed_for_dir = Direction::ToClient;
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack, Direction::ToClient)
         );
     }
 
@@ -2442,11 +2445,11 @@ mod tests {
         ];
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack1, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack1, Direction::ToClient)
         );
         if let Some(ref back) = dcerpc_state.bindack {
             assert_eq!(1, back.accepted_uuid_list.len());
@@ -2455,11 +2458,11 @@ mod tests {
         }
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind2, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack2, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack2, Direction::ToClient)
         );
         if let Some(ref back) = dcerpc_state.bindack {
             assert_eq!(1, back.accepted_uuid_list.len());
@@ -2468,15 +2471,15 @@ mod tests {
         }
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind3, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind3, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind_ack3, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bind_ack3, Direction::ToClient)
         );
         if let Some(ref back) = dcerpc_state.bindack {
             assert_eq!(1, back.accepted_uuid_list.len());
-            dcerpc_state.data_needed_for_dir = core::STREAM_TOSERVER;
+            dcerpc_state.data_needed_for_dir = Direction::ToServer;
             assert_eq!(11, back.accepted_uuid_list[0].ctxid);
             assert_eq!(expected_uuid3, back.accepted_uuid_list[0].uuid);
         }
@@ -2525,11 +2528,11 @@ mod tests {
         ];
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bind, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(bind, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(bindack, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(bindack, Direction::ToClient)
         );
         if let Some(ref back) = dcerpc_state.bindack {
             assert_eq!(1, back.accepted_uuid_list.len());
@@ -2538,11 +2541,11 @@ mod tests {
         }
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(alter_context, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(alter_context, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(alter_context_resp, core::STREAM_TOCLIENT)
+            dcerpc_state.handle_input_data(alter_context_resp, Direction::ToClient)
         );
         if let Some(ref back) = dcerpc_state.bindack {
             assert_eq!(1, back.accepted_uuid_list.len());
@@ -2564,11 +2567,11 @@ mod tests {
         let mut dcerpc_state = DCERPCState::new();
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request1, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request1, Direction::ToServer)
         );
         assert_eq!(
             AppLayerResult::ok(),
-            dcerpc_state.handle_input_data(request2, core::STREAM_TOSERVER)
+            dcerpc_state.handle_input_data(request2, Direction::ToServer)
         );
         let tx = &dcerpc_state.transactions[0];
         assert_eq!(2, tx.opnum);
index 0ece8025ef8a22ea41321f00fc86a6544acf7f0e..376fc4755ffd51c0dfb0b9fe62d83572d2e03376 100644 (file)
@@ -16,7 +16,7 @@
  */
 
 use crate::applayer::*;
-use crate::core;
+use crate::core::{self, Direction, DIR_BOTH};
 use crate::dcerpc::dcerpc::{
     DCERPCTransaction, DCERPC_TYPE_REQUEST, DCERPC_TYPE_RESPONSE, PFCL1_FRAG, PFCL1_LASTFRAG,
     rs_dcerpc_get_alstate_progress, ALPROTO_DCERPC, PARSER_NAME,
@@ -314,14 +314,16 @@ pub unsafe extern "C" fn rs_dcerpc_probe_udp(_f: *const core::Flow, direction: u
     //is_incomplete is checked by caller
     let (is_dcerpc, is_request) = probe(slice);
     if is_dcerpc {
-        let dir = if is_request {
-            core::STREAM_TOSERVER
+        let dir: Direction = (direction & DIR_BOTH).into();
+        if is_request {
+            if dir != Direction::ToServer {
+                *rdir = Direction::ToServer.into();
+            }
         } else {
-            core::STREAM_TOCLIENT
+            if dir != Direction::ToClient {
+                *rdir = Direction::ToClient.into();
+            }
         };
-        if direction & (core::STREAM_TOSERVER|core::STREAM_TOCLIENT) != dir {
-            *rdir = dir;
-        }
         return ALPROTO_DCERPC;
     }
     return core::ALPROTO_FAILED;
@@ -331,7 +333,7 @@ fn register_pattern_probe() -> i8 {
     unsafe {
         if AppLayerProtoDetectPMRegisterPatternCSwPP(core::IPPROTO_UDP as u8, ALPROTO_DCERPC,
                                                      b"|04 00|\0".as_ptr() as *const std::os::raw::c_char, 2, 0,
-                                                     core::STREAM_TOSERVER, rs_dcerpc_probe_udp, 0, 0) < 0 {
+                                                     Direction::ToServer.into(), rs_dcerpc_probe_udp, 0, 0) < 0 {
             SCLogDebug!("TOSERVER => AppLayerProtoDetectPMRegisterPatternCSwPP FAILED");
             return -1;
         }