]> git.ipfire.org Git - thirdparty/suricata-verify.git/commitdiff
run: parallel tests capability
authorTharushi Jayasekara <tharushi68@gmail.com>
Mon, 14 Dec 2020 04:45:02 +0000 (10:15 +0530)
committerVictor Julien <victor@inliniac.net>
Thu, 11 Feb 2021 18:55:22 +0000 (19:55 +0100)
Used the python multiprocessing module to add parallel tests
capability.

Added a -j parameter to control the number of jobs defaulting to the
number of cores found by the mp module.

run.py

diff --git a/run.py b/run.py
index eb622719239f59741dac9a26ecd0266089939065..e5879b821e1a80b9e7a592ce00b77044c724bb73 100755 (executable)
--- a/run.py
+++ b/run.py
@@ -37,6 +37,7 @@ import glob
 import re
 import json
 import unittest
+import multiprocessing as mp
 from collections import namedtuple
 
 import yaml
@@ -45,6 +46,17 @@ WIN32 = sys.platform == "win32"
 suricata_bin = "src\suricata.exe" if WIN32 else "./src/suricata"
 suricata_yaml = "suricata.yaml" if WIN32 else "./suricata.yaml"
 
+manager = mp.Manager()
+lock = mp.Lock()
+failedLogs = manager.list()
+count_dict = manager.dict()
+check_args = manager.dict()
+
+count_dict['passed'] = 0
+count_dict['failed'] = 0
+count_dict['skipped'] = 0
+check_args['fail'] = 0
+
 class SelfTest(unittest.TestCase):
 
     def test_parse_suricata_version(self):
@@ -87,6 +99,9 @@ class TestError(Exception):
 class UnsatisfiedRequirementError(Exception):
     pass
 
+class TerminatePoolError(Exception):
+    pass
+
 SuricataVersion = namedtuple(
     "SuricataVersion", ["major", "minor", "patch"])
 
@@ -123,17 +138,17 @@ def handle_exceptions(func):
         try:
             result = func(*args,**kwargs)
         except TestError as te:
-            print("Sub test #{}: FAIL : {}".format(kwargs["test_num"], te))
+            print("===> {}: Sub test #{}: FAIL : {}".format(kwargs["test_name"], kwargs["test_num"], te))
             check_args_fail()
             kwargs["count"]["failure"] += 1
         except UnsatisfiedRequirementError as ue:
-            print("Sub test #{}: SKIPPED : {}".format(kwargs["test_num"], ue))
+            print("===> {}: Sub test #{}: SKIPPED : {}".format(kwargs["test_name"], kwargs["test_num"], ue))
             kwargs["count"]["skipped"] += 1
         else:
             if result:
               kwargs["count"]["success"] += 1
             else:
-              print("\nSub test #{}: FAIL : {}".format(kwargs["test_num"], kwargs["check"]["args"]))
+              print("\n===> {}: Sub test #{}: FAIL : {}".format(kwargs["test_name"], kwargs["test_num"], kwargs["check"]["args"]))
               kwargs["count"]["failure"] += 1
         return kwargs["count"]
     return applicator
@@ -541,9 +556,6 @@ class TestRunner:
 
     def run(self):
 
-        sys.stdout.write("===> %s: " % os.path.basename(self.directory))
-        sys.stdout.flush()
-
         if not self.force:
             self.check_requires()
             self.check_skip()
@@ -626,9 +638,9 @@ class TestRunner:
                 return check_value
 
         if not check_value["failure"] and not check_value["skipped"]:
-            print("OK%s" % (" (%dx)" % count if count > 1 else ""))
+            print("===> %s: OK%s" % (os.path.basename(self.directory), " (%dx)" % count if count > 1 else ""))
         elif not check_value["failure"]:
-            print("OK (checks: {}, skipped: {})".format(sum(check_value.values()), check_value["skipped"]))
+            print("===> {}: OK (checks: {}, skipped: {})".format(os.path.basename(self.directory), sum(check_value.values()), check_value["skipped"]))
         return check_value
 
     def pre_check(self):
@@ -636,18 +648,18 @@ class TestRunner:
             subprocess.call(self.config["pre-check"], shell=True)
 
     @handle_exceptions
-    def perform_filter_checks(self, check, count, test_num):
+    def perform_filter_checks(self, check, count, test_num, test_name):
         count = FilterCheck(check, self.output,
                 self.suricata_config).run()
         return count
 
     @handle_exceptions
