]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
mqtt: limit size of variable integer 5388/head
authorPhilippe Antoine <contact@catenacyber.fr>
Tue, 8 Sep 2020 19:13:07 +0000 (21:13 +0200)
committerVictor Julien <victor@inliniac.net>
Tue, 8 Sep 2020 20:11:49 +0000 (22:11 +0200)
rust/src/mqtt/parser.rs

index 65eb780fec05ed4cf8c734f11e14eea9310edec5..fdab86f3a86a004de12227236c7f2875fad1fbbd 100644 (file)
@@ -52,17 +52,6 @@ fn convert_varint(continued: Vec<u8>, last: u8) -> u32 {
     return value;
 }
 
-#[inline]
-fn varint_length(val: usize) -> usize {
-  match val {
-      0 ..= 127 => 1,
-      128 ..= 16383 => 2,
-      16384 ..= 2097151 => 3,
-      2097152 ..= 268435455 => 4,
-      _ => 0,
-  }
-}
-
 // DATA TYPES
 
 named!(#[inline], pub parse_mqtt_string<String>,
@@ -76,8 +65,9 @@ named!(#[inline], pub parse_mqtt_string<String>,
 
 named!(#[inline], pub parse_mqtt_variable_integer<u32>,
        do_parse!(
-           continued_part: take_while!(is_continuation_bit_set)
-           >> non_continued_part: be_u8
+           // take at most 4 bytes in total, so as not to overflow u32
+           continued_part: take_while_m_n!(0, 3, is_continuation_bit_set)
+           >> non_continued_part: verify!(be_u8, |&val| !is_continuation_bit_set(val))
            >>  (
                  convert_varint(continued_part.to_vec(), non_continued_part)
                )
@@ -495,9 +485,8 @@ pub fn parse_message(input: &[u8], protocol_version: u8, max_msg_size: usize) ->
             // before returning the remainder. It is the sum of the length
             // of the flag byte (1) and the length of the message length
             // varint.
-            let skiplen = 1 + varint_length(len);
+            let skiplen = input.len() - fullrem.len();
             let message_type = header.message_type;
-            assert_eq!(skiplen, input.len() - fullrem.len());
 
             // If the remaining length (message length) exceeds the specified
             // limit, we return a special truncation message type, containing
@@ -667,3 +656,26 @@ pub fn parse_message(input: &[u8], protocol_version: u8, max_msg_size: usize) ->
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+
+    use super::*;
+
+    #[test]
+    fn test_mqtt_parse_variable_integer() {
+        let buf0: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF];
+        let r0 = parse_mqtt_variable_integer(buf0);
+        match r0 {
+            Ok((_, _)) => {
+                panic!("Result should not have been ok.");
+            }
+            Err(Err::Error(err)) => {
+                assert_eq!(err.1, error::ErrorKind::Verify);
+            }
+            _ => {
+                panic!("Result should be an error.");
+            }
+        }
+    }
+}