]> git.ipfire.org Git - thirdparty/iptables.git/blobdiff - iptables-test.py
man: Do not escape exclamation marks
[thirdparty/iptables.git] / iptables-test.py
index 0ba3d36864fd762ea081c20a5b420820586c8f31..6f63cdbeda9af819be713f22c4c553da6f3f5d10 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 #
 # (C) 2012-2013 by Pablo Neira Ayuso <pablo@netfilter.org>
 #
@@ -54,12 +54,12 @@ def print_error(reason, filename=None, lineno=None):
         ": line %d (%s)" % (lineno, reason), file=sys.stderr)
 
 
-def delete_rule(iptables, rule, filename, lineno):
+def delete_rule(iptables, rule, filename, lineno, netns = None):
     '''
     Removes an iptables rule
     '''
     cmd = iptables + " -D " + rule
-    ret = execute_cmd(cmd, filename, lineno)
+    ret = execute_cmd(cmd, filename, lineno, netns)
     if ret == 1:
         reason = "cannot delete: " + iptables + " -I " + rule
         print_error(reason, filename, lineno)
@@ -73,26 +73,24 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns):
     Executes an unit test. Returns the output of delete_rule().
 
     Parameters:
-    :param  iptables: string with the iptables command to execute
+    :param iptables: string with the iptables command to execute
     :param rule: string with iptables arguments for the rule to test
-    :param rule_save: string to find the rule in the output of iptables -save
+    :param rule_save: string to find the rule in the output of iptables-save
     :param res: expected result of the rule. Valid values: "OK", "FAIL"
     :param filename: name of the file tested (used for print_error purposes)
     :param lineno: line number being tested (used for print_error purposes)
+    :param netns: network namespace to call commands in (or None)
     '''
     ret = 0
 
     cmd = iptables + " -A " + rule
-    if netns:
-            cmd = "ip netns exec ____iptables-container-test " + EXECUTEABLE + " " + cmd
-
-    ret = execute_cmd(cmd, filename, lineno)
+    ret = execute_cmd(cmd, filename, lineno, netns)
 
     #
     # report failed test
     #
     if ret:
-        if res == "OK":
+        if res != "FAIL":
             reason = "cannot load: " + cmd
             print_error(reason, filename, lineno)
             return -1
@@ -103,32 +101,32 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns):
         if res == "FAIL":
             reason = "should fail: " + cmd
             print_error(reason, filename, lineno)
-            delete_rule(iptables, rule, filename, lineno)
+            delete_rule(iptables, rule, filename, lineno, netns)
             return -1
 
     matching = 0
-    splitted = iptables.split(" ")
-    if len(splitted) == 2:
-        if splitted[1] == '-4':
+    tokens = iptables.split(" ")
+    if len(tokens) == 2:
+        if tokens[1] == '-4':
             command = IPTABLES_SAVE
-        elif splitted[1] == '-6':
+        elif tokens[1] == '-6':
             command = IP6TABLES_SAVE
-    elif len(splitted) == 1:
-        if splitted[0] == IPTABLES:
+    elif len(tokens) == 1:
+        if tokens[0] == IPTABLES:
             command = IPTABLES_SAVE
-        elif splitted[0] == IP6TABLES:
+        elif tokens[0] == IP6TABLES:
             command = IP6TABLES_SAVE
-        elif splitted[0] == ARPTABLES:
+        elif tokens[0] == ARPTABLES:
             command = ARPTABLES_SAVE
-        elif splitted[0] == EBTABLES:
+        elif tokens[0] == EBTABLES:
             command = EBTABLES_SAVE
 
-    command = EXECUTEABLE + " " + command
+    command = EXECUTABLE + " " + command
 
     if netns:
-            command = "ip netns exec ____iptables-container-test " + command
+            command = "ip netns exec " + netns + " " + command
 
-    args = splitted[1:]
+    args = tokens[1:]
     proc = subprocess.Popen(command, shell=True,
                             stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -138,18 +136,28 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns):
     # check for segfaults
     #
     if proc.returncode == -11:
-        reason = "iptables-save segfaults: " + cmd
+        reason = command + " segfaults!"
         print_error(reason, filename, lineno)
-        delete_rule(iptables, rule, filename, lineno)
+        delete_rule(iptables, rule, filename, lineno, netns)
         return -1
 
     # find the rule
     matching = out.find(rule_save.encode('utf-8'))
     if matching < 0:
-        reason = "cannot find: " + iptables + " -I " + rule
-        print_error(reason, filename, lineno)
-        delete_rule(iptables, rule, filename, lineno)
-        return -1
+        if res == "OK":
+            reason = "cannot find: " + iptables + " -I " + rule
+            print_error(reason, filename, lineno)
+            delete_rule(iptables, rule, filename, lineno, netns)
+            return -1
+        else:
+            # do not report this error
+            return 0
+    else:
+        if res != "OK":
+            reason = "should not match: " + cmd
+            print_error(reason, filename, lineno)
+            delete_rule(iptables, rule, filename, lineno, netns)
+            return -1
 
     # Test "ip netns del NETNS" path with rules in place
     if netns:
@@ -157,7 +165,7 @@ def run_test(iptables, rule, rule_save, res, filename, lineno, netns):
 
     return delete_rule(iptables, rule, filename, lineno)
 
-def execute_cmd(cmd, filename, lineno):
+def execute_cmd(cmd, filename, lineno = 0, netns = None):
     '''
     Executes a command, checking for segfaults and returning the command exit
     code.
