From: Bob Halley Date: Tue, 17 Sep 2024 12:56:14 +0000 (-0700) Subject: Add a copy mode to dns.message.make_response(). (#1131) X-Git-Tag: v2.7.0rc1~11 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aad6d8174d9767becba3c348dc469d680a55feb7;p=thirdparty%2Fdnspython.git Add a copy mode to dns.message.make_response(). (#1131) Add a copy mode to dns.message.make_response(). If the mode is none, then a default copy mode appropriate for the opcode will be used. This is currently always dns.message.CopyMode.QUESTION. If the mode is dns.message.CopyMode.QUESTION then only the question section is copied. If the mode is dns.message.CopyMode.EVERYTHING, then all sections are copied other than OPT or TSIG records which are created appropriately if needed instead of being copied. If the mode is dns.message.CopyMode.NOTHING then no sections are copied. --- diff --git a/dns/message.py b/dns/message.py index f320f826..e978a0a2 100644 --- a/dns/message.py +++ b/dns/message.py @@ -18,6 +18,7 @@ """DNS Messages""" import contextlib +import enum import io import time from typing import Any, Dict, List, Optional, Tuple, Union, cast @@ -1825,6 +1826,16 @@ def make_query( return m +class CopyMode(enum.Enum): + """ + How should sections be copied when making an update response? + """ + + NOTHING = 0 + QUESTION = 1 + EVERYTHING = 2 + + def make_response( query: Message, recursion_available: bool = False, @@ -1832,13 +1843,14 @@ def make_response( fudge: int = 300, tsig_error: int = 0, pad: Optional[int] = None, + copy_mode: Optional[CopyMode] = None, ) -> Message: """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the content. - The response's question section is a shallow copy of the query's question section, - so the query's question RRsets should not be changed. + Response section(s) which are copied are shallow copies of the matching section(s) + in the query, so the query's RRsets should not be changed. *query*, a ``dns.message.Message``, the query to respond to. @@ -1851,25 +1863,44 @@ def make_response( *tsig_error*, an ``int``, the TSIG error. *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise - if not ``None`` add padding bytes to make the message size a multiple of *pad*. - Note that if padding is non-zero, an EDNS PADDING option will always be added to the + if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note + that if padding is non-zero, an EDNS PADDING option will always be added to the message. If ``None``, add padding following RFC 8467, namely if the request is padded, pad the response to 468 otherwise do not pad. + *copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are + copied. The default, ``None`` copies sections according to the default for the + message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all + opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section. + ``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG + records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING`` + copies no sections; note that this mode is for server testing purposes and is + otherwise not recommended for use. In particular, ``dns.message.is_response()`` + will be ``False`` if you create a response this way and the rcode is not + ``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``. + Returns a ``dns.message.Message`` object whose specific class is appropriate for the - query. For example, if query is a ``dns.update.UpdateMessage``, response will be - too. + query. For example, if query is a ``dns.update.UpdateMessage``, the response will + be one too. """ if query.flags & dns.flags.QR: raise dns.exception.FormError("specified query message is not a query") - factory = _message_factory_from_opcode(query.opcode()) + opcode = query.opcode() + factory = _message_factory_from_opcode(opcode) response = factory(id=query.id) response.flags = dns.flags.QR | (query.flags & dns.flags.RD) if recursion_available: response.flags |= dns.flags.RA - response.set_opcode(query.opcode()) - response.question = list(query.question) + response.set_opcode(opcode) + if copy_mode is None: + copy_mode = CopyMode.QUESTION + if copy_mode != CopyMode.NOTHING: + response.question = list(query.question) + if copy_mode == CopyMode.EVERYTHING: + response.answer = list(query.answer) + response.authority = list(query.authority) + response.additional = list(query.additional) if query.edns >= 0: if pad is None: # Set response padding per RFC 8467 diff --git a/tests/test_constants.py b/tests/test_constants.py index bf0d9709..43e9863b 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -3,14 +3,13 @@ import unittest import dns.dnssec -import dns.rdtypes.dnskeybase +import dns.edns import dns.flags -import dns.rcode -import dns.opcode import dns.message +import dns.opcode +import dns.rcode +import dns.rdtypes.dnskeybase import dns.update -import dns.edns - import tests.util @@ -27,7 +26,9 @@ class ConstantsTestCase(unittest.TestCase): tests.util.check_enum_exports(dns.opcode, self.assertEqual) def test_message_constants(self): - tests.util.check_enum_exports(dns.message, self.assertEqual) + tests.util.check_enum_exports( + dns.message, self.assertEqual, only={dns.message.MessageSection} + ) tests.util.check_enum_exports(dns.update, self.assertEqual) def test_rdata_constants(self): diff --git a/tests/test_message.py b/tests/test_message.py index 89db217e..1d592657 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -233,6 +233,46 @@ class MessageTestCase(unittest.TestCase): self.assertTrue(r.flags & dns.flags.RA != 0) self.assertEqual(r.edns, 0) + def _make_copy_query(self): + q = dns.message.make_query("foo", "A") + # These are nonsensical, but all we care about is they get copied. + q.answer.append(dns.rrset.from_text("foo", 300, "IN", "A", "10.0.0.1")) + q.authority.append(dns.rrset.from_text("foo2", 300, "IN", "A", "10.0.0.2")) + q.additional.append(dns.rrset.from_text("foo3", 300, "IN", "A", "10.0.0.3")) + return q + + def test_MakeResponseCopyNothing(self): + q = self._make_copy_query() + r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.NOTHING) + self.assertEqual(len(r.question), 0) + self.assertEqual(len(r.answer), 0) + self.assertEqual(len(r.authority), 0) + self.assertEqual(len(r.additional), 0) + + def test_MakeResponseCopyDefault(self): + q = self._make_copy_query() + r = dns.message.make_response(q) + self.assertTrue(len(r.question) == 1 and q.question[0] == r.question[0]) + self.assertEqual(len(r.answer), 0) + self.assertEqual(len(r.authority), 0) + self.assertEqual(len(r.additional), 0) + + def test_MakeResponseCopyQuestion(self): + q = self._make_copy_query() + r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.QUESTION) + self.assertTrue(len(r.question) == 1 and q.question[0] == r.question[0]) + self.assertEqual(len(r.answer), 0) + self.assertEqual(len(r.authority), 0) + self.assertEqual(len(r.additional), 0) + + def test_MakeResponseCopyEverything(self): + q = self._make_copy_query() + r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.EVERYTHING) + self.assertTrue(len(r.question) == 1 and q.question == r.question) + self.assertTrue(len(r.answer) == 1 and q.answer == r.answer) + self.assertTrue(len(r.authority) == 1 and q.authority == r.authority) + self.assertTrue(len(r.additional) == 1 and q.additional == r.additional) + def test_ExtendedRcodeSetting(self): m = dns.message.make_query("foo", "A") m.set_rcode(4095)