]> git.ipfire.org Git - thirdparty/suricata-verify.git/commitdiff
runner: convert to multi-threading from multi-processing
authorJason Ish <jason.ish@oisf.net>
Mon, 7 Jul 2025 06:01:06 +0000 (00:01 -0600)
committerVictor Julien <victor@inliniac.net>
Tue, 15 Jul 2025 14:40:39 +0000 (16:40 +0200)
Multi-processing has issues on Windows, moving to threading does allow
it to work on Windows (with a few issues to still figure out).

This removes the single thread runner, instead for Windows and Mac
we'll just use one job for now, and change that default as we prove
its reliaable.

Update ctrl-c handling as well, for reliable ctrl-c cancellation.

run.py

diff --git a/run.py b/run.py
index 30f694313c4f7f3f7a956cc536823914a68a1337..61ee2f9862889f0d1799995316b8d7fa96329aae 100755 (executable)
--- a/run.py
+++ b/run.py
@@ -37,7 +37,7 @@ import glob
 import re
 import json
 import unittest
-import multiprocessing as mp
+from concurrent.futures import ThreadPoolExecutor
 from collections import namedtuple
 import threading
 import filecmp
@@ -45,12 +45,10 @@ import subprocess
 import yaml
 import traceback
 import platform
+import signal
 
 VALIDATE_EVE = False
-# Windows and macOS don't support the mp logic below.
 WIN32 = sys.platform == "win32"
-DARWIN = sys.platform == "darwin"
-MP = not WIN32 and not DARWIN
 suricata_yaml = "suricata.yaml" if WIN32 else "./suricata.yaml"
 
 # Determine the Suricata binary
@@ -61,25 +59,20 @@ else:
 
 PROC_TIMEOUT=300
 
-if MP:
-    manager = mp.Manager()
-    lock = mp.Lock()
-    failedLogs = manager.list()
-    count_dict = manager.dict()
-    check_args = manager.dict()
-else:
-    failedLogs = []
-    count_dict = {}
-    check_args = {}
-    # Bring in a lock from threading to satisfy the MP semantics when
-    # not using MP.
-    lock = threading.Lock()
+lock = threading.Lock()
+failedLogs = []
+count_dict = {}
+check_args = {}
 
 count_dict['passed'] = 0
 count_dict['failed'] = 0
 count_dict['skipped'] = 0
 check_args['fail'] = 0
 
+# Global flag for shutdown signal
+shutdown_requested = False
+executor_instance = None
+
 class SelfTest(unittest.TestCase):
 
     def test_parse_suricata_version(self):
@@ -831,8 +824,22 @@ class TestRunner:
                         r.terminate()
 
                 try:
-                    r = p.wait(timeout=PROC_TIMEOUT)
-                except:
+                    # Check for shutdown request periodically during wait
+                    while p.poll() is None:
+                        if shutdown_requested:
+                            print("\\nTerminating test process due to shutdown request")
+                            p.terminate()
+                            p.wait(timeout=5)  # Give it 5 seconds to terminate gracefully
+                            raise TestError("interrupted by user")
+                        # Wait a short time before checking again
+                        try:
+                            r = p.wait(timeout=1)
+                            break
+                        except subprocess.TimeoutExpired:
+                            continue
+                    else:
+                        r = p.returncode
+                except subprocess.TimeoutExpired:
                     print("Suricata timed out, terminating")
                     p.terminate()
                     raise TestError("timed out when expected exit code %d" % (
@@ -1058,7 +1065,7 @@ def check_deps():
 
 def run_test(dirpath, args, cwd, suricata_config):
     with lock:
-        if check_args['fail'] == 1:
+        if check_args['fail'] == 1 or shutdown_requested:
             raise TerminatePoolError()
 
     name = os.path.basename(dirpath)
@@ -1110,23 +1117,42 @@ def run_test(dirpath, args, cwd, suricata_config):
             failedLogs.append(dirpath)
             raise TerminatePoolError()
 
-def run_mp(jobs, tests, dirpath, args, cwd, suricata_config):
+def signal_handler(signum, frame):
+    global shutdown_requested, executor_instance
+    print("\nReceived interrupt signal, shutting down gracefully...")
+    shutdown_requested = True
+    if executor_instance:
+        executor_instance.shutdown(wait=False)
+    sys.exit(1)
+
+def run_parallel(jobs, tests, args, cwd, suricata_config):
+    global executor_instance
     print("Number of concurrent jobs: %d" % jobs)
-    pool = mp.Pool(jobs)
-    try:
-        for dirpath in tests:
-            pool.apply_async(run_test, args=(dirpath, args, cwd, suricata_config))
-    except TerminatePoolError:
-        pool.terminate()
-    pool.close()
-    pool.join()
-
-def run_single(tests, dirpath, args, cwd, suricata_config):
-    try:
-        for dirpath in tests:
-            run_test(dirpath, args, cwd, suricata_config)
-    except TerminatePoolError:
-        sys.exit(1)
+    
+    # Set up signal handler
+    signal.signal(signal.SIGINT, signal_handler)
+    
+    with ThreadPoolExecutor(max_workers=jobs) as executor:
+        executor_instance = executor
+        try:
+            futures = []
+            for dirpath in tests:
+                if shutdown_requested:
+                    break
+                future = executor.submit(run_test, dirpath, args, cwd, suricata_config)
+                futures.append(future)
+            
+            # Wait for all futures to complete and re-raise any exceptions
+            for future in futures:
+                if shutdown_requested:
+                    break
+                future.result()
+        except (TerminatePoolError, KeyboardInterrupt):
+            executor.shutdown(wait=False)
+            # Don't exit immediately - let the function return so summary can be printed
+            return
+        finally:
+            executor_instance = None
 
 def build_eve_validator():
     env = os.environ.copy()
@@ -1169,9 +1195,8 @@ def main():
                         help="Clean up output directories of passing tests")
     parser.add_argument("--no-validation", action="store_true", help="Disable EVE validation")
     parser.add_argument("patterns", nargs="*", default=[])
-    if MP:
-        parser.add_argument("-j", type=int, default=min(8, mp.cpu_count()),
-                        help="Number of jobs to run")
+    parser.add_argument("-j", type=int, default=min(8, os.cpu_count()),
+                        help="Number of jobs to run (threads)")
     args = parser.parse_args()
 
     if args.self_test:
@@ -1255,10 +1280,11 @@ def main():
     # Sort alphabetically.
     tests.sort()
 
-    if MP and args.j > 1:
-        run_mp(args.j, tests, dirpath, args, cwd, suricata_config)
-    else:
-        run_single(tests, dirpath, args, cwd, suricata_config)
+    try:
+        run_parallel(args.j, tests, args, cwd, suricata_config)
+    except KeyboardInterrupt:
+        print("\nInterrupted by user")
+        return 1
 
     passed = count_dict["passed"]
     failed = count_dict["failed"]