]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add a copy mode to dns.message.make_response(). (#1131)
authorBob Halley <halley@dnspython.org>
Tue, 17 Sep 2024 12:56:14 +0000 (05:56 -0700)
committerGitHub <noreply@github.com>
Tue, 17 Sep 2024 12:56:14 +0000 (05:56 -0700)
Add a copy mode to dns.message.make_response().

If the mode is none, then a default copy mode appropriate for the opcode will
be used.  This is currently always dns.message.CopyMode.QUESTION.

If the mode is dns.message.CopyMode.QUESTION then only the question
section is copied.

If the mode is dns.message.CopyMode.EVERYTHING, then all sections are
copied other than OPT or TSIG records which are created appropriately
if needed instead of being copied.

If the mode is dns.message.CopyMode.NOTHING then no sections are
copied.

dns/message.py
tests/test_constants.py
tests/test_message.py

index f320f82698607f4a3d3f34ece7cf87caa58d681d..e978a0a2e1d8ce681b4d353dc9d4dcb74c97006d 100644 (file)
@@ -18,6 +18,7 @@
 """DNS Messages"""
 
 import contextlib
+import enum
 import io
 import time
 from typing import Any, Dict, List, Optional, Tuple, Union, cast
@@ -1825,6 +1826,16 @@ def make_query(
     return m
 
 
+class CopyMode(enum.Enum):
+    """
+    How should sections be copied when making an update response?
+    """
+
+    NOTHING = 0
+    QUESTION = 1
+    EVERYTHING = 2
+
+
 def make_response(
     query: Message,
     recursion_available: bool = False,
@@ -1832,13 +1843,14 @@ def make_response(
     fudge: int = 300,
     tsig_error: int = 0,
     pad: Optional[int] = None,
+    copy_mode: Optional[CopyMode] = None,
 ) -> Message:
     """Make a message which is a response for the specified query.
     The message returned is really a response skeleton; it has all of the infrastructure
     required of a response, but none of the content.
 
-    The response's question section is a shallow copy of the query's question section,
-    so the query's question RRsets should not be changed.
+    Response section(s) which are copied are shallow copies of the matching section(s)
+    in the query, so the query's RRsets should not be changed.
 
     *query*, a ``dns.message.Message``, the query to respond to.
 
@@ -1851,25 +1863,44 @@ def make_response(
     *tsig_error*, an ``int``, the TSIG error.
 
     *pad*, a non-negative ``int`` or ``None``.  If 0, the default, do not pad; otherwise
-    if not ``None`` add padding bytes to make the message size a multiple of *pad*.
-    Note that if padding is non-zero, an EDNS PADDING option will always be added to the
+    if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note
+    that if padding is non-zero, an EDNS PADDING option will always be added to the
     message.  If ``None``, add padding following RFC 8467, namely if the request is
     padded, pad the response to 468 otherwise do not pad.
 
+    *copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are
+    copied.  The default, ``None`` copies sections according to the default for the
+    message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all
+    opcodes.   ``dns.message.CopyMode.QUESTION`` copies only the question section.
+    ``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG
+    records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING``
+    copies no sections; note that this mode is for server testing purposes and is
+    otherwise not recommended for use.  In particular, ``dns.message.is_response()``
+    will be ``False`` if you create a response this way and the rcode is not
+    ``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``.
+
     Returns a ``dns.message.Message`` object whose specific class is appropriate for the
-    query.  For example, if query is a ``dns.update.UpdateMessage``, response will be
-    too.
+    query.  For example, if query is a ``dns.update.UpdateMessage``, the response will
+    be one too.
     """
 
     if query.flags & dns.flags.QR:
         raise dns.exception.FormError("specified query message is not a query")
-    factory = _message_factory_from_opcode(query.opcode())
+    opcode = query.opcode()
+    factory = _message_factory_from_opcode(opcode)
     response = factory(id=query.id)
     response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
     if recursion_available:
         response.flags |= dns.flags.RA
-    response.set_opcode(query.opcode())
-    response.question = list(query.question)
+    response.set_opcode(opcode)
+    if copy_mode is None:
+        copy_mode = CopyMode.QUESTION
+    if copy_mode != CopyMode.NOTHING:
+        response.question = list(query.question)
+    if copy_mode == CopyMode.EVERYTHING:
+        response.answer = list(query.answer)
+        response.authority = list(query.authority)
+        response.additional = list(query.additional)
     if query.edns >= 0:
         if pad is None:
             # Set response padding per RFC 8467
index bf0d97091955f6bdd2179547c508154b0cb6c2b6..43e9863b80c06489d56858eb782245d29d06b2e7 100644 (file)
@@ -3,14 +3,13 @@
 import unittest
 
 import dns.dnssec
-import dns.rdtypes.dnskeybase
+import dns.edns
 import dns.flags
-import dns.rcode
-import dns.opcode
 import dns.message
+import dns.opcode
+import dns.rcode
+import dns.rdtypes.dnskeybase
 import dns.update
-import dns.edns
-
 import tests.util
 
 
@@ -27,7 +26,9 @@ class ConstantsTestCase(unittest.TestCase):
         tests.util.check_enum_exports(dns.opcode, self.assertEqual)
 
     def test_message_constants(self):
-        tests.util.check_enum_exports(dns.message, self.assertEqual)
+        tests.util.check_enum_exports(
+            dns.message, self.assertEqual, only={dns.message.MessageSection}
+        )
         tests.util.check_enum_exports(dns.update, self.assertEqual)
 
     def test_rdata_constants(self):
index 89db217e3917eb2a55ce739f07eb7513b49d6653..1d592657145eab717f3a6b1bf212b592701cedc0 100644 (file)
@@ -233,6 +233,46 @@ class MessageTestCase(unittest.TestCase):
         self.assertTrue(r.flags & dns.flags.RA != 0)
         self.assertEqual(r.edns, 0)
 
+    def _make_copy_query(self):
+        q = dns.message.make_query("foo", "A")
+        # These are nonsensical, but all we care about is they get copied.
+        q.answer.append(dns.rrset.from_text("foo", 300, "IN", "A", "10.0.0.1"))
+        q.authority.append(dns.rrset.from_text("foo2", 300, "IN", "A", "10.0.0.2"))
+        q.additional.append(dns.rrset.from_text("foo3", 300, "IN", "A", "10.0.0.3"))
+        return q
+
+    def test_MakeResponseCopyNothing(self):
+        q = self._make_copy_query()
+        r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.NOTHING)
+        self.assertEqual(len(r.question), 0)
+        self.assertEqual(len(r.answer), 0)
+        self.assertEqual(len(r.authority), 0)
+        self.assertEqual(len(r.additional), 0)
+
+    def test_MakeResponseCopyDefault(self):
+        q = self._make_copy_query()
+        r = dns.message.make_response(q)
+        self.assertTrue(len(r.question) == 1 and q.question[0] == r.question[0])
+        self.assertEqual(len(r.answer), 0)
+        self.assertEqual(len(r.authority), 0)
+        self.assertEqual(len(r.additional), 0)
+
+    def test_MakeResponseCopyQuestion(self):
+        q = self._make_copy_query()
+        r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.QUESTION)
+        self.assertTrue(len(r.question) == 1 and q.question[0] == r.question[0])
+        self.assertEqual(len(r.answer), 0)
+        self.assertEqual(len(r.authority), 0)
+        self.assertEqual(len(r.additional), 0)
+
+    def test_MakeResponseCopyEverything(self):
+        q = self._make_copy_query()
+        r = dns.message.make_response(q, copy_mode=dns.message.CopyMode.EVERYTHING)
+        self.assertTrue(len(r.question) == 1 and q.question == r.question)
+        self.assertTrue(len(r.answer) == 1 and q.answer == r.answer)
+        self.assertTrue(len(r.authority) == 1 and q.authority == r.authority)
+        self.assertTrue(len(r.additional) == 1 and q.additional == r.additional)
+
     def test_ExtendedRcodeSetting(self):
         m = dns.message.make_query("foo", "A")
         m.set_rcode(4095)