]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
dns.message.make_query() now interprets any setting that implies
authorBob Halley <halley@dnspython.org>
Fri, 13 May 2016 13:43:36 +0000 (06:43 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 13 May 2016 13:43:36 +0000 (06:43 -0700)
EDNS as a request to turn on EDNS, if use_edns has not been set
explicitly.

dns/message.py
tests/test_message.py

index 1b310869f439b6db34fd737eb5eab11105f7e707..184100dab9b88cee176ba45af57b26c07d22d5ac 100644 (file)
@@ -1043,7 +1043,7 @@ def from_file(f):
 
 
 def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
-               want_dnssec=False, ednsflags=0, payload=1280,
+               want_dnssec=False, ednsflags=None, payload=None,
                request_payload=None, options=None):
     """Make a query message.
 
@@ -1088,7 +1088,28 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
     m.flags |= dns.flags.RD
     m.find_rrset(m.question, qname, rdclass, rdtype, create=True,
                  force_unique=True)
-    m.use_edns(use_edns, ednsflags, payload, request_payload, options)
+    # only pass keywords on to use_edns if they have been set to a
+    # non-None value.  Setting a field will turn EDNS on if it hasn't
+    # been configured.
+    kwargs = {}
+    if ednsflags is not None:
+        kwargs['ednsflags'] = ednsflags
+        if use_edns is None:
+            use_edns = 0
+    if payload is not None:
+        kwargs['payload'] = payload
+        if use_edns is None:
+            use_edns = 0
+    if request_payload is not None:
+        kwargs['request_payload'] = request_payload
+        if use_edns is None:
+            use_edns = 0
+    if options is not None:
+        kwargs['options'] = options
+        if use_edns is None:
+            use_edns = 0
+    kwargs['edns'] = use_edns
+    m.use_edns(**kwargs)
     m.want_dnssec(want_dnssec)
     return m
 
index 795626541f64c79c2a0ef811b7eb73e1e7d771cf..4513b959615611a78b1beb691c777e82dcbcb368 100644 (file)
@@ -21,6 +21,7 @@ except ImportError:
 import binascii
 
 import dns.exception
+import dns.flags
 import dns.message
 from dns._compat import xrange
 
@@ -179,5 +180,25 @@ class MessageTestCase(unittest.TestCase):
         m.use_edns(1)
         self.failUnless((m.ednsflags >> 16) & 0xFF == 1)
 
+    def test_SettingNoEDNSOptionsImpliesNoEDNS(self):
+        m = dns.message.make_query('foo', 'A')
+        self.failUnless(m.edns == -1)
+
+    def test_SettingEDNSFlagsImpliesEDNS(self):
+        m = dns.message.make_query('foo', 'A', ednsflags=dns.flags.DO)
+        self.failUnless(m.edns == 0)
+
+    def test_SettingEDNSPayloadImpliesEDNS(self):
+        m = dns.message.make_query('foo', 'A', payload=4096)
+        self.failUnless(m.edns == 0)
+
+    def test_SettingEDNSRequestPayloadImpliesEDNS(self):
+        m = dns.message.make_query('foo', 'A', request_payload=4096)
+        self.failUnless(m.edns == 0)
+
+    def test_SettingOptionsImpliesEDNS(self):
+        m = dns.message.make_query('foo', 'A', options=[])
+        self.failUnless(m.edns == 0)
+
 if __name__ == '__main__':
     unittest.main()