-    def perform_shell_checks(self, check, count, test_num):
+    def perform_shell_checks(self, check, count, test_num, test_name):
         count = ShellCheck(check).run()
         return count
 
     @handle_exceptions
-    def perform_stats_checks(self, check, count, test_num):
+    def perform_stats_checks(self, check, count, test_num, test_name):
         count = StatsCheck(check, self.output).run()
         return count
 
@@ -673,7 +685,7 @@ class TestRunner:
                         if key in ["filter", "shell", "stats"]:
                             func = getattr(self, "perform_{}_checks".format(key))
                             count = func(check=check[key], count=count,
-                                    test_num=check_count + 1)
+                                    test_num=check_count + 1, test_name=os.path.basename(self.directory))
                         else:
                             print("FAIL: Unknown check type: {}".format(key))
         finally:
@@ -805,7 +817,8 @@ class TestRunner:
 
 def check_args_fail():
     if args.fail:
-        sys.exit(1)
+        with lock:
+            check_args['fail'] = 1
 
 
 def check_deps():
@@ -825,6 +838,41 @@ def check_deps():
 
     return True
 
+def run_test(dirpath, args, cwd, suricata_config):
+    with lock:
+        if check_args['fail'] == 1:
+            raise TerminatePoolError()
+
+    name = os.path.basename(dirpath)
+
+    outdir = os.path.join(dirpath, "output")
+    if args.outdir:
+        outdir = os.path.join(os.path.realpath(args.outdir), name, "output")
+
+    test_runner = TestRunner(
+        cwd, dirpath, outdir, suricata_config, args.verbose, args.force)
+    try:
+        results = test_runner.run()
+        if results["failure"] > 0:
+            with lock:
+                count_dict["failed"] += 1
+                failedLogs.append(dirpath)
+        elif results["skipped"] > 0 and results["success"] == 0:
+            with lock:
+                count_dict["skipped"] += 1
+        elif results["success"] > 0:
+            with lock:
+                count_dict["passed"] += 1
+    except UnsatisfiedRequirementError as ue:
+        print("===> {}: SKIPPED: {}".format(os.path.basename(dirpath), ue))
+        with lock:
+            count_dict["skipped"] += 1
+    except TestError as te:
+        print("===> {}: FAILED: {}".format(os.path.basename(dirpath), te))
+        check_args_fail()
+        with lock:
+            count_dict["failed"] += 1
+
 def main():
     global TOPDIR
     global args
@@ -834,6 +882,8 @@ def main():
 
     parser = argparse.ArgumentParser(description="Verification test runner.")
     parser.add_argument("-v", dest="verbose", action="store_true")
+    parser.add_argument("-j", type=int, default=mp.cpu_count(),
+                        help="Number of jobs to run")
     parser.add_argument("--force", dest="force", action="store_true",
                         help="Force running of skipped tests")
     parser.add_argument("--fail", action="store_true",
@@ -921,33 +971,24 @@ def main():
 
     # Sort alphabetically.
     tests.sort()
-    failedLogs = []
 
-    for dirpath in tests:
-        name = os.path.basename(dirpath)
+    jobs = args.j
+    print("Number of concurrent jobs: %d" % jobs)
 
-        outdir = os.path.join(dirpath, "output")
-        if args.outdir:
-            outdir = os.path.join(os.path.realpath(args.outdir), name, "output")
+    pool = mp.Pool(jobs)
 
-        test_runner = TestRunner(
-            cwd, dirpath, outdir, suricata_config, args.verbose, args.force)
-        try:
-            results = test_runner.run()
-            if results["failure"] > 0:
-                failed += 1
-                failedLogs.append(dirpath)
-            elif results["skipped"] > 0 and results["success"] == 0:
-                skipped += 1
-            elif results["success"] > 0:
-                passed += 1
-        except UnsatisfiedRequirementError as ue:
-            print("SKIPPED: {}".format(ue))
-            skipped += 1
-        except TestError as te:
-            print("FAILED: {}".format(te))
-            check_args_fail()
-            failed += 1
+    try:
+        for dirpath in tests:
+            pool.apply_async(run_test, args=(dirpath, args, cwd, suricata_config))
+    except TerminatePoolError:
+        pool.terminate()
+
+    pool.close()
+    pool.join()
+
+    passed = count_dict["passed"]
+    failed = count_dict["failed"]
+    skipped = count_dict["skipped"]
 
     print("")
     print("PASSED:  %d" % (passed))