@@ -165,10 +173,14 @@ def execute_cmd(cmd, filename, lineno):
     :param cmd: string with the command to be executed
     :param filename: name of the file tested (used for print_error purposes)
     :param lineno: line number being tested (used for print_error purposes)
+    :param netns: network namespace to run command in
     '''
     global log_file
     if cmd.startswith('iptables ') or cmd.startswith('ip6tables ') or cmd.startswith('ebtables ') or cmd.startswith('arptables '):
-        cmd = EXECUTEABLE + " " + cmd
+        cmd = EXECUTABLE + " " + cmd
+
+    if netns:
+        cmd = "ip netns exec " + netns + " " + cmd
 
     print("command: {}".format(cmd), file=log_file)
     ret = subprocess.call(cmd, shell=True, universal_newlines=True,
@@ -176,17 +188,200 @@ def execute_cmd(cmd, filename, lineno):
     log_file.flush()
 
     # generic check for segfaults
-    if ret  == -11:
+    if ret == -11:
         reason = "command segfaults: " + cmd
         print_error(reason, filename, lineno)
     return ret
 
 
+def variant_res(res, variant, alt_res=None):
+    '''
+    Adjust expected result with given variant
+
+    If expected result is scoped to a variant, the other one yields a different
+    result. Therefore map @res to itself if given variant is current, use the
+    alternate result, @alt_res, if specified, invert @res otherwise.
+
+    :param res: expected result from test spec ("OK", "FAIL" or "NOMATCH")
+    :param variant: variant @res is scoped to by test spec ("NFT" or "LEGACY")
+    :param alt_res: optional expected result for the alternate variant.
+    '''
+    variant_executable = {
+        "NFT": "xtables-nft-multi",
+        "LEGACY": "xtables-legacy-multi"
+    }
+    res_inverse = {
+        "OK": "FAIL",
+        "FAIL": "OK",
+        "NOMATCH": "OK"
+    }
+
+    if variant_executable[variant] == EXECUTABLE:
+        return res
+    if alt_res is not None:
+        return alt_res
+    return res_inverse[res]
+
+def fast_run_possible(filename):
+    '''
+    Keep things simple, run only for simple test files:
+    - no external commands
+    - no multiple tables
+    - no variant-specific results
+    '''
+    table = None
+    rulecount = 0
+    for line in open(filename):
+        if line[0] in ["#", ":"] or len(line.strip()) == 0:
+            continue
+        if line[0] == "*":
+            if table or rulecount > 0:
+                return False
+            table = line.rstrip()[1:]
+        if line[0] in ["@", "%"]:
+            return False
+        if len(line.split(";")) > 3:
+            return False
+        rulecount += 1
+
+    return True
+
+def run_test_file_fast(iptables, filename, netns):
+    '''
+    Run a test file, but fast
+
+    :param filename: name of the file with the test rules
+    :param netns: network namespace to perform test run in
+    '''
+
+    f = open(filename)
+
+    rules = {}
+    table = "filter"
+    chain_array = []
+    tests = 0
+
+    for lineno, line in enumerate(f):
+        if line[0] == "#" or len(line.strip()) == 0:
+            continue
+
+        if line[0] == "*":
+            table = line.rstrip()[1:]
+            continue
+
+        if line[0] == ":":
+            chain_array = line.rstrip()[1:].split(",")
+            continue
+
+        if len(chain_array) == 0:
+            return -1
+
+        tests += 1
+
+        for chain in chain_array:
+            item = line.split(";")
+            rule = chain + " " + item[0]
+
+            if item[1] == "=":
+                rule_save = chain + " " + item[0]
+            else:
+                rule_save = chain + " " + item[1]
+
+            if iptables == EBTABLES and rule_save.find('-j') < 0:
+                rule_save += " -j CONTINUE"
+
+            res = item[2].rstrip()
+            if res != "OK":
+                rule = chain + " -t " + table + " " + item[0]
+                ret = run_test(iptables, rule, rule_save,
+                               res, filename, lineno + 1, netns)
+
+                if ret < 0:
+                    return -1
+                continue
+
+            if not chain in rules.keys():
+                rules[chain] = []
+            rules[chain].append((rule, rule_save))
+
+    restore_data = ["*" + table]
+    out_expect = []
+    for chain in ["PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"]:
+        if not chain in rules.keys():
+            continue
+        for rule in rules[chain]:
+            restore_data.append("-A " + rule[0])
+            out_expect.append("-A " + rule[1])
+    restore_data.append("COMMIT")
+
+    out_expect = "\n".join(out_expect)
+
+    # load all rules via iptables_restore
+
+    command = EXECUTABLE + " " + iptables + "-restore"
+    if netns:
+        command = "ip netns exec " + netns + " " + command
+
+    for line in restore_data:
+        print(iptables + "-restore: " + line, file=log_file)
+
+    proc = subprocess.Popen(command, shell = True, text = True,
+                            stdin = subprocess.PIPE,
+                            stdout = subprocess.PIPE,
+                            stderr = subprocess.PIPE)
+    restore_data = "\n".join(restore_data) + "\n"
+    out, err = proc.communicate(input = restore_data)
+
+    if proc.returncode == -11:
+        reason = iptables + "-restore segfaults!"
+        print_error(reason, filename, lineno)
+        msg = [iptables + "-restore segfault from:"]
+        msg.extend(["input: " + l for l in restore_data.split("\n")])
+        print("\n".join(msg), file=log_file)
+        return -1
+
+    if proc.returncode != 0:
+        print("%s-restore returned %d: %s" % (iptables, proc.returncode, err),
+              file=log_file)
+        return -1
+
+    # find all rules in iptables_save output
+
+    command = EXECUTABLE + " " + iptables + "-save"
+    if netns:
+        command = "ip netns exec " + netns + " " + command
+
+    proc = subprocess.Popen(command, shell = True,
+                            stdin = subprocess.PIPE,
+                            stdout = subprocess.PIPE,
+                            stderr = subprocess.PIPE)
+    out, err = proc.communicate()
+
+    if proc.returncode == -11:
+        reason = iptables + "-save segfaults!"
+        print_error(reason, filename, lineno)
+        return -1
+
+    cmd = iptables + " -F -t " + table
+    execute_cmd(cmd, filename, 0, netns)
+
+    out = out.decode('utf-8').rstrip()
+    if out.find(out_expect) < 0:
+        msg = ["dumps differ!"]
+        msg.extend(["expect: " + l for l in out_expect.split("\n")])
+        msg.extend(["got: " + l for l in out.split("\n")
+                                if not l[0] in ['*', ':', '#']])
+        print("\n".join(msg), file=log_file)
+        return -1
+
+    return tests
+
 def run_test_file(filename, netns):
     '''
     Runs a test file
 
     :param filename: name of the file with the test rules
