]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Introduce pytest kasp library
authorMatthijs Mekking <matthijs@isc.org>
Mon, 2 Sep 2024 15:51:32 +0000 (17:51 +0200)
committerMatthijs Mekking <matthijs@isc.org>
Wed, 16 Oct 2024 15:16:02 +0000 (17:16 +0200)
Write initial pytest kasp library. This contains everything that is
required for testing Offline KSK functionality with pytest.

This includes:
- addtime: adding a value to a timing metadata
- get_timing_metdata: retrieve timing metadata from keyfile
- get_metadata/get_keystate: retrieve metadata from statefile
- get_keytag: retrieve keytag from base keyfile string
- get_keyrole: get key role from statefile
- dnskey_equals: compare DNSKEY record from file against a string
- cds_equals: compare CDS derived from file against a string
- zone_is_signed: wait until a zone is completely signed
- dnssec_verify: verify a DNSSEC signed zone with dnssec-verify
- check_dnssecstatus: check rndc dnssec -status output
- check_signatures: check that signatures for a given RRset are correct
- check_dnskeys: check that the published DNSKEY RRset is correct
- check_cds: check that the published CDS RRset is correct
- check_apex: check SOA, DNSKEY, CDNSKEY, and CDS RRset
- check_subdomain: check an RRset below the apex

bin/tests/system/isctest/__init__.py
bin/tests/system/isctest/kasp.py [new file with mode: 0644]

index 5426e450feaca3c3dc929f97272e4431c53208d1..756f6b6b381004e03b3a7c908d49195815a8af2a 100644 (file)
@@ -12,6 +12,7 @@
 from . import check
 from . import instance
 from . import query
+from . import kasp
 from . import name
 from . import rndc
 from . import run
