]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
tests/py: Don't read expected payload for each table
authorPhil Sutter <phil@nwl.cc>
Tue, 8 May 2018 11:08:43 +0000 (13:08 +0200)
committerPablo Neira Ayuso <pablo@netfilter.org>
Fri, 11 May 2018 10:17:45 +0000 (12:17 +0200)
When testing rule adding to different table families, expected payload
was read for each tested family again. Instead, read it just once and
just try to read a family-specific payload for each tested family.

Signed-off-by: Phil Sutter <phil@nwl.cc>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
tests/py/nft-test.py

index 4a01fc23b7d55b64056a357e17ce7dd1999d0761..6684bcd1081d785c4e9251f6fc69711ae4a887ca 100755 (executable)
@@ -648,23 +648,25 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
         print_error(reason, filename, lineno)
         return [-1, warning, error, unit_tests]
 
-    payload_expected = []
-
-    for table in table_list:
+    if rule[1].strip() == "ok":
         try:
-            payload_log = open("%s.payload.%s" % (filename_path, table.family))
-        except IOError:
             payload_log = open("%s.payload" % filename_path)
+            payload_expected = payload_find_expected(payload_log, rule[0])
+        except:
+            payload_expected = None
 
+    for table in table_list:
         if rule[1].strip() == "ok":
+            table_payload_expected = None
             try:
-                payload_expected.index(rule[0])
-            except ValueError:
-                payload_expected = payload_find_expected(payload_log, rule[0])
-
+                payload_log = open("%s.payload.%s" % (filename_path, table.family))
+                table_payload_expected = payload_find_expected(payload_log, rule[0])
+            except:
                 if not payload_expected:
                     print_error("did not find payload information for "
                                 "rule '%s'" % rule[0], payload_log.name, 1)
+            if not table_payload_expected:
+                table_payload_expected = payload_expected
 
         for table_chain in table.chains:
             chain = chain_get_by_name(table_chain)
@@ -697,7 +699,7 @@ def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
                 continue
 
             # Check for matching payload
-            if state == "ok" and not payload_check(payload_expected,
+            if state == "ok" and not payload_check(table_payload_expected,
                                                    payload_log, cmd):
                 error += 1
                 gotf = open("%s.payload.got" % filename_path, 'a')