]> git.ipfire.org Git - people/ms/suricata.git/commitdiff
mqtt: use generic tx iterator
authorJason Ish <jason.ish@oisf.net>
Wed, 10 Nov 2021 17:20:13 +0000 (11:20 -0600)
committerVictor Julien <vjulien@oisf.net>
Mon, 22 Nov 2021 09:20:22 +0000 (10:20 +0100)
rust/src/mqtt/mqtt.rs

index 9df1ae3a5e88b97d724a9f46f9636eb7dc608195..a5f1b2b31502072fd0274104ea3a7251b121124b 100644 (file)
@@ -99,6 +99,12 @@ impl Drop for MQTTTransaction {
     }
 }
 
+impl Transaction for MQTTTransaction {
+    fn id(&self) -> u64 {
+        self.tx_id
+    }
+}
+
 pub struct MQTTState {
     tx_id: u64,
     pub protocol_version: u8,
@@ -109,6 +115,12 @@ pub struct MQTTState {
     max_msg_len: usize,
 }
 
+impl State<MQTTTransaction> for MQTTState {
+    fn get_transactions(&self) -> &[MQTTTransaction] {
+        &self.transactions
+    }
+}
+
 impl MQTTState {
     pub fn new() -> Self {
         Self {
@@ -520,27 +532,6 @@ impl MQTTState {
         let ev = event as u8;
         core::sc_app_layer_decoder_events_set_event_raw(&mut tx.events, ev);
     }
-
-    fn tx_iterator(
-        &mut self,
-        min_tx_id: u64,
-        state: &mut u64,
-    ) -> Option<(&MQTTTransaction, u64, bool)> {
-        let mut index = *state as usize;
-        let len = self.transactions.len();
-
-        while index < len {
-            let tx = &self.transactions[index];
-            if tx.tx_id < min_tx_id + 1 {
-                index += 1;
-                continue;
-            }
-            *state = index as u64;
-            return Some((tx, tx.tx_id - 1, (len - index) > 1));
-        }
-
-        return None;
-    }
 }
 
 // C exports.
@@ -700,28 +691,6 @@ pub unsafe extern "C" fn rs_mqtt_state_get_events(
     return tx.events;
 }
 
-#[no_mangle]
-pub unsafe extern "C" fn rs_mqtt_state_get_tx_iterator(
-    _ipproto: u8,
-    _alproto: AppProto,
-    state: *mut std::os::raw::c_void,
-    min_tx_id: u64,
-    _max_tx_id: u64,
-    istate: &mut u64,
-) -> applayer::AppLayerGetTxIterTuple {
-    let state = cast_pointer!(state, MQTTState);
-    match state.tx_iterator(min_tx_id, istate) {
-        Some((tx, out_tx_id, has_next)) => {
-            let c_tx = tx as *const _ as *mut _;
-            let ires = applayer::AppLayerGetTxIterTuple::with_values(c_tx, out_tx_id, has_next);
-            return ires;
-        }
-        None => {
-            return applayer::AppLayerGetTxIterTuple::not_found();
-        }
-    }
-}
-
 // Parser name as a C style string.
 const PARSER_NAME: &'static [u8] = b"mqtt\0";
 
@@ -758,7 +727,7 @@ pub unsafe extern "C" fn rs_mqtt_register_parser(cfg_max_msg_len: u32) {
         localstorage_new: None,
         localstorage_free: None,
         get_files: None,
-        get_tx_iterator: Some(rs_mqtt_state_get_tx_iterator),
+        get_tx_iterator: Some(crate::applayer::state_get_tx_iterator::<MQTTState, MQTTTransaction>),
         get_tx_data: rs_mqtt_get_tx_data,
         apply_tx_config: None,
         flags: APP_LAYER_PARSER_OPT_UNIDIR_TXS,