]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
lifetime (timeout) support for dns.resolver.query 258/head
authorShatil Rafiullah <shatil@atomtickets.com>
Sun, 4 Jun 2017 19:59:12 +0000 (12:59 -0700)
committerShatil Rafiullah <shatil@atomtickets.com>
Sun, 4 Jun 2017 19:59:12 +0000 (12:59 -0700)
Introduces `lifetime` param to `dns.resolver.query`, allowing user to
specify a timeout for querying DNS resolvers instead of being stuck on
the hardcoded default.

This doesn't modify the `Resolver` _instance_ itself, so subsequent
calls to it, without specifying `lifetime` as a param, will honor the
default value configured in `reset` (see `self.lifetime =`).

dns/resolver.py

index 00323835d33bc7003c424a7bf02ed16356326a69..8bfbc4c863c378cb80b97992414d05dd61aab778 100644 (file)
@@ -780,7 +780,8 @@ class Resolver(object):
             except WindowsError:  # pylint: disable=undefined-variable
                 return False
 
-    def _compute_timeout(self, start):
+    def _compute_timeout(self, start, lifetime=None):
+        lifetime = self.lifetime if lifetime is None else lifetime
         now = time.time()
         duration = now - start
         if duration < 0:
@@ -792,12 +793,13 @@ class Resolver(object):
                 # happen, e.g. under vmware with older linux kernels.
                 # Pretend it didn't happen.
                 now = start
-        if duration >= self.lifetime:
+        if duration >= lifetime:
             raise Timeout(timeout=duration)
-        return min(self.lifetime - duration, self.timeout)
+        return min(lifetime - duration, self.timeout)
 
     def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-              tcp=False, source=None, raise_on_no_answer=True, source_port=0):
+              tcp=False, source=None, raise_on_no_answer=True, source_port=0,
+              lifetime=None):
         """Query nameservers to find the answer to the question.
 
         The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -820,6 +822,8 @@ class Resolver(object):
 
         *source_port*, an ``int``, the port from which to send the message.
 
+        *lifetime*, a ``float``, how long query should run before timing out.
+
         Raises ``dns.exception.Timeout`` if no answers could be found
         in the specified lifetime.
 
@@ -891,7 +895,7 @@ class Resolver(object):
                 if len(nameservers) == 0:
                     raise NoNameservers(request=request, errors=errors)
                 for nameserver in nameservers[:]:
-                    timeout = self._compute_timeout(start)
+                    timeout = self._compute_timeout(start, lifetime)
                     port = self.nameserver_ports.get(nameserver, self.port)
                     try:
                         tcp_attempt = tcp
@@ -908,7 +912,7 @@ class Resolver(object):
                             if response.flags & dns.flags.TC:
                                 # Response truncated; retry with TCP.
                                 tcp_attempt = True
-                                timeout = self._compute_timeout(start)
+                                timeout = self._compute_timeout(start, lifetime)
                                 response = \
                                     dns.query.tcp(request, nameserver,
                                                   timeout, port,
@@ -983,7 +987,7 @@ class Resolver(object):
                     # But we still have servers to try.  Sleep a bit
                     # so we don't pound them!
                     #
-                    timeout = self._compute_timeout(start)
+                    timeout = self._compute_timeout(start, lifetime)
                     sleep_time = min(timeout, backoff)
                     backoff *= 2
                     time.sleep(sleep_time)
@@ -1081,7 +1085,7 @@ def reset_default_resolver():
 
 def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
           tcp=False, source=None, raise_on_no_answer=True,
-          source_port=0):
+          source_port=0, lifetime=None):
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
@@ -1092,7 +1096,8 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
     """
 
     return get_default_resolver().query(qname, rdtype, rdclass, tcp, source,
-                                        raise_on_no_answer, source_port)
+                                        raise_on_no_answer, source_port,
+                                        lifetime)
 
 
 def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):