From 729d4f3ab56d0d25e89654dc1dd33ec6a1cf1074 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 13 May 2016 06:43:36 -0700 Subject: [PATCH] dns.message.make_query() now interprets any setting that implies EDNS as a request to turn on EDNS, if use_edns has not been set explicitly. --- dns/message.py | 25 +++++++++++++++++++++++-- tests/test_message.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/dns/message.py b/dns/message.py index 1b310869..184100da 100644 --- a/dns/message.py +++ b/dns/message.py @@ -1043,7 +1043,7 @@ def from_file(f): def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, - want_dnssec=False, ednsflags=0, payload=1280, + want_dnssec=False, ednsflags=None, payload=None, request_payload=None, options=None): """Make a query message. @@ -1088,7 +1088,28 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, m.flags |= dns.flags.RD m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) - m.use_edns(use_edns, ednsflags, payload, request_payload, options) + # only pass keywords on to use_edns if they have been set to a + # non-None value. Setting a field will turn EDNS on if it hasn't + # been configured. + kwargs = {} + if ednsflags is not None: + kwargs['ednsflags'] = ednsflags + if use_edns is None: + use_edns = 0 + if payload is not None: + kwargs['payload'] = payload + if use_edns is None: + use_edns = 0 + if request_payload is not None: + kwargs['request_payload'] = request_payload + if use_edns is None: + use_edns = 0 + if options is not None: + kwargs['options'] = options + if use_edns is None: + use_edns = 0 + kwargs['edns'] = use_edns + m.use_edns(**kwargs) m.want_dnssec(want_dnssec) return m diff --git a/tests/test_message.py b/tests/test_message.py index 79562654..4513b959 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -21,6 +21,7 @@ except ImportError: import binascii import dns.exception +import dns.flags import dns.message from dns._compat import xrange @@ -179,5 +180,25 @@ class MessageTestCase(unittest.TestCase): m.use_edns(1) self.failUnless((m.ednsflags >> 16) & 0xFF == 1) + def test_SettingNoEDNSOptionsImpliesNoEDNS(self): + m = dns.message.make_query('foo', 'A') + self.failUnless(m.edns == -1) + + def test_SettingEDNSFlagsImpliesEDNS(self): + m = dns.message.make_query('foo', 'A', ednsflags=dns.flags.DO) + self.failUnless(m.edns == 0) + + def test_SettingEDNSPayloadImpliesEDNS(self): + m = dns.message.make_query('foo', 'A', payload=4096) + self.failUnless(m.edns == 0) + + def test_SettingEDNSRequestPayloadImpliesEDNS(self): + m = dns.message.make_query('foo', 'A', request_payload=4096) + self.failUnless(m.edns == 0) + + def test_SettingOptionsImpliesEDNS(self): + m = dns.message.make_query('foo', 'A', options=[]) + self.failUnless(m.edns == 0) + if __name__ == '__main__': unittest.main() -- 2.47.3