]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Move proof checking into a NSEC3Checker class
authorPetr Špaček <pspacek@isc.org>
Mon, 9 Jun 2025 08:46:34 +0000 (10:46 +0200)
committerPetr Špaček <pspacek@isc.org>
Tue, 29 Jul 2025 08:00:46 +0000 (10:00 +0200)
bin/tests/system/dnssec/tests_nsec3.py

index 76e0faf599cde087f1548f56b26f9a42ca364a09..efd6619e2a7418a91424c1b887cf0727b4899e9a 100755 (executable)
 # See the COPYRIGHT file distributed with this work for additional
 # information regarding copyright ownership.
 
+from dataclasses import dataclass
 import os
 from pathlib import Path
+from typing import Optional, Tuple
 
 import pytest
 
 pytest.importorskip("dns", minversion="2.5.0")
-from dns.dnssectypes import NSEC3Hash
 import dns.dnssec
 import dns.message
 import dns.name
@@ -44,13 +45,14 @@ ZONE = isctest.name.ZoneAnalyzer.read_path(
 )
 
 
-def do_test_query(qname, qtype, server, named_port) -> dns.message.Message:
+def do_test_query(
+    qname, qtype, server, named_port
+) -> Tuple[dns.message.Message, "NSEC3Checker"]:
     query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True)
     response = isctest.query.tcp(query, server, named_port, timeout=TIMEOUT)
     isctest.check.is_response_to(response, query)
     assert response.rcode() in (dns.rcode.NOERROR, dns.rcode.NXDOMAIN)
-    NSEC3Checker(response)
-    return response
+    return response, NSEC3Checker(response)
 
 
 def assume_nx_and_no_delegation(qname):
@@ -67,56 +69,6 @@ def assume_nx_and_no_delegation(qname):
     )
 
 
-def nsec3_covers(rrset: dns.rrset.RRset, hashed_name: dns.name.Name) -> bool:
-    """
-    Test if 'hashed_name' is covered by an NSEC3 record in 'rrset', i.e. the name does not exist.
-    """
-    prev_name = rrset.name
-
-    for nsec3 in rrset:
-        assert nsec3.flags == 0, "opt-out not supported by test logic"
-        next_name = nsec3.next_name(SUFFIX)
-
-        # Single name case.
-        if prev_name == next_name:
-            return prev_name != hashed_name
-
-        # Standard case.
-        if prev_name < next_name:
-            if prev_name < hashed_name < next_name:
-                return True
-
-        # The cover wraps.
-        if next_name < prev_name:
-            # Case 1: The covered name is at the end of the chain.
-            if hashed_name > prev_name:
-                return True
-            # Case 2: The covered name is at the start of the chain.
-            if hashed_name < next_name:
-                return True
-    return False
-
-
-def check_nsec3_covers(name: dns.name.Name, response: dns.message.Message) -> None:
-    """Given name provably does not exist"""
-    name_is_covered = False
-
-    nhash = dns.dnssec.nsec3_hash(
-        name, salt=None, iterations=0, algorithm=NSEC3Hash.SHA1
-    )
-    hashed_name = dns.name.from_text(nhash, SUFFIX)
-
-    for rrset in response.authority:
-        if rrset.match(dns.rdataclass.IN, dns.rdatatype.NSEC3, dns.rdatatype.NONE):
-            name_is_covered = nsec3_covers(rrset, hashed_name)
-            if name_is_covered:
-                break
-
-    assert (
-        name_is_covered
-    ), f"Expected covering NSEC3 for {name} (hash={nhash}) not found:\n {response}"
-
-
 @pytest.mark.parametrize(
     "server", [pytest.param(AUTH, id="ns3"), pytest.param(RESOLVER, id="ns4")]
 )
@@ -172,52 +124,34 @@ def test_wildcard_nodata(server, qname: dns.name.Name, named_port: int) -> None:
     check_wildcard_nodata(server, named_port, qname)
 
 
-def check_nsec3_owner(owner: dns.name.Name, response):
-    """Check response has NSEC3 RR matching given owner name, i.e. the name exists."""
-    name_hash = dns.dnssec.nsec3_hash(
-        owner, salt=None, iterations=0, algorithm=NSEC3Hash.SHA1
-    )
-    nsec3_owner = dns.name.from_text(name_hash, SUFFIX)
-
-    nsec3_found = False
-    for rrset in response.authority:
-        if rrset.match(
-            nsec3_owner, dns.rdataclass.IN, dns.rdatatype.NSEC3, dns.rdatatype.NONE
-        ):
-            nsec3_found = True
-    assert (
-        nsec3_found
-    ), f"Expected matching NSEC3 for {owner} (hash={name_hash}) not found:\n{response}"
-
-
 def check_wildcard_nodata(server, named_port: int, qname: dns.name.Name) -> None:
-    response = do_test_query(qname, dns.rdatatype.AAAA, server, named_port)
+    response, nsec3check = do_test_query(qname, dns.rdatatype.AAAA, server, named_port)
     assert response.rcode() is dns.rcode.NOERROR
 
     ce, nce = ZONE.closest_encloser(qname)
-    check_nsec3_owner(ce, response)
-    check_nsec3_covers(nce, response)
+    nsec3check.prove_name_exists(ce)
+    nsec3check.prove_name_does_not_exist(nce)
 
     wname = ZONE.source_of_synthesis(qname)
     # expecting proof that wildcard owner does not have rdatatype requested
