import re
import json
import unittest
-import multiprocessing as mp
+from concurrent.futures import ThreadPoolExecutor
from collections import namedtuple
import threading
import filecmp
import yaml
import traceback
import platform
+import signal
VALIDATE_EVE = False
-# Windows and macOS don't support the mp logic below.
WIN32 = sys.platform == "win32"
-DARWIN = sys.platform == "darwin"
-MP = not WIN32 and not DARWIN
suricata_yaml = "suricata.yaml" if WIN32 else "./suricata.yaml"
# Determine the Suricata binary
PROC_TIMEOUT=300
-if MP:
- manager = mp.Manager()
- lock = mp.Lock()
- failedLogs = manager.list()
- count_dict = manager.dict()
- check_args = manager.dict()
-else:
- failedLogs = []
- count_dict = {}
- check_args = {}
- # Bring in a lock from threading to satisfy the MP semantics when
- # not using MP.
- lock = threading.Lock()
+lock = threading.Lock()
+failedLogs = []
+count_dict = {}
+check_args = {}
count_dict['passed'] = 0
count_dict['failed'] = 0
count_dict['skipped'] = 0
check_args['fail'] = 0
+# Global flag for shutdown signal
+shutdown_requested = False
+executor_instance = None
+
class SelfTest(unittest.TestCase):
def test_parse_suricata_version(self):
r.terminate()
try:
- r = p.wait(timeout=PROC_TIMEOUT)
- except:
+ # Check for shutdown request periodically during wait
+ while p.poll() is None:
+ if shutdown_requested:
+ print("\\nTerminating test process due to shutdown request")
+ p.terminate()
+ p.wait(timeout=5) # Give it 5 seconds to terminate gracefully
+ raise TestError("interrupted by user")
+ # Wait a short time before checking again
+ try:
+ r = p.wait(timeout=1)
+ break
+ except subprocess.TimeoutExpired:
+ continue
+ else:
+ r = p.returncode
+ except subprocess.TimeoutExpired:
print("Suricata timed out, terminating")
p.terminate()
raise TestError("timed out when expected exit code %d" % (
def run_test(dirpath, args, cwd, suricata_config):
with lock:
- if check_args['fail'] == 1:
+ if check_args['fail'] == 1 or shutdown_requested:
raise TerminatePoolError()
name = os.path.basename(dirpath)
failedLogs.append(dirpath)
raise TerminatePoolError()
-def run_mp(jobs, tests, dirpath, args, cwd, suricata_config):
+def signal_handler(signum, frame):
+ global shutdown_requested, executor_instance
+ print("\nReceived interrupt signal, shutting down gracefully...")
+ shutdown_requested = True
+ if executor_instance:
+ executor_instance.shutdown(wait=False)
+ sys.exit(1)
+
+def run_parallel(jobs, tests, args, cwd, suricata_config):
+ global executor_instance
print("Number of concurrent jobs: %d" % jobs)
- pool = mp.Pool(jobs)
- try:
- for dirpath in tests:
- pool.apply_async(run_test, args=(dirpath, args, cwd, suricata_config))
- except TerminatePoolError:
- pool.terminate()
- pool.close()
- pool.join()
-
-def run_single(tests, dirpath, args, cwd, suricata_config):
- try:
- for dirpath in tests:
- run_test(dirpath, args, cwd, suricata_config)
- except TerminatePoolError:
- sys.exit(1)
+
+ # Set up signal handler
+ signal.signal(signal.SIGINT, signal_handler)
+
+ with ThreadPoolExecutor(max_workers=jobs) as executor:
+ executor_instance = executor
+ try:
+ futures = []
+ for dirpath in tests:
+ if shutdown_requested:
+ break
+ future = executor.submit(run_test, dirpath, args, cwd, suricata_config)
+ futures.append(future)
+
+ # Wait for all futures to complete and re-raise any exceptions
+ for future in futures:
+ if shutdown_requested:
+ break
+ future.result()
+ except (TerminatePoolError, KeyboardInterrupt):
+ executor.shutdown(wait=False)
+ # Don't exit immediately - let the function return so summary can be printed
+ return
+ finally:
+ executor_instance = None
def build_eve_validator():
env = os.environ.copy()
help="Clean up output directories of passing tests")
parser.add_argument("--no-validation", action="store_true", help="Disable EVE validation")
parser.add_argument("patterns", nargs="*", default=[])
- if MP:
- parser.add_argument("-j", type=int, default=min(8, mp.cpu_count()),
- help="Number of jobs to run")
+ parser.add_argument("-j", type=int, default=min(8, os.cpu_count()),
+ help="Number of jobs to run (threads)")
args = parser.parse_args()
if args.self_test:
# Sort alphabetically.
tests.sort()
- if MP and args.j > 1:
- run_mp(args.j, tests, dirpath, args, cwd, suricata_config)
- else:
- run_single(tests, dirpath, args, cwd, suricata_config)
+ try:
+ run_parallel(args.j, tests, args, cwd, suricata_config)
+ except KeyboardInterrupt:
+ print("\nInterrupted by user")
+ return 1
passed = count_dict["passed"]
failed = count_dict["failed"]