diff --git a/bin/tests/system/isctest/kasp.py b/bin/tests/system/isctest/kasp.py
new file mode 100644 (file)
index 0000000..685250f
--- /dev/null
@@ -0,0 +1,541 @@
+# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
+#
+# SPDX-License-Identifier: MPL-2.0
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0.  If a copy of the MPL was not distributed with this
+# file, you can obtain one at https://mozilla.org/MPL/2.0/.
+#
+# See the COPYRIGHT file distributed with this work for additional
+# information regarding copyright ownership.
+
+import os
+import time
+
+from datetime import datetime
+from datetime import timedelta
+
+import dns
+import isctest.log
+
+
+DEFAULT_TTL = 300
+
+
+def _save_response(response, fname):
+    with open(fname, "w", encoding="utf-8") as file:
+        file.write(response.to_text())
+
+
+def _query(server, qname, qtype, outfile=None):
+    query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True)
+    try:
+        response = dns.query.tcp(query, server.ip, port=server.ports.dns, timeout=3)
+    except dns.exception.Timeout:
+        isctest.log.debug(f"query timeout for query {qname} {qtype} to {server.ip}")
+        return None
+
+    if outfile is not None:
+        _save_response(response, outfile)
+
+    return response
+
+
+def addtime(value, plus):
+    # Get timing metadata from a value plus additional time.
+    # Convert "%Y%m%d%H%M%S" format to epoch seconds.
+    # Then, add the additional time (can be negative).
+    now = datetime.strptime(value, "%Y%m%d%H%M%S")
+    delta = timedelta(seconds=plus)
+    then = now + delta
+    return then.strftime("%Y%m%d%H%M%S")
+
+
+def get_timing_metadata(key, metadata, keydir=None, offset=0, must_exist=True):
+    value = "0"
+
+    if keydir is not None:
+        keyfile = "{}/{}.key".format(keydir, key)
+    else:
+        keyfile = "{}.key".format(key)
+
+    with open(keyfile, "r", encoding="utf-8") as file:
+        for line in file:
+            if "; {}".format(metadata) in line:
+                value = line.split()[2]
+                break
+
+    if must_exist:
+        assert int(value) > 0
+
+    if int(value) > 0:
+        return addtime(value, offset)
+
+    return "0"
+
+
+def get_metadata(key, metadata, keydir=None, must_exist=True):
+    if keydir is not None:
+        statefile = "{}/{}.state".format(keydir, key)
+    else:
+        statefile = "{}.state".format(key)
+
+    value = "undefined"
+    with open(statefile, "r", encoding="utf-8") as file:
+        for line in file:
+            if f"{metadata}: " in line:
+                value = line.split()[1]
+                break
+
+    if must_exist:
+        assert value != "undefined"
+
+    return value
+
+
+def get_keystate(key, metadata, keydir=None, must_exist=True):
+
+    return get_metadata(key, metadata, keydir, must_exist)
+
+
+def get_keytag(key):
+    return int(key[-5:])
+
+
+def get_keyrole(key, keydir=None):
+    ksk = "no"
+    zsk = "no"
+
+    if keydir is not None:
+        statefile = "{}/{}.state".format(keydir, key)
+    else:
+        statefile = "{}.state".format(key)
+
+    with open(statefile, "r", encoding="utf-8") as file:
+        for line in file:
+            if "KSK: " in line:
+                ksk = line.split()[1]
+            if "ZSK: " in line:
+                zsk = line.split()[1]
+
+    return ksk == "yes", zsk == "yes"
+
+
+def dnskey_equals(key, value, keydir=None, cdnskey=False):
+    if keydir is not None:
+        keyfile = f"{keydir}/{key}.key"
+    else:
+        keyfile = f"{key}.key"
+
+    dnskey = value.split()
+
+    if cdnskey:
+        # fourth element is the rrtype
+        assert dnskey[3] == "CDNSKEY"
+        dnskey[3] = "DNSKEY"
+
+    dnskey_fromfile = []
+    rdata = " ".join(dnskey[:7])
+
+    with open(keyfile, "r", encoding="utf-8") as file:
+        for line in file:
+            if f"{rdata}" in line:
+                dnskey_fromfile = line.split()
+
+    pubkey_fromfile = "".join(dnskey_fromfile[7:])
+    pubkey_fromwire = "".join(dnskey[7:])
+
+    return pubkey_fromfile == pubkey_fromwire
+
+
+def cds_equals(key, value, alg, keydir=None):
+    if keydir is not None:
+        keyfile = f"{keydir}/{key}.key"
+    else:
+        keyfile = f"{key}.key"
+
+    cds = value.split()
+
+    dsfromkey_command = [
+        *os.environ.get("DSFROMKEY").split(),
+        "-T",
+        "3600",
+        "-a",
+        alg,
+        "-C",
+        "-w",
+        keyfile,
+    ]
+
+    out = isctest.run.cmd(dsfromkey_command, log_stdout=True)
+    dsfromkey = out.stdout.decode("utf-8").split()
+    index = 6
+    while index < len(cds):
+        dsfromkey[index] = dsfromkey[index].lower()
+        index += 1
+
+    rdata_fromfile = " ".join(dsfromkey[:7])
+    rdata_fromwire = " ".join(cds[:7])
+    if rdata_fromfile != rdata_fromwire:
+        isctest.log.debug(f"CDS RDATA MISMATCH: {rdata_fromfile} - {rdata_fromwire}")
+        return False
+
+    digest_fromfile = "".join(cds[7:])
+    digest_fromwire = "".join(cds[7:])
+    if digest_fromfile != digest_fromwire:
+        isctest.log.debug(f"CDS DIGEST MISMATCH: {digest_fromfile} - {digest_fromwire}")
+        return False
+
+    return digest_fromfile == digest_fromwire
+
+
+def zone_is_signed(server, zone):
+    addr = server.ip
+    fqdn = f"{zone}."
+
+    # wait until zone is fully signed
+    signed = False
+    for _ in range(10):
+        response = _query(server, fqdn, dns.rdatatype.NSEC)
+        if not isinstance(response, dns.message.Message):
+            isctest.log.debug(f"no response for {fqdn} NSEC from {addr}")
+        elif response.rcode() != dns.rcode.NOERROR:
+            rcode = dns.rcode.to_text(response.rcode())
+            isctest.log.debug(f"{rcode} response for {fqdn} NSEC from {addr}")
+        else:
+            has_nsec = False
+            has_rrsig = False
+            for rr in response.answer:
+                if not has_nsec:
+                    has_nsec = rr.match(
+                        dns.name.from_text(fqdn),
+                        dns.rdataclass.IN,
+                        dns.rdatatype.NSEC,
+                        dns.rdatatype.NONE,
+                    )
+                if not has_rrsig:
+                    has_rrsig = rr.match(
+                        dns.name.from_text(fqdn),
+                        dns.rdataclass.IN,
+                        dns.rdatatype.RRSIG,
+                        dns.rdatatype.NSEC,
+                    )
+
+            if not has_nsec:
+                isctest.log.debug(
+                    f"missing apex {fqdn} NSEC record in response from {addr}"
+                )
+            if not has_rrsig:
+                isctest.log.debug(
+                    f"missing {fqdn} NSEC signature in response from {addr}"
+                )
+
+            signed = has_nsec and has_rrsig
+
+        if signed:
+            break
+
+        time.sleep(1)
+
+    assert signed
+
+
+def dnssec_verify(server, zone):
+    # Check if zone if DNSSEC valid with dnssec-verify.
+    fqdn = f"{zone}."
+    transfer = _query(server, fqdn, dns.rdatatype.AXFR)
+    if not isinstance(transfer, dns.message.Message):
+        isctest.log.debug(f"no response for {fqdn} AXFR from {server.ip}")
+    elif transfer.rcode() != dns.rcode.NOERROR:
+        rcode = dns.rcode.to_text(transfer.rcode())
+        isctest.log.debug(f"{rcode} response for {fqdn} AXFR from {server.ip}")
+    else:
+        zonefile = f"{zone}.axfr"
+        with open(zonefile, "w", encoding="utf-8") as file:
+            for rr in transfer.answer:
+                file.write(rr.to_text())
+                file.write("\n")
+
+    verify_command = [*os.environ.get("VERIFY").split(), "-z", "-o", zone, zonefile]
+
+    isctest.run.cmd(verify_command)
+
+
+def check_dnssecstatus(server, zone, keys, policy=None, view=None):
+    # Call rndc dnssec -status on 'server' for 'zone'. Expect 'policy' in
+    # the output. This is a loose verification, it just tests if the right
+    # policy name is returned, and if all expected keys are listed.
+    response = ""
+    if view is None:
+        response = server.rndc("dnssec -status {}".format(zone), log=False)
+    else:
+        response = server.rndc("dnssec -status {} in {}".format(zone, view), log=False)
+
+    if policy is None:
+        assert "Zone does not have dnssec-policy" in response
+        return
+
+    assert "dnssec-policy: {}".format(policy) in response
+
+    for key in keys:
+        keytag = get_keytag(key)
+        assert "key: {}".format(keytag) in response
+
+
+# pylint: disable=too-many-locals,too-many-branches
+def _check_signatures(signatures, covers, fqdn, keys, keydir=None):
+    now = datetime.now().strftime("%Y%m%d%H%M%S")
+    numsigs = 0
+    zrrsig = True
+    if covers in [dns.rdatatype.DNSKEY, dns.rdatatype.CDNSKEY, dns.rdatatype.CDS]:
+        zrrsig = False
+    krrsig = not zrrsig
+
+    for key in keys:
+        keytag = get_keytag(key)
+        ksk, zsk = get_keyrole(key, keydir=keydir)
+        activate = get_timing_metadata(key, "Activate", keydir=keydir)
+        inactive = get_timing_metadata(key, "Inactive", keydir=keydir, must_exist=False)
+
+        active = int(now) >= int(activate)
+        retired = int(inactive) != 0 and int(inactive) <= int(now)
+        signing = active and not retired
+
+        if not signing:
+            for rrsig in signatures:
+                assert f"{keytag} {fqdn}" not in rrsig
+            continue
+
+        if zrrsig and zsk:
+            has_rrsig = False
+            for rrsig in signatures:
+                if f"{keytag} {fqdn}" in rrsig:
+                    has_rrsig = True
+                    break
+            assert has_rrsig
+            numsigs += 1
+
+        if zrrsig and not zsk:
+            for rrsig in signatures:
+                assert f"{keytag} {fqdn}" not in rrsig
+
+        if krrsig and ksk:
+            has_rrsig = False
+            for rrsig in signatures:
+                if f"{keytag} {fqdn}" in rrsig:
+                    has_rrsig = True
+                    break
+            assert has_rrsig
+            numsigs += 1
+
+        if krrsig and not ksk:
+            for rrsig in signatures:
+                assert f"{keytag} {fqdn}" not in rrsig
+
+    return numsigs
+
+
+# pylint: disable=too-many-arguments
+def check_signatures(rrset, covers, fqdn, ksks, zsks, kskdir=None, zskdir=None):
+    # Check if signatures with covering type are signed with the right keys.
+    # The right keys are the ones that expect a signature and have the
+    # correct role.
+    numsigs = 0
+
+    signatures = []
+    for rr in rrset:
+        for rdata in rr:
+            rdclass = dns.rdataclass.to_text(rr.rdclass)
+            rdtype = dns.rdatatype.to_text(rr.rdtype)
+            rrsig = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
+            signatures.append(rrsig)
+
+    numsigs += _check_signatures(signatures, covers, fqdn, ksks, keydir=kskdir)
+    numsigs += _check_signatures(signatures, covers, fqdn, zsks, keydir=zskdir)
+
+    assert numsigs == len(signatures)
+
+
+def _check_dnskeys(dnskeys, keys, keydir=None, cdnskey=False):
+    now = datetime.now().strftime("%Y%m%d%H%M%S")
+    numkeys = 0
+
+    publish_md = "Publish"
+    delete_md = "Delete"
+    if cdnskey:
+        publish_md = f"Sync{publish_md}"
+        delete_md = f"Sync{delete_md}"
+
+    for key in keys:
+        publish = get_timing_metadata(key, publish_md, keydir=keydir)
+        delete = get_timing_metadata(key, delete_md, keydir=keydir, must_exist=False)
+        published = int(now) >= int(publish)
+        removed = int(delete) != 0 and int(delete) <= int(now)
+
+        if not published or removed:
+            for dnskey in dnskeys:
+                assert not dnskey_equals(key, dnskey, keydir=keydir, cdnskey=cdnskey)
+            continue
+
+        has_dnskey = False
+        for dnskey in dnskeys:
+            if dnskey_equals(key, dnskey, keydir=keydir, cdnskey=cdnskey):
+                has_dnskey = True
+                break
+
+        assert has_dnskey
+        numkeys += 1
+
+    return numkeys
+
+
+# pylint: disable=too-many-arguments
+def check_dnskeys(rrset, ksks, zsks, kskdir=None, zskdir=None, cdnskey=False):
+    # Check if the correct DNSKEY records are published. If the current time
+    # is between the timing metadata 'publish' and 'delete', the key must have
+    # a DNSKEY record published. If 'cdnskey' is True, check against CDNSKEY
+    # records instead.
+    numkeys = 0
+
+    dnskeys = []
+    for rr in rrset:
+        for rdata in rr:
+            rdclass = dns.rdataclass.to_text(rr.rdclass)
+            rdtype = dns.rdatatype.to_text(rr.rdtype)
+            dnskey = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
+            dnskeys.append(dnskey)
+
+    numkeys += _check_dnskeys(dnskeys, ksks, keydir=kskdir, cdnskey=cdnskey)
+    if not cdnskey:
+        numkeys += _check_dnskeys(dnskeys, zsks, keydir=zskdir)
+
+    assert numkeys == len(dnskeys)
+
+
+# pylint: disable=too-many-locals
+def check_cds(rrset, keys, keydir=None):
+    # Check if the correct CDS records are published. If the current time
+    # is between the timing metadata 'publish' and 'delete', the key must have
+    # a DNSKEY record published. If 'cdnskey' is True, check against CDNSKEY
+    # records instead.
+    now = datetime.now().strftime("%Y%m%d%H%M%S")
+    numcds = 0
+
+    cdss = []
+    for rr in rrset:
+        for rdata in rr:
+            rdclass = dns.rdataclass.to_text(rr.rdclass)
+            rdtype = dns.rdatatype.to_text(rr.rdtype)
+            cds = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
+            cdss.append(cds)
+
+    for key in keys:
+        ksk, _ = get_keyrole(key, keydir=keydir)
+        assert ksk
+
+        publish = get_timing_metadata(key, "SyncPublish", keydir=keydir)
+        delete = get_timing_metadata(key, "SyncDelete", keydir=keydir, must_exist=False)
+        published = int(now) >= int(publish)
+        removed = int(delete) != 0 and int(delete) <= int(now)
+        if not published or removed:
+            for cds in cdss:
+                assert not cds_equals(key, cds, "SHA-256", keydir=keydir)
+            continue
+
+        has_cds = False
+        for cds in cdss:
+            if cds_equals(key, cds, "SHA-256", keydir=keydir):
+                has_cds = True
+                break
+
+        assert has_cds
+        numcds += 1
+
+    assert numcds == len(cdss)
+
+
+def _query_rrset(server, fqdn, qtype):
+    response = _query(server, fqdn, qtype)
+    assert response.rcode() == dns.rcode.NOERROR
+
+    rrs = []
+    rrsigs = []
+    for rrset in response.answer:
+        if rrset.match(
+            dns.name.from_text(fqdn), dns.rdataclass.IN, dns.rdatatype.RRSIG, qtype
+        ):
+            rrsigs.append(rrset)
+        elif rrset.match(
+            dns.name.from_text(fqdn), dns.rdataclass.IN, qtype, dns.rdatatype.NONE
+        ):
+            rrs.append(rrset)
+        else:
+            assert False
+
+    return rrs, rrsigs
+
+
+# pylint: disable=too-many-arguments
+def check_apex(server, zone, ksks, zsks, kskdir=None, zskdir=None):
+    # Test the apex of a zone. This checks that the SOA and DNSKEY RRsets
+    # are signed correctly and with the appropriate keys.
+    fqdn = f"{zone}."
+
+    # test dnskey query
+    dnskeys, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.DNSKEY)
+    assert len(dnskeys) > 0
+    check_dnskeys(dnskeys, ksks, zsks, kskdir=kskdir, zskdir=zskdir)
+    assert len(rrsigs) > 0
+    check_signatures(
+        rrsigs, dns.rdatatype.DNSKEY, fqdn, ksks, zsks, kskdir=kskdir, zskdir=zskdir
+    )
+
+    # test soa query
+    soa, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.SOA)
+    assert len(soa) == 1
+    assert f"{zone}. {DEFAULT_TTL} IN SOA" in soa[0].to_text()
+    assert len(rrsigs) > 0
+    check_signatures(
+        rrsigs, dns.rdatatype.SOA, fqdn, ksks, zsks, kskdir=kskdir, zskdir=zskdir
+    )
+
+    # test cdnskey query
+    cdnskeys, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.CDNSKEY)
+    assert len(cdnskeys) > 0
+    check_dnskeys(cdnskeys, ksks, zsks, kskdir=kskdir, zskdir=zskdir, cdnskey=True)
+    assert len(rrsigs) > 0
+    check_signatures(
+        rrsigs, dns.rdatatype.CDNSKEY, fqdn, ksks, zsks, kskdir=kskdir, zskdir=zskdir
+    )
+
+    # test cds query
+    cds, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.CDS)
+    assert len(cds) > 0
+    check_cds(cds, ksks, keydir=kskdir)
+    assert len(rrsigs) > 0
+    check_signatures(
+        rrsigs, dns.rdatatype.CDS, fqdn, ksks, zsks, kskdir=kskdir, zskdir=zskdir
+    )
+
+
+# pylint: disable=too-many-arguments
+def check_subdomain(server, zone, ksks, zsks, kskdir=None, zskdir=None):
+    # Test an RRset below the apex and verify it is signed correctly.
+    fqdn = f"{zone}."
+    qname = f"a.{zone}."
+    qtype = dns.rdatatype.A
+    response = _query(server, qname, qtype)
+    assert response.rcode() == dns.rcode.NOERROR
+
+    match = f"{qname} {DEFAULT_TTL} IN A 10.0.0.1"
+    rrsigs = []
+    for rrset in response.answer:
+        if rrset.match(
+            dns.name.from_text(qname), dns.rdataclass.IN, dns.rdatatype.RRSIG, qtype
+        ):
+            rrsigs.append(rrset)
+        else:
+            assert match in rrset.to_text()
+
+    assert len(rrsigs) > 0
+    check_signatures(rrsigs, qtype, fqdn, ksks, zsks, kskdir=kskdir, zskdir=zskdir)