-    check_nsec3_owner(wname, response)
+    nsec3check.prove_name_exists(wname)
 
 
 def check_nxdomain(server, named_port: int, qname: dns.name.Name) -> None:
-    response = do_test_query(qname, dns.rdatatype.A, server, named_port)
+    response, nsec3check = do_test_query(qname, dns.rdatatype.A, server, named_port)
     assert response.rcode() is dns.rcode.NXDOMAIN
 
     ce, nce = ZONE.closest_encloser(qname)
-    check_nsec3_owner(ce, response)
-    check_nsec3_covers(nce, response)
+    nsec3check.prove_name_exists(ce)
+    nsec3check.prove_name_does_not_exist(nce)
 
     wname = ZONE.source_of_synthesis(qname)
-    check_nsec3_covers(wname, response)
+    nsec3check.prove_name_does_not_exist(wname)
 
 
 def check_wildcard_synthesis(server, named_port: int, qname: dns.name.Name) -> None:
     """Expect wildcard response with a signed A RRset"""
-    response = do_test_query(qname, dns.rdatatype.A, server, named_port)
+    response, nsec3check = do_test_query(qname, dns.rdatatype.A, server, named_port)
     assert response.rcode() is dns.rcode.NOERROR
 
     answer_sig = response.get_rrset(
@@ -250,8 +184,17 @@ def check_wildcard_synthesis(server, named_port: int, qname: dns.name.Name) -> N
     assert ce == qname.split(wildcard_parent_labels)[1]
     # ce is proven to exist by the RRSIG
     assert nce == qname.split(wildcard_parent_labels + 1)[1]
-    # nce must be proven to NOT exist
-    check_nsec3_covers(nce, response)
+    nsec3check.prove_name_does_not_exist(nce)
+
+
+@dataclass(kw_only=True, frozen=True)
+class NSEC3Params:
+    """Common values from a single DNS response"""
+
+    algorithm: int
+    flags: int
+    iterations: int
+    salt: Optional[bytes]
 
 
 class NSEC3Checker:
@@ -273,6 +216,7 @@ class NSEC3Checker:
         }
         first = True
         owners_seen = set()
+        self.rrsets = []
         for rrset in response.authority:
             if not rrset.match(
                 dns.rdataclass.IN, dns.rdatatype.NSEC3, dns.rdatatype.NONE
@@ -303,7 +247,73 @@ class NSEC3Checker:
                         current == value_seen
                     ), f"inconsistent {attr_name}\n{response}"
             first = False
+            self.rrsets.append(rrset)
 
         assert attrs_seen["algorithm"] is not None, f"no NSEC3 found\n{response}"
-        self.attrs = attrs_seen
-        self.response = response
+        self.params = NSEC3Params(**attrs_seen)  # type: NSEC3Params
+        self.response = response  # type: dns.message.Message
+
+    @staticmethod
+    def nsec3_covers(rrset: dns.rrset.RRset, hashed_name: dns.name.Name) -> bool:
+        """
+        Test if 'hashed_name' is covered by an NSEC3 record in 'rrset', i.e. the name does not exist.
+        """
+        prev_name = rrset.name
+
+        assert len(rrset) == 1
+        nsec3 = rrset[0]
+        assert isinstance(nsec3, dns.rdtypes.ANY.NSEC3.NSEC3)
+        assert nsec3.flags == 0, "opt-out not supported by test logic"
+        next_name = nsec3.next_name(SUFFIX)
+
+        # Single name case.
+        if prev_name == next_name:
+            return prev_name != hashed_name
+
+        # Standard case.
+        if prev_name < next_name:
+            if prev_name < hashed_name < next_name:
+                return True
+
+        # The cover wraps.
+        if next_name < prev_name:
+            # Case 1: The covered name is at the end of the chain.
+            if hashed_name > prev_name:
+                return True
+            # Case 2: The covered name is at the start of the chain.
+            if hashed_name < next_name:
+                return True
+        return False
+
+    def hash_name(self, name: dns.name.Name) -> dns.name.Name:
+        nhash = dns.dnssec.nsec3_hash(
+            name,
+            salt=self.params.salt,
+            iterations=self.params.iterations,
+            algorithm=self.params.algorithm,
+        )
+        return dns.name.from_text(nhash, SUFFIX)
+
+    def prove_name_does_not_exist(self, name: dns.name.Name) -> dns.rrset.RRset:
+        """Hash of a given name must fall between an NSEC3 owner and 'next' name"""
+        hashed_name = self.hash_name(name)
+        for rrset in self.rrsets:
+            name_is_covered = self.nsec3_covers(rrset, hashed_name)
+            if name_is_covered:
+                return rrset
+
+        assert (
+            False
+        ), f"Expected covering NSEC3 for {name} (hash={hashed_name}) not found:\n{self.response}"
+
+    def prove_name_exists(self, owner: dns.name.Name) -> dns.rrset.RRset:
+        """Check response has NSEC3 RR matching given owner name, i.e. the name exists."""
+        nsec3_owner = self.hash_name(owner)
+        for rrset in self.rrsets:
+            if rrset.match(
+                nsec3_owner, dns.rdataclass.IN, dns.rdatatype.NSEC3, dns.rdatatype.NONE
+            ):
+                return rrset
+        assert (
+            False
+        ), f"Expected matching NSEC3 for {owner} (hash={nsec3_owner}) not found:\n{self.response}"