]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add flags to dns.message.make_query(). 720/head
authorBrian Wellington <bwelling@xbill.org>
Tue, 2 Nov 2021 17:41:15 +0000 (10:41 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 2 Nov 2021 17:41:15 +0000 (10:41 -0700)
dns/message.py
tests/test_message.py

index 8e6f5cc4bb2530c84205633ce74fd92376d4ae9a..1e67a17b8e2e30de9099aa103c71724bd19f741f 100644 (file)
@@ -1425,7 +1425,7 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False):
 def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
                want_dnssec=False, ednsflags=None, payload=None,
                request_payload=None, options=None, idna_codec=None,
-               id=None):
+               id=None, flags=dns.flags.RD):
     """Make a query message.
 
     The query name, type, and class may all be specified either
@@ -1470,6 +1470,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
     *id*, an ``int`` or ``None``, the desired query id.  The default is
     ``None``, which generates a random query id.
 
+    *flags*, an ``int``, the desired query flags.  The default is
+    ``dns.flags.RD``.
+
     Returns a ``dns.message.QueryMessage``
     """
 
@@ -1478,7 +1481,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     m = QueryMessage(id=id)
-    m.flags |= dns.flags.RD
+    m.flags = dns.flags.Flag(flags)
     m.find_rrset(m.question, qname, rdclass, rdtype, create=True,
                  force_unique=True)
     # only pass keywords on to use_edns if they have been set to a
index 190385af932d0edf3b7bc16a5404df5aeea319d1..ad302984951e5bd134996c69be931414b30899ee 100644 (file)
@@ -441,6 +441,12 @@ class MessageTestCase(unittest.TestCase):
         q = dns.message.make_query('www.dnspython.org.', 'a', id=12345)
         self.assertEqual(q.id, 12345)
 
+    def test_setting_flags(self):
+        q = dns.message.make_query('www.dnspython.org.', 'a',
+                                   flags=dns.flags.RD|dns.flags.CD)
+        self.assertEqual(q.flags, dns.flags.RD|dns.flags.CD)
+        self.assertEqual(q.flags, 0x0110)
+
     def test_generic_message_class(self):
         q1 = dns.message.Message(id=1)
         q1.set_opcode(dns.opcode.NOTIFY)