import os
from typing import Optional
+import dns.flags
import dns.rcode
import dns.message
import dns.zone
rcode(message, dns_rcode.SERVFAIL)
+def adflag(message: dns.message.Message) -> None:
+ assert (message.flags & dns.flags.AD) != 0, str(message)
+
+
+def noadflag(message: dns.message.Message) -> None:
+ assert (message.flags & dns.flags.AD) == 0, str(message)
+
+
+def rdflag(message: dns.message.Message) -> None:
+ assert (message.flags & dns.flags.RD) != 0, str(message)
+
+
+def nordflag(message: dns.message.Message) -> None:
+ assert (message.flags & dns.flags.RD) == 0, str(message)
+
+
+def section_equal(sec1: list, sec2: list) -> None:
+ # convert an RRset to a normalized string (lower case, TTL=0)
+ # so it can be used as a set member.
+ def normalized(rrset):
+ ttl = rrset.ttl
+ rrset.ttl = 0
+ s = str(rrset).lower()
+ rrset.ttl = ttl
+ return s
+
+ # convert the section contents to sets before comparison,
+ # in case they aren't in the same sort order.
+ set1 = {normalized(item) for item in sec1}
+ set2 = {normalized(item) for item in sec2}
+ assert set1 == set2
+
+
+def same_data(res1: dns.message.Message, res2: dns.message.Message):
+ assert res1.question == res2.question
+ section_equal(res1.answer, res2.answer)
+ section_equal(res1.authority, res2.authority)
+ section_equal(res1.additional, res2.additional)
+ assert res1.rcode() == res2.rcode()
+
+
+def same_answer(res1: dns.message.Message, res2: dns.message.Message):
+ assert res1.question == res2.question
+ section_equal(res1.answer, res2.answer)
+ assert res1.rcode() == res2.rcode()
+
+
def rrsets_equal(
first_rrset: dns.rrset.RRset,
second_rrset: dns.rrset.RRset,
assert not message.answer, str(message)
+def answer_count_eq(m: dns.message.Message, expected: int):
+ count = sum(max(1, len(rrs)) for rrs in m.answer)
+ assert count == expected, str(m)
+
+
+def authority_count_eq(m: dns.message.Message, expected: int):
+ count = sum(max(1, len(rrs)) for rrs in m.authority)
+ assert count == expected, str(m)
+
+
+def additional_count_eq(m: dns.message.Message, expected: int):
+ count = sum(max(1, len(rrs)) for rrs in m.additional)
+
+ # add one for the OPT?
+ opt = bool(m.opt) if hasattr(m, "opt") else bool(m.edns >= 0)
+ count += 1 if opt else 0
+
+ # add one for the TSIG?
+ tsig = bool(m.tsig) if hasattr(m, "tsig") else m.had_tsig
+ count += 1 if tsig else 0
+
+ assert count == expected, str(m)
+
+
def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None:
single_question(response)
single_question(query)
--- /dev/null
+# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
+#
+# SPDX-License-Identifier: MPL-2.0
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, you can obtain one at https://mozilla.org/MPL/2.0/.
+#
+# See the COPYRIGHT file distributed with this work for additional
+# information regarding copyright ownership.
+
+from dns import flags, message
+
+
+def msg(qname: str, qtype: str, **kwargs):
+ headerflags = flags.RD
+ # "ad" is on by default
+ if "ad" not in kwargs or not kwargs["ad"]:
+ headerflags |= flags.AD
+ # "cd" is off by default
+ if "cd" in kwargs and kwargs["cd"]:
+ headerflags |= flags.CD
+ return message.make_query(
+ qname, qtype, use_edns=True, want_dnssec=True, flags=headerflags
+ )