]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
rust/smb: convert parser to nom7 functions (NBSS records)
authorPierre Chifflier <chifflier@wzdftpd.net>
Fri, 12 Nov 2021 13:32:09 +0000 (14:32 +0100)
committerVictor Julien <vjulien@oisf.net>
Mon, 13 Dec 2021 12:06:21 +0000 (13:06 +0100)
rust/src/smb/nbss_records.rs
rust/src/smb/smb.rs

index 7256c8a91da64fee7ef517b71dcce1d1ff1cc10e..e63862cdcb1bbb7c815a60e0cae5086b3f30e7a2 100644 (file)
  * 02110-1301, USA.
  */
 
-use nom::combinator::rest;
+use nom7::bytes::streaming::take;
+use nom7::combinator::rest;
+use nom7::number::streaming::be_u32;
+use nom7::IResult;
 
 pub const NBSS_MSGTYPE_SESSION_MESSAGE:         u8 = 0x00;
 pub const NBSS_MSGTYPE_SESSION_REQUEST:         u8 = 0x81;
@@ -62,36 +65,37 @@ impl<'a> NbssRecord<'a> {
     }
 }
 
-named!(pub parse_nbss_record<NbssRecord>,
-   do_parse!(
-       type_and_len: bits!(tuple!(
-               take_bits!(8u8),
-               take_bits!(24u32)))
-       >> data: take!(type_and_len.1 as usize)
-       >> (NbssRecord {
-            message_type:type_and_len.0,
-            length:type_and_len.1,
-            data:data,
-        })
-));
-
-named!(pub parse_nbss_record_partial<NbssRecord>,
-   do_parse!(
-       type_and_len: bits!(tuple!(
-               take_bits!(8u8),
-               take_bits!(24u32)))
-       >> data: rest
-       >> (NbssRecord {
-            message_type:type_and_len.0,
-            length:type_and_len.1,
-            data:data,
-        })
-));
+pub fn parse_nbss_record(i: &[u8]) -> IResult<&[u8], NbssRecord> {
+    let (i, buf) = be_u32(i)?;
+    let message_type = (buf >> 24) as u8;
+    let length = buf & 0xff_ffff;
+    let (i, data) = take(length as usize)(i)?;
+    let record = NbssRecord {
+        message_type,
+        length,
+        data,
+    };
+    Ok((i, record))
+}
+
+pub fn parse_nbss_record_partial(i: &[u8]) -> IResult<&[u8], NbssRecord> {
+    let (i, buf) = be_u32(i)?;
+    let message_type = (buf >> 24) as u8;
+    let length = buf & 0xff_ffff;
+    let (i, data) = rest(i)?;
+    let record = NbssRecord {
+        message_type,
+        length,
+        data,
+    };
+    Ok((i, record))
+}
 
 #[cfg(test)]
 mod tests {
 
     use super::*;
+    use nom7::Err;
 
     #[test]
     fn test_parse_nbss_record() {
@@ -126,10 +130,10 @@ mod tests {
                 // there should be nothing left
                 assert_eq!(remainder.len(), 0);
             }
-            Err(nom::Err::Error((_remainder, err))) => {
-                panic!("Result should not be an error: {:?}.", err);
+            Err(Err::Error(err)) => {
+                panic!("Result should not be an error: {:?}.", err.code);
             }
-            Err(nom::Err::Incomplete(_)) => {
+            Err(Err::Incomplete(_)) => {
                 panic!("Result should not have been incomplete.");
             }
             _ => {
@@ -170,10 +174,10 @@ mod tests {
                 // there should be nothing left
                 assert_eq!(remainder.len(), 0);
             }
-            Err(nom::Err::Error((_remainder, err))) => {
-                panic!("Result should not be an error: {:?}.", err);
+            Err(Err::Error(err)) => {
+                panic!("Result should not be an error: {:?}.", err.code);
             }
-            Err(nom::Err::Incomplete(_)) => {
+            Err(Err::Incomplete(_)) => {
                 panic!("Result should not have been incomplete.");
             }
             _ => {
@@ -210,10 +214,10 @@ mod tests {
                 // there should be nothing left
                 assert_eq!(remainder.len(), 0);
             }
-            Err(nom::Err::Error((_remainder, err))) => {
-                panic!("Result should not be an error: {:?}.", err);
+            Err(Err::Error(err)) => {
+                panic!("Result should not be an error: {:?}.", err.code);
             }
-            Err(nom::Err::Incomplete(_)) => {
+            Err(Err::Incomplete(_)) => {
                 panic!("Result should not have returned as incomplete.");
             }
             _ => {
index 2fedf97e2be40044592856d9ed5f63d76f0a3a65..393c7868ed215bc83add5d9cfd6a735aa151475b 100644 (file)
@@ -32,6 +32,7 @@ use std::ffi::{self, CString};
 use std::collections::HashMap;
 
 use nom;
+use nom7::{Err, Needed};
 
 use crate::core::*;
 use crate::applayer;
@@ -1420,8 +1421,9 @@ impl SMBState {
                     }
                     cur_i = rem;
                 },
-                Err(nom::Err::Incomplete(needed)) => {
-                    if let nom::Needed::Size(n) = needed {
+                Err(Err::Incomplete(needed)) => {
+                    if let Needed::Size(n) = needed {
+                        let n = usize::from(n) + cur_i.len();
                         // 512 is the minimum for parse_tcp_data_ts_partial
                         if n >= 512 && cur_i.len() < 512 {
                             let total_consumed = i.len() - cur_i.len();
@@ -1433,7 +1435,7 @@ impl SMBState {
                             let total_consumed = i.len() - cur_i.len();
                             SCLogDebug!("setting consumed {} need {} needed {:?} total input {}",
                                     total_consumed, n, needed, i.len());
-                            let need = n + 4; // Incomplete returns size of data minus NBSS header
+                            let need = n;
                             return AppLayerResult::incomplete(total_consumed as u32, need as u32);
                         }
                         // tracking a write record, which we don't need to
@@ -1661,9 +1663,10 @@ impl SMBState {
                     }
                     cur_i = rem;
                 },
-                Err(nom::Err::Incomplete(needed)) => {
+                Err(Err::Incomplete(needed)) => {
                     SCLogDebug!("INCOMPLETE have {} needed {:?}", cur_i.len(), needed);
-                    if let nom::Needed::Size(n) = needed {
+                    if let Needed::Size(n) = needed {
+                        let n = usize::from(n) + cur_i.len();
                         // 512 is the minimum for parse_tcp_data_tc_partial
                         if n >= 512 && cur_i.len() < 512 {
                             let total_consumed = i.len() - cur_i.len();
@@ -1675,7 +1678,7 @@ impl SMBState {
                             let total_consumed = i.len() - cur_i.len();
                             SCLogDebug!("setting consumed {} need {} needed {:?} total input {}",
                                     total_consumed, n, needed, i.len());
-                            let need = n + 4; // Incomplete returns size of data minus NBSS header
+                            let need = n;
                             return AppLayerResult::incomplete(total_consumed as u32, need as u32);
                         }
                         // tracking a read record, which we don't need to