]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
detect/asn1: Simplify errors and checks
authorEmmanuel Thompson <eet6646@gmail.com>
Wed, 3 Jun 2020 18:37:30 +0000 (14:37 -0400)
committerVictor Julien <victor@inliniac.net>
Wed, 8 Jul 2020 14:50:38 +0000 (16:50 +0200)
rust/src/asn1/mod.rs

index 6fe211afe24299e4ce9de0bd75b758df0730ac77..6f8364898dc9c5f65e2247f2a75a1218d0a2a3e2 100644 (file)
@@ -15,9 +15,7 @@
  * 02110-1301, USA.
  */
 
-use crate::log::*;
 use der_parser::ber::{parse_ber_recursive, BerObject, BerObjectContent, BerTag};
-use der_parser::error::BerError;
 use std::convert::TryFrom;
 
 mod parse_rules;
@@ -32,20 +30,7 @@ pub struct Asn1<'a>(Vec<BerObject<'a>>);
 enum Asn1DecodeError {
     InvalidKeywordParameter,
     MaxFrames,
-    InvalidStructure,
-    BerTypeError,
-    BerValueError,
-    InvalidTag,
-    InvalidLength,
-    InvalidClass,
-    ConstructExpected,
-    ConstructUnexpected,
-    IntegerTooLarge,
-    BerMaxDepth,
-    ObjectTooShort,
-    DerConstraintFailed,
-    UnknownTag,
-    Unsupported,
+    BerError(nom::Err<der_parser::error::BerError>),
 }
 
 /// Enumeration of Asn1 checks
@@ -54,61 +39,46 @@ enum Asn1Check {
     OversizeLength,
     BitstringOverflow,
     DoubleOverflow,
-}
-
-/// Errors possible during Asn1 checks
-#[derive(Debug)]
-enum Asn1CheckError {
     MaxDepth,
 }
 
