]> git.ipfire.org Git - thirdparty/suricata-verify.git/commitdiff
runner: fix error looking for filter comment
authorJason Ish <ish@unx.ca>
Fri, 12 Jan 2018 14:02:48 +0000 (08:02 -0600)
committerJason Ish <ish@unx.ca>
Fri, 12 Jan 2018 14:02:48 +0000 (08:02 -0600)
run.py

diff --git a/run.py b/run.py
index 35a6d69a72959fa7afc02378f6353b7e1a316af6..bd85ece4bee0655ded854b84e162529bd3f128d1 100755 (executable)
--- a/run.py
+++ b/run.py
@@ -39,6 +39,9 @@ from collections import namedtuple
 
 import yaml
 
+class TestError(Exception):
+    pass
+
 class UnsatisfiedRequirementError(Exception):
     pass
 
@@ -202,7 +205,7 @@ class StatsCheck:
         for key in self.config:
             val = find_value(key, stats)
             if val != self.config[key]:
-                raise Exception("stats.%s: expected %s; got %s" % (
+                raise TestError("stats.%s: expected %s; got %s" % (
                     key, str(self.config[key]), str(val)))
         return True
 
@@ -220,10 +223,10 @@ class FilterCheck:
                     count += 1
         if count == self.config["count"]:
             return True
-        if self.config["comment"]:
-            raise Exception("%s: expected %d, got %d" % (
+        if "comment" in self.config:
+            raise TestError("%s: expected %d, got %d" % (
                 self.config["comment"], self.config["count"], count))
-        raise Exception("expected %d matches; got %d for filter %s" % (
+        raise TestError("expected %d matches; got %d for filter %s" % (
             self.config["count"], count, str(self.config)))
 
     def match(self, event):
@@ -319,18 +322,18 @@ class TestRunner:
                     for key in check:
                         if key == "filter":
                             if not FilterCheck(check[key]).run():
-                                raise Exception("filter did not match: %s" % (
+                                raise TestError("filter did not match: %s" % (
                                     str(check[key])))
                         elif key == "shell":
                             if not ShellCheck(check[key]).run():
-                                raise Exception(
+                                raise TestError(
                                     "shell output did not match: %s" % (
                                         str(check[key])))
                         elif key == "stats":
                             if not StatsCheck(check[key]).run():
-                                raise Exception("stats check did not pass")
+                                raise TestError("stats check did not pass")
                         else:
-                            raise Exception("Unknown check type: %s" % (key))
+                            raise TestError("Unknown check type: %s" % (key))
         finally:
             os.chdir(pdir)
 
@@ -366,9 +369,9 @@ class TestRunner:
         # Find pcaps.
         pcaps = glob.glob(os.path.join(self.directory, "*.pcap"))
         if not pcaps:
-            raise Exception("No pcap file found")
+            raise TestError("No pcap file found")
         elif len(pcaps) > 1:
-            raise Exception("More than 1 pcap file found")
+            raise TestError("More than 1 pcap file found")
         args += ["-r", pcaps[0]]
 
         # Find rules.
@@ -378,7 +381,7 @@ class TestRunner:
         elif len(rules) == 1:
             args += ["-S", rules[0]]
         else:
-            raise Exception("More than 1 rule file found")
+            raise TestError("More than 1 rule file found")
 
         return args
 
@@ -469,11 +472,13 @@ def main():
             except UnsatisfiedRequirementError as err:
                 print("SKIPPED: %s" % (str(err)))
                 skipped += 1
-            except Exception as err:
-                print("FAIL: exception: %s" % (str(err)))
+            except TestError as err:
+                print("FAIL: %s" % (str(err)))
                 failed += 1
                 if args.fail:
                     return 1
+            except Exception as err:
+                raise
 
     print("")
     print("PASSED:  %d" % (passed))