]> git.ipfire.org Git - thirdparty/suricata-verify.git/commitdiff
runner: remove os.chdir, use full paths as needed
authorJason Ish <jason.ish@oisf.net>
Mon, 7 Jul 2025 05:19:30 +0000 (23:19 -0600)
committerVictor Julien <victor@inliniac.net>
Tue, 15 Jul 2025 14:40:39 +0000 (16:40 +0200)
In preparation for multi-threading, where we can't chdir as that would
affect other threads.

run.py

diff --git a/run.py b/run.py
index f9967eaa38028f1d46b2fc557e3e932fae12594d..30f694313c4f7f3f7a956cc536823914a68a1337 100755 (executable)
--- a/run.py
+++ b/run.py
@@ -336,7 +336,7 @@ def check_filter_test_version_compat(requires, test_version):
                     raise UnnecessaryRequirementError(
                         "test already requires min {} not needed for the check {}".format(test_version["min"], requires["min-version"]))
 
-def check_requires(requires, suricata_config: SuricataConfig):
+def check_requires(requires, suricata_config: SuricataConfig, test_dir=None):
     suri_version = suricata_config.version
     for key in requires:
         if key == "min-version":
@@ -375,10 +375,14 @@ def check_requires(requires, suricata_config: SuricataConfig):
                         "requires env var %s" % (env))
         elif key == "files":
             for filename in requires["files"]:
+                if test_dir and not os.path.isabs(filename):
+                    filename = os.path.join(test_dir, filename)
                 if not os.path.exists(filename):
                     raise UnsatisfiedRequirementError(
                         "requires file %s" % (filename))
         elif key == "script":
+            # This is run for the current directory (the Suricata
+            # source directory).
             for script in requires["script"]:
                 try:
                     subprocess.check_call("%s" % script, shell=True)
@@ -463,15 +467,18 @@ def rule_is_version_compatible(rulefile, suri_version):
 
 class FileCompareCheck:
 
-    def __init__(self, config, directory):
+    def __init__(self, config, directory, cwd):
         self.config = config
         self.directory = directory
+        self.cwd = cwd
 
     def run(self):
         if WIN32:
             raise UnsatisfiedRequirementError("shell check not supported on Windows")
         expected = os.path.join(self.directory, self.config["expected"])
         filename = self.config["filename"]
+        if self.cwd and not os.path.isabs(filename):
+            filename = os.path.join(self.cwd, filename)
         try:
             if filecmp.cmp(expected, filename):
                 return True
@@ -482,10 +489,12 @@ class FileCompareCheck:
 
 class ShellCheck:
 
-    def __init__(self, config, env, suricata_config):
+    def __init__(self, config, env, suricata_config, output_dir, test_dir):
         self.config = config
         self.env = env
         self.suricata_config = suricata_config
+        self.cwd = output_dir
+        self.script_cwd = test_dir
 
     def run(self):
         shell_args = {}
@@ -500,12 +509,12 @@ class ShellCheck:
             shell_args["min-version"] = min_version
         if lt_version is not None:
             shell_args["lt-version"] = lt_version
-        check_requires(shell_args, self.suricata_config)
+        check_requires(shell_args, self.suricata_config, self.script_cwd)
 
         try:
             if WIN32:
                 raise UnsatisfiedRequirementError("shell check not supported on Windows")
-            output = subprocess.check_output(self.config["args"], shell=True, env=self.env)
+            output = subprocess.check_output(self.config["args"], shell=True, env=self.env, cwd=self.cwd)
             if "expect" in self.config:
                 return str(self.config["expect"]) == output.decode().strip()
             return True
@@ -521,7 +530,8 @@ class StatsCheck:
 
     def run(self):
         stats = None
-        with open("eve.json", "r") as fileobj:
+        eve_json_path = os.path.join(self.outdir, "eve.json")
+        with open(eve_json_path, "r") as fileobj:
             for line in fileobj:
                 event = json.loads(line)
                 if event["event_type"] == "stats":
@@ -535,12 +545,13 @@ class StatsCheck:
 
 class FilterCheck:
 
-    def __init__(self, config, outdir, suricata_config, test_version):
+    def __init__(self, config, outdir, suricata_config, test_version, script_cwd=None):
         self.config = config
         self.outdir = outdir
         self.suricata_config = suricata_config
         self.suri_version = suricata_config.version
         self.test_version = test_version
+        self.script_cwd = script_cwd
 
     def run(self):
         requires = self.config.get("requires", {})
