From a992480818d9e26052ac4854f642b76da8870ee7 Mon Sep 17 00:00:00 2001 From: Jason Ish Date: Mon, 7 Jul 2025 00:01:06 -0600 Subject: [PATCH] runner: convert to multi-threading from multi-processing Multi-processing has issues on Windows, moving to threading does allow it to work on Windows (with a few issues to still figure out). This removes the single thread runner, instead for Windows and Mac we'll just use one job for now, and change that default as we prove its reliaable. Update ctrl-c handling as well, for reliable ctrl-c cancellation. --- run.py | 112 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 69 insertions(+), 43 deletions(-) diff --git a/run.py b/run.py index 30f694313..61ee2f986 100755 --- a/run.py +++ b/run.py @@ -37,7 +37,7 @@ import glob import re import json import unittest -import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor from collections import namedtuple import threading import filecmp @@ -45,12 +45,10 @@ import subprocess import yaml import traceback import platform +import signal VALIDATE_EVE = False -# Windows and macOS don't support the mp logic below. WIN32 = sys.platform == "win32" -DARWIN = sys.platform == "darwin" -MP = not WIN32 and not DARWIN suricata_yaml = "suricata.yaml" if WIN32 else "./suricata.yaml" # Determine the Suricata binary @@ -61,25 +59,20 @@ else: PROC_TIMEOUT=300 -if MP: - manager = mp.Manager() - lock = mp.Lock() - failedLogs = manager.list() - count_dict = manager.dict() - check_args = manager.dict() -else: - failedLogs = [] - count_dict = {} - check_args = {} - # Bring in a lock from threading to satisfy the MP semantics when - # not using MP. - lock = threading.Lock() +lock = threading.Lock() +failedLogs = [] +count_dict = {} +check_args = {} count_dict['passed'] = 0 count_dict['failed'] = 0 count_dict['skipped'] = 0 check_args['fail'] = 0 +# Global flag for shutdown signal +shutdown_requested = False +executor_instance = None + class SelfTest(unittest.TestCase): def test_parse_suricata_version(self): @@ -831,8 +824,22 @@ class TestRunner: r.terminate() try: - r = p.wait(timeout=PROC_TIMEOUT) - except: + # Check for shutdown request periodically during wait + while p.poll() is None: + if shutdown_requested: + print("\\nTerminating test process due to shutdown request") + p.terminate() + p.wait(timeout=5) # Give it 5 seconds to terminate gracefully + raise TestError("interrupted by user") + # Wait a short time before checking again + try: + r = p.wait(timeout=1) + break + except subprocess.TimeoutExpired: + continue + else: + r = p.returncode + except subprocess.TimeoutExpired: print("Suricata timed out, terminating") p.terminate() raise TestError("timed out when expected exit code %d" % ( @@ -1058,7 +1065,7 @@ def check_deps(): def run_test(dirpath, args, cwd, suricata_config): with lock: - if check_args['fail'] == 1: + if check_args['fail'] == 1 or shutdown_requested: raise TerminatePoolError() name = os.path.basename(dirpath) @@ -1110,23 +1117,42 @@ def run_test(dirpath, args, cwd, suricata_config): failedLogs.append(dirpath) raise TerminatePoolError() -def run_mp(jobs, tests, dirpath, args, cwd, suricata_config): +def signal_handler(signum, frame): + global shutdown_requested, executor_instance + print("\nReceived interrupt signal, shutting down gracefully...") + shutdown_requested = True + if executor_instance: + executor_instance.shutdown(wait=False) + sys.exit(1) + +def run_parallel(jobs, tests, args, cwd, suricata_config): + global executor_instance print("Number of concurrent jobs: %d" % jobs) - pool = mp.Pool(jobs) - try: - for dirpath in tests: - pool.apply_async(run_test, args=(dirpath, args, cwd, suricata_config)) - except TerminatePoolError: - pool.terminate() - pool.close() - pool.join() - -def run_single(tests, dirpath, args, cwd, suricata_config): - try: - for dirpath in tests: - run_test(dirpath, args, cwd, suricata_config) - except TerminatePoolError: - sys.exit(1) + + # Set up signal handler + signal.signal(signal.SIGINT, signal_handler) + + with ThreadPoolExecutor(max_workers=jobs) as executor: + executor_instance = executor + try: + futures = [] + for dirpath in tests: + if shutdown_requested: + break + future = executor.submit(run_test, dirpath, args, cwd, suricata_config) + futures.append(future) + + # Wait for all futures to complete and re-raise any exceptions + for future in futures: + if shutdown_requested: + break + future.result() + except (TerminatePoolError, KeyboardInterrupt): + executor.shutdown(wait=False) + # Don't exit immediately - let the function return so summary can be printed + return + finally: + executor_instance = None def build_eve_validator(): env = os.environ.copy() @@ -1169,9 +1195,8 @@ def main(): help="Clean up output directories of passing tests") parser.add_argument("--no-validation", action="store_true", help="Disable EVE validation") parser.add_argument("patterns", nargs="*", default=[]) - if MP: - parser.add_argument("-j", type=int, default=min(8, mp.cpu_count()), - help="Number of jobs to run") + parser.add_argument("-j", type=int, default=min(8, os.cpu_count()), + help="Number of jobs to run (threads)") args = parser.parse_args() if args.self_test: @@ -1255,10 +1280,11 @@ def main(): # Sort alphabetically. tests.sort() - if MP and args.j > 1: - run_mp(args.j, tests, dirpath, args, cwd, suricata_config) - else: - run_single(tests, dirpath, args, cwd, suricata_config) + try: + run_parallel(args.j, tests, args, cwd, suricata_config) + except KeyboardInterrupt: + print("\nInterrupted by user") + return 1 passed = count_dict["passed"] failed = count_dict["failed"] -- 2.47.3