+    :param netns: network namespace to perform test run in
     '''
     #
     # if this is not a test file, skip.
@@ -202,18 +397,26 @@ def run_test_file(filename, netns):
         iptables = IPTABLES
     elif "libarpt_" in filename:
         # only supported with nf_tables backend
-        if EXECUTEABLE != "xtables-nft-multi":
+        if EXECUTABLE != "xtables-nft-multi":
            return 0, 0
         iptables = ARPTABLES
     elif "libebt_" in filename:
         # only supported with nf_tables backend
-        if EXECUTEABLE != "xtables-nft-multi":
+        if EXECUTABLE != "xtables-nft-multi":
            return 0, 0
         iptables = EBTABLES
     else:
         # default to iptables if not known prefix
         iptables = IPTABLES
 
+    fast_failed = False
+    if fast_run_possible(filename):
+        tests = run_test_file_fast(iptables, filename, netns)
+        if tests > 0:
+            print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY))
+            return tests, tests
+        fast_failed = True
+
     f = open(filename)
 
     tests = 0
@@ -223,7 +426,7 @@ def run_test_file(filename, netns):
     total_test_passed = True
 
     if netns:
-        execute_cmd("ip netns add ____iptables-container-test", filename, 0)
+        execute_cmd("ip netns add " + netns, filename)
 
     for lineno, line in enumerate(f):
         if line[0] == "#" or len(line.strip()) == 0:
@@ -233,20 +436,11 @@ def run_test_file(filename, netns):
             chain_array = line.rstrip()[1:].split(",")
             continue
 
-        # external non-iptables invocation, executed as is.
-        if line[0] == "@":
-            external_cmd = line.rstrip()[1:]
-            if netns:
-                external_cmd = "ip netns exec ____iptables-container-test " + external_cmd
-            execute_cmd(external_cmd, filename, lineno)
-            continue
-
-        # external iptables invocation, executed as is.
-        if line[0] == "%":
+        # external command invocation, executed as is.
+        # detects iptables commands to prefix with EXECUTABLE automatically
+        if line[0] in ["@", "%"]:
             external_cmd = line.rstrip()[1:]
-            if netns:
-                external_cmd = "ip netns exec ____iptables-container-test " + EXECUTEABLE + " " + external_cmd
-            execute_cmd(external_cmd, filename, lineno)
+            execute_cmd(external_cmd, filename, lineno, netns)
             continue
 
         if line[0] == "*":
@@ -275,6 +469,14 @@ def run_test_file(filename, netns):
                 rule_save = chain + " " + item[1]
 
             res = item[2].rstrip()
+            if len(item) > 3:
+                variant = item[3].rstrip()
+                if len(item) > 4:
+                    alt_res = item[4].rstrip()
+                else:
+                    alt_res = None
+                res = variant_res(res, variant, alt_res)
+
             ret = run_test(iptables, rule, rule_save,
                            res, filename, lineno + 1, netns)
 
@@ -287,9 +489,12 @@ def run_test_file(filename, netns):
             passed += 1
 
     if netns:
-        execute_cmd("ip netns del ____iptables-container-test", filename, 0)
+        execute_cmd("ip netns del " + netns, filename)
     if total_test_passed:
-        print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY))
+        suffix = ""
+        if fast_failed:
+            suffix = maybe_colored('red', " but fast mode failed!", STDOUT_IS_TTY)
+        print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix)
 
     f.close()
     return tests, passed
@@ -353,7 +558,8 @@ def main():
                         help='Check for missing tests')
     parser.add_argument('-n', '--nftables', action='store_true',
                         help='Test iptables-over-nftables')
-    parser.add_argument('-N', '--netns', action='store_true',
+    parser.add_argument('-N', '--netns', action='store_const',
+                        const='____iptables-container-test',
                         help='Test netnamespace path')
     parser.add_argument('--no-netns', action='store_true',
                         help='Do not run testsuite in own network namespace')
@@ -366,14 +572,17 @@ def main():
         show_missing()
         return
 
-    global EXECUTEABLE
-    EXECUTEABLE = "xtables-legacy-multi"
+    variants = []
+    if args.legacy:
+        variants.append("legacy")
     if args.nftables:
-        EXECUTEABLE = "xtables-nft-multi"
+        variants.append("nft")
+    if len(variants) == 0:
+        variants = [ "legacy", "nft" ]
 
     if os.getuid() != 0:
         print("You need to be root to run this, sorry", file=sys.stderr)
-        return
+        return 77
 
     if not args.netns and not args.no_netns and not spawn_netns():
         print("Cannot run in own namespace, connectivity might break",
@@ -384,36 +593,51 @@ def main():
         os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir),
                                               os.getenv("PATH")))
 
-    test_files = 0
-    tests = 0
-    passed = 0
-
-    # setup global var log file
-    global log_file
-    try:
-        log_file = open(LOGFILE, 'w')
-    except IOError:
-        print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
-        return
-
-    if args.filename:
-        file_list = args.filename
-    else:
-        file_list = [os.path.join(EXTENSIONS_PATH, i)
-                     for i in os.listdir(EXTENSIONS_PATH)
-                     if i.endswith('.t')]
-        file_list.sort()
-
-    for filename in file_list:
-        file_tests, file_passed = run_test_file(filename, args.netns)
-        if file_tests:
-            tests += file_tests
-            passed += file_passed
-            test_files += 1
-
-    print("%d test files, %d unit tests, %d passed" % (test_files, tests, passed))
-    return passed - tests
-
+    total_test_files = 0
+    total_passed = 0
+    total_tests = 0
+    for variant in variants:
+        global EXECUTABLE
+        EXECUTABLE = "xtables-" + variant + "-multi"
+
+        test_files = 0
+        tests = 0
+        passed = 0
+
+        # setup global var log file
+        global log_file
+        try:
+            log_file = open(LOGFILE, 'w')
+        except IOError:
+            print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
+            return
+
+        if args.filename:
+            file_list = args.filename
+        else:
+            file_list = [os.path.join(EXTENSIONS_PATH, i)
+                         for i in os.listdir(EXTENSIONS_PATH)
+                         if i.endswith('.t')]
+            file_list.sort()
+
+        for filename in file_list:
+            file_tests, file_passed = run_test_file(filename, args.netns)
+            if file_tests:
+                tests += file_tests
+                passed += file_passed
+                test_files += 1
+
+        print("%s: %d test files, %d unit tests, %d passed"
+              % (variant, test_files, tests, passed))
+
+        total_passed += passed
+        total_tests += tests
+        total_test_files = max(total_test_files, test_files)
+
+    if len(variants) > 1:
+        print("total: %d test files, %d unit tests, %d passed"
+              % (total_test_files, total_tests, total_passed))
+    return total_passed - total_tests
 
 if __name__ == '__main__':
     sys.exit(main())