]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
add helper functions to isctest
authorEvan Hunt <each@isc.org>
Thu, 26 Jun 2025 22:19:45 +0000 (15:19 -0700)
committerEvan Hunt <each@isc.org>
Tue, 29 Jul 2025 22:58:11 +0000 (22:58 +0000)
added some helper functions in isctest to reduce code repetition
in dnssec-related tests:

- isctest.check.adflag() - checks that a response contains AD=1
- isctest.check.noadflag() - checks that a response contains AD=0

- isctest.check.rdflag() - checks that a response contains RD=1
- isctest.check.nordflag() - checks that a response contains RD=0

- isctest.check.answer_count_eq() - checks the answer count is correct
- isctest.check.additional_count_eq() - same for authority count
- isctest.check.authority_count_eq() - same for additional count

- isctest.check.same_data() - check that two message have the
                              same rcode and data
- isctest.check.same_answer() - check that two message have the same
                                rcode and answer

- isctest.dnssec.msg() - a wrapper for dns.message.make_query() that
                         creates a query message similar to dig +dnssec:
                         use_edns=True, want_dnssec=True,
                         and flags are set to (RD|AD) by default, but
                         options exist to disable AD or enable CD.
                         (to generate non-DNSSEC queries, use
                         message.make_query() directly.)

(cherry picked from commit b69097f139154ca0d2177f35632400200d220bdc)

bin/tests/system/isctest/__init__.py
bin/tests/system/isctest/check.py
bin/tests/system/isctest/dnssec.py [new file with mode: 0644]
bin/tests/system/isctest/instance.py

index 08428c9372dfc43e9dc0a2cdb30547578f87c7ac..0da6b5bb51247722c810888bea173bc37c3dc4fb 100644 (file)
@@ -10,6 +10,7 @@
 # information regarding copyright ownership.
 
 from . import check
+from . import dnssec
 from . import instance
 from . import query
 from . import name
index afcc2db6ff44ffa71cfba85df49d5e22d6708d17..d8cff71ea4803cdfcc9e400ba191699072a8c37b 100644 (file)
@@ -12,6 +12,7 @@
 import shutil
 from typing import Optional
 
+import dns.flags
 import dns.rcode
 import dns.message
 import dns.zone
@@ -40,6 +41,53 @@ def servfail(message: dns.message.Message) -> None:
     rcode(message, dns_rcode.SERVFAIL)
 
 
+def adflag(message: dns.message.Message) -> None:
+    assert (message.flags & dns.flags.AD) != 0, str(message)
+
+
+def noadflag(message: dns.message.Message) -> None:
+    assert (message.flags & dns.flags.AD) == 0, str(message)
+
+
+def rdflag(message: dns.message.Message) -> None:
+    assert (message.flags & dns.flags.RD) != 0, str(message)
+
+
+def nordflag(message: dns.message.Message) -> None:
+    assert (message.flags & dns.flags.RD) == 0, str(message)
+
+
+def section_equal(sec1: list, sec2: list) -> None:
+    # convert an RRset to a normalized string (lower case, TTL=0)
+    # so it can be used as a set member.
+    def normalized(rrset):
+        ttl = rrset.ttl
+        rrset.ttl = 0
+        s = str(rrset).lower()
+        rrset.ttl = ttl
+        return s
+
+    # convert the section contents to sets before comparison,
+    # in case they aren't in the same sort order.
+    set1 = {normalized(item) for item in sec1}
+    set2 = {normalized(item) for item in sec2}
+    assert set1 == set2
+
+
+def same_data(res1: dns.message.Message, res2: dns.message.Message):
+    assert res1.question == res2.question
+    section_equal(res1.answer, res2.answer)
+    section_equal(res1.authority, res2.authority)
+    section_equal(res1.additional, res2.additional)
+    assert res1.rcode() == res2.rcode()
+
+
+def same_answer(res1: dns.message.Message, res2: dns.message.Message):
+    assert res1.question == res2.question
+    section_equal(res1.answer, res2.answer)
+    assert res1.rcode() == res2.rcode()
+
+
 def rrsets_equal(
     first_rrset: dns.rrset.RRset,
     second_rrset: dns.rrset.RRset,
@@ -114,6 +162,30 @@ def empty_answer(message: dns.message.Message) -> None:
     assert not message.answer, str(message)
 
 
+def answer_count_eq(m: dns.message.Message, expected: int):
+    count = sum(max(1, len(rrs)) for rrs in m.answer)
+    assert count == expected, str(m)
+
+
+def authority_count_eq(m: dns.message.Message, expected: int):
+    count = sum(max(1, len(rrs)) for rrs in m.authority)
+    assert count == expected, str(m)
+
+
+def additional_count_eq(m: dns.message.Message, expected: int):
+    count = sum(max(1, len(rrs)) for rrs in m.additional)
+
+    # add one for the OPT?
+    opt = bool(m.opt) if hasattr(m, "opt") else bool(m.edns >= 0)
+    count += 1 if opt else 0
+
+    # add one for the TSIG?
+    tsig = bool(m.tsig) if hasattr(m, "tsig") else m.had_tsig
+    count += 1 if tsig else 0
+
+    assert count == expected, str(m)
+
+
 def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None:
     single_question(response)
     single_question(query)
diff --git a/bin/tests/system/isctest/dnssec.py b/bin/tests/system/isctest/dnssec.py
new file mode 100644 (file)
index 0000000..2096711
--- /dev/null
@@ -0,0 +1,25 @@
+# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
+#
+# SPDX-License-Identifier: MPL-2.0
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0.  If a copy of the MPL was not distributed with this
+# file, you can obtain one at https://mozilla.org/MPL/2.0/.
+#
+# See the COPYRIGHT file distributed with this work for additional
+# information regarding copyright ownership.
+
+from dns import flags, message
+
+
+def msg(qname: str, qtype: str, **kwargs):
+    headerflags = flags.RD
+    # "ad" is on by default
+    if "ad" not in kwargs or not kwargs["ad"]:
+        headerflags |= flags.AD
+    # "cd" is off by default
+    if "cd" in kwargs and kwargs["cd"]:
+        headerflags |= flags.CD
+    return message.make_query(
+        qname, qtype, use_edns=True, want_dnssec=True, flags=headerflags
+    )
index ea3a8fadfbd4288c0adedbdc785ae88aecdb30c2..763a8dbb596925b2b678c6d2faf6598dcac7b534 100644 (file)
@@ -137,13 +137,13 @@ class NamedInstance:
         """
         return WatchLogFromHere(self.log.path, timeout)
 
-    def reconfigure(self) -> None:
+    def reconfigure(self, **kwargs) -> None:
         """
         Reconfigure this named `instance` and wait until reconfiguration is
         finished.  Raise an `RNDCException` if reconfiguration fails.
         """
         with self.watch_log_from_here() as watcher:
-            self.rndc("reconfig")
+            self.rndc("reconfig", **kwargs)
             watcher.wait_for_line("any newly configured zones are now loaded")
 
     def _rndc_log(self, command: str, response: str) -> None: