]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
smb: Change fn sign as per rust registration requirement
authorShivani Bhardwaj <shivanib134@gmail.com>
Fri, 18 Jun 2021 11:40:37 +0000 (17:10 +0530)
committerShivani Bhardwaj <shivanib134@gmail.com>
Thu, 5 Aug 2021 15:16:54 +0000 (20:46 +0530)
Registering parsers in Rust requires signatures to be a certain way and
compatible with C. Change signatures of all the functions.

rust/src/smb/files.rs
rust/src/smb/smb.rs
src/app-layer-smb.c

index 50b96ae8b940d26ca5b998c385115767011c1228..d02bca8e2d8edc3077b63971477347eef61fa195 100644 (file)
@@ -15,6 +15,7 @@
  * 02110-1301, USA.
  */
 
+use std;
 use crate::core::*;
 use crate::filetracker::*;
 use crate::filecontainer::*;
@@ -189,9 +190,10 @@ impl SMBState {
 }
 
 #[no_mangle]
-pub extern "C" fn rs_smb_getfiles(direction: u8, ptr: *mut SMBState) -> * mut FileContainer {
+pub extern "C" fn rs_smb_getfiles(ptr: *mut std::ffi::c_void, direction: u8) -> * mut FileContainer {
     if ptr.is_null() { panic!("NULL ptr"); };
-    let parser = unsafe { &mut *ptr };
+    let ptr = cast_pointer!(ptr, SMBState);
+    let parser = &mut *ptr;
     parser.getfiles(direction)
 }
 
index b2013e9736ff8d42b5263e1c2e5510d3c117f852..7837d71873922440804ee276784074c9771ba22f 100644 (file)
@@ -28,7 +28,7 @@
 use std;
 use std::mem::transmute;
 use std::str;
-use std::ffi::CStr;
+use std::ffi::{self, CStr};
 
 use std::collections::HashMap;
 
@@ -36,7 +36,7 @@ use nom;
 
 use crate::core::*;
 use crate::applayer;
-use crate::applayer::{AppLayerResult, AppLayerTxData};
+use crate::applayer::*;
 use crate::filecontainer::*;
 
 use crate::smb::nbss_records::*;
@@ -55,6 +55,8 @@ use crate::smb::smb2_ioctl::*;
 pub const MIN_REC_SIZE: u16 = 32 + 4; // SMB hdr + nbss hdr
 pub const SMB_CONFIG_DEFAULT_STREAM_DEPTH: u32 = 0;
 
+static mut ALPROTO_SMB: AppProto = ALPROTO_UNKNOWN;
+
 pub static mut SURICATA_SMB_FILE_CONFIG: Option<&'static SuricataFileContext> = None;
 
 #[no_mangle]
@@ -1810,16 +1812,18 @@ pub extern "C" fn rs_smb_state_free(state: *mut std::os::raw::c_void) {
 
 /// C binding parse a SMB request. Returns 1 on success, -1 on failure.
 #[no_mangle]
-pub extern "C" fn rs_smb_parse_request_tcp(flow: &mut Flow,
-                                       state: &mut SMBState,
+pub extern "C" fn rs_smb_parse_request_tcp(flow: *const Flow,
+                                       state: *mut ffi::c_void,
                                        _pstate: *mut std::os::raw::c_void,
                                        input: *const u8,
                                        input_len: u32,
-                                       _data: *mut std::os::raw::c_void,
+                                       _data: *const std::os::raw::c_void,
                                        flags: u8)
                                        -> AppLayerResult
 {
     let buf = unsafe{std::slice::from_raw_parts(input, input_len as usize)};
+    let mut state = cast_pointer!(state, SMBState);
+    let flow = cast_pointer!(flow, Flow);
     SCLogDebug!("parsing {} bytes of request data", input_len);
 
     /* START with MISTREAM set: record might be starting the middle. */
@@ -1842,15 +1846,17 @@ pub extern "C" fn rs_smb_parse_request_tcp_gap(
 
 
 #[no_mangle]
-pub extern "C" fn rs_smb_parse_response_tcp(flow: &mut Flow,
-                                        state: &mut SMBState,
+pub extern "C" fn rs_smb_parse_response_tcp(flow: *const Flow,
+                                        state: *mut ffi::c_void,
                                         _pstate: *mut std::os::raw::c_void,
                                         input: *const u8,
                                         input_len: u32,
-                                        _data: *mut std::os::raw::c_void,
+                                        _data: *const ffi::c_void,
                                         flags: u8)
                                         -> AppLayerResult
 {
+    let mut state = cast_pointer!(state, SMBState);
+    let flow = cast_pointer!(flow, Flow);
     SCLogDebug!("parsing {} bytes of response data", input_len);
     let buf = unsafe{std::slice::from_raw_parts(input, input_len as usize)};
 
@@ -1872,7 +1878,7 @@ pub extern "C" fn rs_smb_parse_response_tcp_gap(
     state.parse_tcp_data_tc_gap(input_len as u32)
 }
 
-fn rs_smb_probe_tcp_midstream(direction: u8, slice: &[u8], rdir: *mut u8) -> i8
+fn smb_probe_tcp_midstream(direction: u8, slice: &[u8], rdir: *mut u8) -> i8
 {
     match search_smb_record(slice) {
         Ok((_, ref data)) => {
@@ -1938,22 +1944,21 @@ fn rs_smb_probe_tcp_midstream(direction: u8, slice: &[u8], rdir: *mut u8) -> i8
 // probing parser
 // return 1 if found, 0 is not found
 #[no_mangle]
-pub extern "C" fn rs_smb_probe_tcp(flags: u8,
-        input: *const u8, len: u32,
-        rdir: *mut u8)
-    -> i8
+pub extern "C" fn rs_smb_probe_tcp(_f: *const Flow,
+                                   flags: u8, input: *const u8, len: u32, rdir: *mut u8)
+    -> AppProto
 {
     let slice = build_slice!(input, len as usize);
     if flags & STREAM_MIDSTREAM == STREAM_MIDSTREAM {
-        if rs_smb_probe_tcp_midstream(flags, slice, rdir) == 1 {
-            return 1;
+        if smb_probe_tcp_midstream(flags, slice, rdir) == 1 {
+            return unsafe { ALPROTO_SMB };
         }
     }
     match parse_nbss_record_partial(slice) {
         Ok((_, ref hdr)) => {
             if hdr.is_smb() {
                 SCLogDebug!("smb found");
-                return 1;
+                return unsafe { ALPROTO_SMB };
             } else if hdr.needs_more(){
                 return 0;
             } else if hdr.is_valid() &&
@@ -1966,7 +1971,7 @@ pub extern "C" fn rs_smb_probe_tcp(flags: u8,
                         Ok((_, ref hdr2)) => {
                             if hdr2.is_smb() {
                                 SCLogDebug!("smb found");
-                                return 1;
+                                return unsafe { ALPROTO_SMB };
                             }
                         }
                         _ => {}
@@ -1981,22 +1986,24 @@ pub extern "C" fn rs_smb_probe_tcp(flags: u8,
         _ => { },
     }
     SCLogDebug!("no smb");
-    return -1
+    return unsafe { ALPROTO_FAILED };
 }
 
 #[no_mangle]
-pub extern "C" fn rs_smb_state_get_tx_count(state: &mut SMBState)
+pub extern "C" fn rs_smb_state_get_tx_count(state: *mut ffi::c_void)
                                             -> u64
 {
+    let state = cast_pointer!(state, SMBState);
     SCLogDebug!("rs_smb_state_get_tx_count: returning {}", state.tx_id);
     return state.tx_id;
 }
 
 #[no_mangle]
-pub extern "C" fn rs_smb_state_get_tx(state: &mut SMBState,
+pub extern "C" fn rs_smb_state_get_tx(state: *mut ffi::c_void,
                                       tx_id: u64)
-                                      -> *mut SMBTransaction
+                                      -> *mut ffi::c_void
 {
+    let state = cast_pointer!(state, SMBState);
     match state.get_tx_by_id(tx_id) {
         Some(tx) => {
             return unsafe{transmute(tx)};
@@ -2010,11 +2017,15 @@ pub extern "C" fn rs_smb_state_get_tx(state: &mut SMBState,
 // for use with the C API call StateGetTxIterator
 #[no_mangle]
 pub extern "C" fn rs_smb_state_get_tx_iterator(
-                                      state: &mut SMBState,
-                                      min_tx_id: u64,
-                                      istate: &mut u64)
-                                      -> applayer::AppLayerGetTxIterTuple
+                                               _ipproto: u8,
+                                               _alproto: AppProto,
+                                               state: *mut std::os::raw::c_void,
+                                               min_tx_id: u64,
+                                               _max_tx_id: u64,
+                                               istate: &mut u64,
+                                               ) -> applayer::AppLayerGetTxIterTuple
 {
+    let state = cast_pointer!(state, SMBState);
     match state.get_tx_iterator(min_tx_id, istate) {
         Some((tx, out_tx_id, has_next)) => {
             let c_tx = unsafe { transmute(tx) };
@@ -2028,18 +2039,21 @@ pub extern "C" fn rs_smb_state_get_tx_iterator(
 }
 
 #[no_mangle]
-pub extern "C" fn rs_smb_state_tx_free(state: &mut SMBState,
+pub extern "C" fn rs_smb_state_tx_free(state: *mut ffi::c_void,
                                        tx_id: u64)
 {
+    let state = cast_pointer!(state, SMBState);
     SCLogDebug!("freeing tx {}", tx_id as u64);
     state.free_tx(tx_id);
 }
 
 #[no_mangle]
-pub extern "C" fn rs_smb_tx_get_alstate_progress(tx: &mut SMBTransaction,
+pub extern "C" fn rs_smb_tx_get_alstate_progress(tx: *mut ffi::c_void,
                                                   direction: u8)
-                                                  -> u8
+                                                  -> i32
 {
+    let tx = cast_pointer!(tx, SMBTransaction);
+
     if direction == STREAM_TOSERVER && tx.request_done {
         SCLogDebug!("tx {} TOSERVER progress 1 => {:?}", tx.id, tx);
         return 1;
@@ -2052,6 +2066,7 @@ pub extern "C" fn rs_smb_tx_get_alstate_progress(tx: &mut SMBTransaction,
     }
 }
 
+
 #[no_mangle]
 pub extern "C" fn rs_smb_get_tx_data(
     tx: *mut std::os::raw::c_void)
@@ -2061,19 +2076,12 @@ pub extern "C" fn rs_smb_get_tx_data(
     return &mut tx.tx_data;
 }
 
-#[no_mangle]
-pub extern "C" fn rs_smb_state_set_tx_detect_state(
-    tx: &mut SMBTransaction,
-    de_state: &mut DetectEngineState)
-{
-    tx.de_state = Some(de_state);
-}
-
 #[no_mangle]
 pub extern "C" fn rs_smb_state_get_tx_detect_state(
-    tx: &mut SMBTransaction)
+    tx: *mut std::os::raw::c_void)
     -> *mut DetectEngineState
 {
+    let tx = cast_pointer!(tx, SMBTransaction);
     match tx.de_state {
         Some(ds) => {
             return ds;
@@ -2084,11 +2092,22 @@ pub extern "C" fn rs_smb_state_get_tx_detect_state(
     }
 }
 
+#[no_mangle]
+pub extern "C" fn rs_smb_state_set_tx_detect_state(
+    tx: *mut std::os::raw::c_void,
+    de_state: &mut DetectEngineState) -> std::os::raw::c_int
+{
+    let tx = cast_pointer!(tx, SMBTransaction);
+    tx.de_state = Some(de_state);
+    0
+}
+
 #[no_mangle]
 pub extern "C" fn rs_smb_state_truncate(
-        state: &mut SMBState,
+        state: *mut std::ffi::c_void,
         direction: u8)
 {
+    let state = cast_pointer!(state, SMBState);
     if (direction & STREAM_TOSERVER) != 0 {
         state.trunc_ts();
     } else {
@@ -2135,7 +2154,7 @@ pub extern "C" fn rs_smb_state_get_event_info_by_id(event_id: std::os::raw::c_in
 pub extern "C" fn rs_smb_state_get_event_info(event_name: *const std::os::raw::c_char,
                                               event_id: *mut std::os::raw::c_int,
                                               event_type: *mut AppLayerEventType)
-                                              -> i8
+                                              -> i32
 {
     if event_name == std::ptr::null() {
         return -1;
index a1037c0bbc300f8cfae8efa4bc66b8e01a139489..206902add07b8ea7803c80c1b8e8231f2d47ef26 100644 (file)
@@ -80,7 +80,7 @@ static uint16_t SMBTCPProbe(Flow *f, uint8_t direction,
         return ALPROTO_UNKNOWN;
     }
 
-    const int r = rs_smb_probe_tcp(direction, input, len, rdir);
+    const int r = rs_smb_probe_tcp(f, direction, input, len, rdir);
     switch (r) {
         case 1:
             return ALPROTO_SMB;
@@ -145,7 +145,8 @@ static AppLayerGetTxIterTuple SMBGetTxIterator(
         void *alstate, uint64_t min_tx_id, uint64_t max_tx_id,
         AppLayerGetTxIterState *istate)
 {
-    return rs_smb_state_get_tx_iterator(alstate, min_tx_id, (uint64_t *)istate);
+    return rs_smb_state_get_tx_iterator(
+            ipproto, alproto, alstate, min_tx_id, max_tx_id, (uint64_t *)istate);
 }
 
 
@@ -167,7 +168,7 @@ static int SMBSetTxDetectState(void *tx, DetectEngineState *s)
 
 static FileContainer *SMBGetFiles(void *state, uint8_t direction)
 {
-    return rs_smb_getfiles(direction, state);
+    return rs_smb_getfiles(state, direction);
 }
 
 static AppLayerDecoderEvents *SMBGetEvents(void *tx)