]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
tests: py: Implement payload_record()
authorPhil Sutter <phil@nwl.cc>
Wed, 10 Sep 2025 13:14:23 +0000 (15:14 +0200)
committerPhil Sutter <phil@nwl.cc>
Thu, 23 Oct 2025 20:48:51 +0000 (22:48 +0200)
This is a helper function to store payload records (and JSON
equivalents) in .got files. The code it replaces missed to insert a
newline before the new entry and also did not check for existing records
in all spots.

Signed-off-by: Phil Sutter <phil@nwl.cc>
tests/py/nft-test.py

index 019c828f957a5e862d48b4c4aae9a91d8a5d2c8e..dc074d4c3872a3363d84534c10b1fb4d31bff00f 100755 (executable)
@@ -16,6 +16,7 @@
 from __future__ import print_function
 import sys
 import os
+import io
 import argparse
 import signal
 import json
@@ -741,6 +742,66 @@ def payload_check(payload_buffer, file, cmd):
     return i > 0
 
 
+def payload_record(path, rule, payload, desc="payload"):
+    '''
+    Record payload for @rule in file at @path
+
+    - @payload may be a file handle, a string or an array of strings
+    - Avoid duplicate entries by searching for a match first
+    - Separate entries by a single empty line, so check for trailing newlines
+      before writing
+    - @return False if already existing, True otherwise
+    '''
+    try:
+        with open(path, 'r') as f:
+            lines = f.readlines()
+    except:
+        lines = []
+
+    plines = []
+    if isinstance(payload, io.TextIOWrapper):
+        payload.seek(0, 0)
+        while True:
+            line = payload.readline()
+            if line.startswith("family "):
+                continue
+            if line == "":
+                break
+            plines.append(line)
+    elif isinstance(payload, str):
+        plines = [l + "\n" for l in payload.split("\n")]
+    elif isinstance(payload, list):
+        plines = payload
+    else:
+        raise Exception
+
+    found = False
+    for i in range(len(lines)):
+        if lines[i] == rule + "\n":
+            found = True
+            for pline in plines:
+                i += 1
+                if lines[i] != pline:
+                    found = False
+                    break
+            if found:
+                return False
+
+    try:
+        with open(path, 'a') as f:
+            if len(lines) > 0 and lines[-1] != "\n":
+                f.write("\n")
+            f.write("# %s\n" % rule)
+            f.writelines(plines)
+    except:
+        warnfmt = "Failed to write %s for rule %s"
+    else:
+        warnfmt = "Wrote %s for rule %s"
+
+    print_warning(warnfmt % (desc, rule[0]), os.path.basename(path), 1)
+    return True
+
+
 def json_dump_normalize(json_string, human_readable = False):
     json_obj = json.loads(json_string)
 
@@ -867,28 +928,8 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
             if state == "ok" and not payload_check(table_payload_expected,
                                                    payload_log, cmd):
                 error += 1
-
-                try:
-                    gotf = open("%s.got" % table_payload_path)
-                    gotf_payload_expected = payload_find_expected(gotf, rule[0])
-                    gotf.close()
-                except:
-                    gotf_payload_expected = None
-                payload_log.seek(0, 0)
-                if not payload_check(gotf_payload_expected, payload_log, cmd):
-                    gotf = open("%s.got" % table_payload_path, 'a')
-                    payload_log.seek(0, 0)
-                    gotf.write("# %s\n" % rule[0])
-                    while True:
-                        line = payload_log.readline()
-                        if line.startswith("family "):
-                            continue
-                        if line == "":
-                            break
-                        gotf.write(line)
-                    gotf.close()
-                    print_warning("Wrote payload for rule %s" % rule[0],
-                                  gotf.name, 1)
+                payload_record("%s.got" % table_payload_path,
+                               rule[0], payload_log)
 
             # Check for matching ruleset listing
             numeric_proto_old = nftables.set_numeric_proto_output(True)
@@ -979,13 +1020,9 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
                         json_output = item["rule"]
                         break
                 json_input = json.dumps(json_output["expr"], sort_keys = True)
-
-                gotf = open("%s.json.got" % filename_path, 'a')
-                jdump = json_dump_normalize(json_input, True)
-                gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
-                gotf.close()
-                print_warning("Wrote JSON equivalent for rule %s" % rule[0],
-                              gotf.name, 1)
+                payload_record("%s.json.got" % filename_path, rule[0],
+                               json_dump_normalize(json_input, True),
+                               "JSON equivalent")
 
             table_flush(table, filename, lineno)
             payload_log = tempfile.TemporaryFile(mode="w+")
@@ -1013,17 +1050,8 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
             # Check for matching payload
             if not payload_check(table_payload_expected, payload_log, cmd):
                 error += 1
-                gotf = open("%s.json.payload.got" % filename_path, 'a')
-                payload_log.seek(0, 0)
-                gotf.write("# %s\n" % rule[0])
-                while True:
-                    line = payload_log.readline()
-                    if line == "":
-                        break
-                    gotf.write(line)
-                gotf.close()
-                print_warning("Wrote JSON payload for rule %s" % rule[0],
-                              gotf.name, 1)
+                payload_record("%s.json.payload.got" % filename_path,
+                               rule[0], payload_log, "JSON payload")
 
             # Check for matching ruleset listing
             numeric_proto_old = nftables.set_numeric_proto_output(True)
@@ -1049,12 +1077,9 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
                 print_differences_warning(filename, lineno,
                                           json_input, json_output, cmd)
                 error += 1
-                gotf = open("%s.json.output.got" % filename_path, 'a')
-                jdump = json_dump_normalize(json_output, True)
-                gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
-                gotf.close()
-                print_warning("Wrote JSON output for rule %s" % rule[0],
-                              gotf.name, 1)
+                payload_record("%s.json.output.got" % filename_path, rule[0],
+                               json_dump_normalize(json_output, True),
+                               "JSON output")
                 # prevent further warnings and .got file updates
                 json_expected = json_output
             elif json_expected and json_output != json_expected: