]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add source_port support to resolver; fix source_port in query code
authorBob Halley <halley@nominum.com>
Sun, 8 Apr 2012 12:25:36 +0000 (13:25 +0100)
committerBob Halley <halley@nominum.com>
Sun, 8 Apr 2012 12:25:36 +0000 (13:25 +0100)
ChangeLog
README
dns/query.py
dns/resolver.py

index 9b5d4e7dd52f20a34e4de6f64f38945575438e29..173699932d57a32c83b7c76d1e7cf411ac8deec3 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,12 @@
 2012-04-08  Bob Halley  <halley@dnspython.org>
 
+       * dns/query.py: Specifying source_port had no effect if source was
+         not specified.  We now use the appropriate wildcard source in
+         that case.
+
+       * dns/resolver.py (Resolver.query): source_port may now be
+         specified.
+
        * dns/resolver.py (Resolver.query): Switch to TCP when a UDP
          response is truncated.  Handle nameservers that serve on UDP
          but not TCP.
diff --git a/README b/README
index 039ea32d5f93dec9487dcb6741fca3b811241d2e..76faeffae69813f095a5f5747fb686bf2be506c9 100644 (file)
--- a/README
+++ b/README
@@ -47,6 +47,8 @@ New since 1.9.4:
 
        Trailing junk checking can be disabled.
 
+       A source port can be specified when creating a resolver query.
+
 Bugs fixed since 1.9.4:
 
        IPv4 and IPv6 address processing is now stricter.
@@ -56,6 +58,10 @@ Bugs fixed since 1.9.4:
        expected) now raise dns.exception.FormError rather than
        IndexError.
 
+       Specifying a source port without specifying source used to
+       have no effect, but now uses the wildcard address and the
+       specified port.
+
 New since 1.9.3:
 
         Nothing.
index 7bba3529d5e9d3a88b1b304b601f9521b272898d..0e6eb9293cf53ea36f73286eee9b83ec94fb6dc4 100644 (file)
@@ -144,6 +144,28 @@ def _addresses_equal(af, a1, a2):
     n2 = dns.inet.inet_pton(af, a2[0])
     return n1 == n2 and a1[1:] == a2[1:]
 
+def _destination_and_source(af, where, port, source, source_port):
+    # Apply defaults and compute destination and source tuples
+    # suitable for use in connect(), sendto(), or bind().
+    if af is None:
+        try:
+            af = dns.inet.af_for_address(where)
+        except:
+            af = dns.inet.AF_INET
+    if af == dns.inet.AF_INET:
+        destination = (where, port)
+        if source is not None or source_port != 0:
+            if source is None:
+                source = '0.0.0.0'
+            source = (source, source_port)
+    elif af == dns.inet.AF_INET6:
+        destination = (where, port, 0, 0)
+        if source is not None or source_port != 0:
+            if source is None:
+                source = '::'
+            source = (source, source_port, 0, 0)
+    return (af, destination, source)
+
 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         ignore_unexpected=False, one_rr_per_rrset=False):
     """Return the response obtained after sending a query via UDP.
@@ -162,7 +184,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     If the inference attempt fails, AF_INET is used.
     @type af: int
     @rtype: dns.message.Message object
-    @param source: source address.  The default is the IPv4 wildcard address.
+    @param source: source address.  The default is the wildcard address.
     @type source: string
     @param source_port: The port from which to send the message.
     The default is 0.
@@ -175,19 +197,8 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     """
 
     wire = q.to_wire()
-    if af is None:
-        try:
-            af = dns.inet.af_for_address(where)
-        except:
-            af = dns.inet.AF_INET
-    if af == dns.inet.AF_INET:
-        destination = (where, port)
-        if source is not None:
-            source = (source, source_port)
-    elif af == dns.inet.AF_INET6:
-        destination = (where, port, 0, 0)
-        if source is not None:
-            source = (source, source_port, 0, 0)
+    (af, destination, source) = _destination_and_source(af, where, port, source,
+                                                        source_port)
     s = socket.socket(af, socket.SOCK_DGRAM, 0)
     try:
         expiration = _compute_expiration(timeout)
@@ -270,7 +281,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     If the inference attempt fails, AF_INET is used.
     @type af: int
     @rtype: dns.message.Message object
