From d52b840088a2c4e496a0e68f6db441c8c1340c41 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Tue, 2 Nov 2021 10:41:15 -0700 Subject: [PATCH] Add flags to dns.message.make_query(). --- dns/message.py | 7 +++++-- tests/test_message.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/dns/message.py b/dns/message.py index 8e6f5cc4..1e67a17b 100644 --- a/dns/message.py +++ b/dns/message.py @@ -1425,7 +1425,7 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False): def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, want_dnssec=False, ednsflags=None, payload=None, request_payload=None, options=None, idna_codec=None, - id=None): + id=None, flags=dns.flags.RD): """Make a query message. The query name, type, and class may all be specified either @@ -1470,6 +1470,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, *id*, an ``int`` or ``None``, the desired query id. The default is ``None``, which generates a random query id. + *flags*, an ``int``, the desired query flags. The default is + ``dns.flags.RD``. + Returns a ``dns.message.QueryMessage`` """ @@ -1478,7 +1481,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, rdtype = dns.rdatatype.RdataType.make(rdtype) rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) - m.flags |= dns.flags.RD + m.flags = dns.flags.Flag(flags) m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a diff --git a/tests/test_message.py b/tests/test_message.py index 190385af..ad302984 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -441,6 +441,12 @@ class MessageTestCase(unittest.TestCase): q = dns.message.make_query('www.dnspython.org.', 'a', id=12345) self.assertEqual(q.id, 12345) + def test_setting_flags(self): + q = dns.message.make_query('www.dnspython.org.', 'a', + flags=dns.flags.RD|dns.flags.CD) + self.assertEqual(q.flags, dns.flags.RD|dns.flags.CD) + self.assertEqual(q.flags, 0x0110) + def test_generic_message_class(self): q1 = dns.message.Message(id=1) q1.set_opcode(dns.opcode.NOTIFY) -- 2.47.3