@@ -557,12 +568,14 @@ class FilterCheck:
         feature = self.config.get("feature")
         if feature is not None:
             requires["features"] = [feature]
-        check_requires(requires, self.suricata_config)
+        check_requires(requires, self.suricata_config, self.script_cwd)
 
         if "filename" in self.config:
             json_filename = self.config["filename"]
+            if not os.path.isabs(json_filename):
+                json_filename = os.path.join(self.outdir, json_filename)
         else:
-            json_filename = "eve.json"
+            json_filename = os.path.join(self.outdir, "eve.json")
         if not os.path.exists(json_filename):
             raise TestError("%s does not exist" % (json_filename))
 
@@ -695,7 +708,7 @@ class TestRunner:
 
     def check_requires(self):
         requires = self.config.get("requires", {})
-        check_requires(requires, self.suricata_config)
+        check_requires(requires, self.suricata_config, self.directory)
         for key in requires:
             if key == "min-version":
                 self.version["min"] = requires["min-version"]
@@ -857,17 +870,17 @@ class TestRunner:
 
     def pre_check(self):
         if "pre-check" in self.config:
-            subprocess.call(self.config["pre-check"], shell=True)
+            subprocess.call(self.config["pre-check"], shell=True, cwd=self.output)
 
     @handle_exceptions
     def perform_filter_checks(self, check, count, test_num, test_name):
         count = FilterCheck(check, self.output,
-                self.suricata_config, self.version).run()
+                self.suricata_config, self.version, self.directory).run()
         return count
 
     @handle_exceptions
     def perform_shell_checks(self, check, count, test_num, test_name):
-        count = ShellCheck(check, self.build_env(), self.suricata_config).run()
+        count = ShellCheck(check, self.build_env(), self.suricata_config, self.output, self.directory).run()
         return count
 
     @handle_exceptions
@@ -877,7 +890,7 @@ class TestRunner:
 
     @handle_exceptions
     def perform_file_compare_checks(self, check, count, test_num, test_name):
-        count = FileCompareCheck(check, self.directory).run()
+        count = FileCompareCheck(check, self.directory, self.output).run()
         return count
 
     def reset_count(self, dictionary):
@@ -885,27 +898,22 @@ class TestRunner:
             dictionary[k] = 0
 
     def check(self):
-        pdir = os.getcwd()
-        os.chdir(self.output)
         count = {
             "success": 0,
             "failure": 0,
             "skipped": 0,
                 }
-        try:
-            self.pre_check()
-            if "checks" in self.config:
-                self.reset_count(count)
-                for check_count, check in enumerate(self.config["checks"]):
-                    for key in check:
-                        if key in ["filter", "shell", "stats", "file-compare"]:
-                            func = getattr(self, "perform_{}_checks".format(key.replace("-","_")))
-                            count = func(check=check[key], count=count,
-                                    test_num=check_count + 1, test_name=os.path.basename(self.directory))
-                        else:
-                            print("FAIL: Unknown check type: {}".format(key))
-        finally:
-            os.chdir(pdir)
+        self.pre_check()
+        if "checks" in self.config:
+            self.reset_count(count)
+            for check_count, check in enumerate(self.config["checks"]):
+                for key in check:
+                    if key in ["filter", "shell", "stats", "file-compare"]:
+                        func = getattr(self, "perform_{}_checks".format(key.replace("-","_")))
+                        count = func(check=check[key], count=count,
+                                test_num=check_count + 1, test_name=os.path.basename(self.directory))
+                    else:
+                        print("FAIL: Unknown check type: {}".format(key))
 
         if count["failure"] or count["skipped"]:
             return count
@@ -968,14 +976,10 @@ class TestRunner:
 
         # Find pcaps.
         if "pcap" in self.config:
-            try:
-                curdir = os.getcwd()
-                os.chdir(self.directory)
-                if not os.path.exists(self.config["pcap"]):
-                    raise TestError("PCAP filename does not exist: {}".format(self.config["pcap"]))
-                args += ["-r", os.path.realpath(os.path.join(self.directory, self.config["pcap"]))]
-            finally:
-                os.chdir(curdir)
+            pcap_path = os.path.join(self.directory, self.config["pcap"])
+            if not os.path.exists(pcap_path):
+                raise TestError("PCAP filename does not exist: {}".format(self.config["pcap"]))
+            args += ["-r", os.path.realpath(pcap_path)]
         else:
             pcaps = glob.glob(os.path.join(self.directory, "*.pcap"))
             pcaps += glob.glob(os.path.join(self.directory, "*.pcapng"))