-    @param source: source address.  The default is the IPv4 wildcard address.
+    @param source: source address.  The default is the wildcard address.
     @type source: string
     @param source_port: The port from which to send the message.
     The default is 0.
@@ -280,19 +291,8 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     """
 
     wire = q.to_wire()
-    if af is None:
-        try:
-            af = dns.inet.af_for_address(where)
-        except:
-            af = dns.inet.AF_INET
-    if af == dns.inet.AF_INET:
-        destination = (where, port)
-        if source is not None:
-            source = (source, source_port)
-    elif af == dns.inet.AF_INET6:
-        destination = (where, port, 0, 0)
-        if source is not None:
-            source = (source, source_port, 0, 0)
+    (af, destination, source) = _destination_and_source(af, where, port, source,
+                                                        source_port)
     s = socket.socket(af, socket.SOCK_STREAM, 0)
     try:
         expiration = _compute_expiration(timeout)
@@ -357,7 +357,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     take.
     @type lifetime: float
     @rtype: generator of dns.message.Message objects.
-    @param source: source address.  The default is the IPv4 wildcard address.
+    @param source: source address.  The default is the wildcard address.
     @type source: string
     @param source_port: The port from which to send the message.
     The default is 0.
@@ -384,19 +384,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     if not keyring is None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
     wire = q.to_wire()
-    if af is None:
-        try:
-            af = dns.inet.af_for_address(where)
-        except:
-            af = dns.inet.AF_INET
-    if af == dns.inet.AF_INET:
-        destination = (where, port)
-        if source is not None:
-            source = (source, source_port)
-    elif af == dns.inet.AF_INET6:
-        destination = (where, port, 0, 0)
-        if source is not None:
-            source = (source, source_port, 0, 0)
+    (af, destination, source) = _destination_and_source(af, where, port, source,
+                                                        source_port)
     if use_udp:
         if rdtype != dns.rdatatype.IXFR:
             raise ValueError('cannot do a UDP AXFR')
index 9f9b438a2cc2140fbc8aaa1b05082fb9f0bb2924..4fb13d369be6d139685bb69f7056599a10890416 100644 (file)
@@ -688,7 +688,7 @@ class Resolver(object):
         return min(self.lifetime - duration, self.timeout)
 
     def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-              tcp=False, source=None, raise_on_no_answer=True):
+              tcp=False, source=None, raise_on_no_answer=True, source_port=0):
         """Query nameservers to find the answer to the question.
 
         The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects
@@ -709,6 +709,9 @@ class Resolver(object):
         @param raise_on_no_answer: raise NoAnswer if there's no answer
         (defaults is True).
         @type raise_on_no_answer: bool
+        @param source_port: The port from which to send the message.
+        The default is 0.
+        @type source_port: int
         @rtype: dns.resolver.Answer instance
         @raises Timeout: no answers could be found in the specified lifetime
         @raises NXDOMAIN: the query name does not exist
@@ -768,17 +771,20 @@ class Resolver(object):
                         if tcp:
                             response = dns.query.tcp(request, nameserver,
                                                      timeout, self.port,
-                                                     source=source)
+                                                     source=source,
+                                                     source_port=source_port)
                         else:
                             response = dns.query.udp(request, nameserver,
                                                      timeout, self.port,
-                                                     source=source)
+                                                     source=source,
+                                                     source_port=source_port)
                             if response.flags & dns.flags.TC:
                                 # Response truncated; retry with TCP.
                                 timeout = self._compute_timeout(start)
                                 response = dns.query.tcp(request, nameserver,
-                                                         timeout, self.port,
-                                                         source=source)
+                                                       timeout, self.port,
+                                                       source=source,
+                                                       source_port=source_port)
 
                     except (socket.error, dns.exception.Timeout):
                         #
@@ -898,7 +904,8 @@ def get_default_resolver():
     return default_resolver
 
 def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-          tcp=False, source=None, raise_on_no_answer=True):
+          tcp=False, source=None, raise_on_no_answer=True,
+          source_port=0):
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
@@ -906,7 +913,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
     @see: L{dns.resolver.Resolver.query} for more information on the
     parameters."""
     return get_default_resolver().query(qname, rdtype, rdclass, tcp, source,
-                                        raise_on_no_answer)
+                                        raise_on_no_answer, source_port)
 
 def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
     """Find the name of the zone which contains the specified name.