-impl std::fmt::Display for Asn1CheckError {
-    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
-        match self {
-            Asn1CheckError::MaxDepth => write!(f, "MaxDepth"),
-        }
-    }
-}
-
 impl<'a> Asn1<'a> {
     /// Checks each BerObject contained in self with the provided detection
     /// data, returns the first successful match if one occurs
-    fn check(&self, ad: &DetectAsn1Data) -> Result<Option<Asn1Check>, Asn1CheckError> {
+    fn check(&self, ad: &DetectAsn1Data) -> Option<Asn1Check> {
         for obj in &self.0 {
-            let res = Asn1::check_object_recursive(obj, ad, ad.max_frames as usize)?;
+            let res = Asn1::check_object_recursive(obj, ad, ad.max_frames as usize);
             if res.is_some() {
-                return Ok(res);
+                return res;
             }
         }
 
-        Ok(None)
+        None
     }
 
     fn check_object_recursive(
-        obj: &BerObject,
-        ad: &DetectAsn1Data,
-        max_depth: usize,
-    ) -> Result<Option<Asn1Check>, Asn1CheckError> {
+        obj: &BerObject, ad: &DetectAsn1Data, max_depth: usize,
+    ) -> Option<Asn1Check> {
         // Check stack depth
         if max_depth == 0 {
-            return Err(Asn1CheckError::MaxDepth);
+            return Some(Asn1Check::MaxDepth);
         }
 
         // Check current object
         let res = Asn1::check_object(obj, ad);
         if res.is_some() {
-            return Ok(res);
+            return res;
         }
 
         // Check sub-nodes
         for node in obj.ref_iter() {
-            let res = Asn1::check_object_recursive(node, ad, max_depth - 1)?;
+            let res = Asn1::check_object_recursive(node, ad, max_depth - 1);
             if res.is_some() {
-                return Ok(res);
+                return res;
             }
         }
 
-        Ok(None)
+        None
     }
 
     /// Checks a BerObject and subnodes against the Asn1 checks
@@ -197,9 +167,7 @@ impl<'a> Asn1<'a> {
 /// Decodes Asn1 objects from an input + length while applying the offset
 /// defined in the asn1 keyword options
 fn asn1_decode<'a>(
-    buffer: &'a [u8],
-    buffer_offset: u32,
-    ad: &DetectAsn1Data,
+    buffer: &'a [u8], buffer_offset: u32, ad: &DetectAsn1Data,
 ) -> Result<Asn1<'a>, Asn1DecodeError> {
     // Get offset
     let offset = if let Some(absolute_offset) = ad.absolute_offset {
@@ -244,10 +212,7 @@ fn asn1_decode<'a>(
 /// pointer must be freed using `rs_asn1_free`
 #[no_mangle]
 pub extern "C" fn rs_asn1_decode(
-    input: *const u8,
-    input_len: u16,
-    buffer_offset: u32,
-    ad_ptr: *const DetectAsn1Data,
+    input: *const u8, input_len: u16, buffer_offset: u32, ad_ptr: *const DetectAsn1Data,
 ) -> *mut Asn1<'static> {
     if input.is_null() || input_len == 0 || ad_ptr.is_null() {
         return std::ptr::null_mut();
@@ -290,10 +255,7 @@ pub unsafe extern "C" fn rs_asn1_free(ptr: *mut Asn1) {
 ///
 /// Returns 1 if any of the options match, 0 if not
 #[no_mangle]
-pub unsafe extern "C" fn rs_asn1_checks(
-    ptr: *const Asn1,
-    ad_ptr: *const DetectAsn1Data,
-) -> u8 {
+pub unsafe extern "C" fn rs_asn1_checks(ptr: *const Asn1, ad_ptr: *const DetectAsn1Data) -> u8 {
     if ptr.is_null() || ad_ptr.is_null() {
         return 0;
     }
@@ -302,12 +264,8 @@ pub unsafe extern "C" fn rs_asn1_checks(
     let ad = &*ad_ptr;
 
     match asn1.check(ad) {
-        Ok(Some(_check)) => 1,
-        Ok(None) => 0,
-        Err(e) => {
-            SCLogError!("error during asn1 checks: {}", e.to_string());
-            0
-        }
+        Some(_check) => 1,
+        None => 0,
     }
 }
 
@@ -319,25 +277,7 @@ impl From<std::num::TryFromIntError> for Asn1DecodeError {
 
 impl From<nom::Err<der_parser::error::BerError>> for Asn1DecodeError {
     fn from(e: nom::Err<der_parser::error::BerError>) -> Asn1DecodeError {
-        match e {
-            nom::Err::Incomplete(_) => Asn1DecodeError::InvalidLength,
-            nom::Err::Error(e) | nom::Err::Failure(e) => match e {
-                BerError::BerTypeError => Asn1DecodeError::BerTypeError,
-                BerError::BerValueError => Asn1DecodeError::BerValueError,
-                BerError::InvalidTag => Asn1DecodeError::InvalidTag,
-                BerError::InvalidClass => Asn1DecodeError::InvalidClass,
-                BerError::InvalidLength => Asn1DecodeError::InvalidLength,
-                BerError::ConstructExpected => Asn1DecodeError::ConstructExpected,
-                BerError::ConstructUnexpected => Asn1DecodeError::ConstructUnexpected,
-                BerError::IntegerTooLarge => Asn1DecodeError::IntegerTooLarge,
-                BerError::BerMaxDepth => Asn1DecodeError::BerMaxDepth,
-                BerError::ObjectTooShort => Asn1DecodeError::ObjectTooShort,
-                BerError::DerConstraintFailed => Asn1DecodeError::DerConstraintFailed,
-                BerError::UnknownTag => Asn1DecodeError::UnknownTag,
-                BerError::Unsupported => Asn1DecodeError::Unsupported,
-                _ => Asn1DecodeError::InvalidStructure,
-            },
-        }
+        Asn1DecodeError::BerError(e)
     }
 }
 
@@ -430,9 +370,7 @@ mod tests {
             ..Default::default()
         }, None; "Test double_overflow rule (non-match)" )]
     fn test_checks(
-        rule: &str,
-        asn1_buf: &'static [u8],
-        expected_data: DetectAsn1Data,
+        rule: &str, asn1_buf: &'static [u8], expected_data: DetectAsn1Data,
         expected_check: Option<Asn1Check>,
     ) {
         // Parse rule
@@ -443,7 +381,7 @@ mod tests {
         let asn1 = Asn1::from_slice(asn1_buf, &ad).unwrap();
 
         // Run checks
-        let result = asn1.check(&ad).unwrap();
+        let result = asn1.check(&ad);
         assert_eq!(expected_check, result);
     }
 }