]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Allow resolver-level control over the defaulting of search (default False).
authorBob Halley <halley@dnspython.org>
Fri, 15 May 2020 18:38:03 +0000 (11:38 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 15 May 2020 18:38:03 +0000 (11:38 -0700)
dns/resolver.py
doc/resolver-class.rst
tests/test_resolver.py

index 3c250b0c216434cd5b75b64bf764461af1582e7e..e60e2bec72259afd219393835de2df2b21d665a8 100644 (file)
@@ -490,10 +490,13 @@ class LRUCache(object):
                     node = next
                 self.data = {}
 
-
 class Resolver(object):
     """DNS stub resolver."""
 
+    # We initialize in reset()
+    #
+    # pylint: disable=attribute-defined-outside-init
+
     def __init__(self, filename='/etc/resolv.conf', configure=True):
         """*filename*, a ``str`` or file object, specifying a file
         in standard /etc/resolv.conf format.  This parameter is meaningful
@@ -506,25 +509,6 @@ class Resolver(object):
         on Windows systems.)
         """
 
-        self.domain = None
-        self.nameservers = []
-        self.nameserver_ports = None
-        self.port = None
-        self.search = None
-        self.timeout = None
-        self.lifetime = None
-        self.keyring = None
-        self.keyname = None
-        self.keyalgorithm = None
-        self.edns = None
-        self.ednsflags = None
-        self.payload = None
-        self.cache = None
-        self.flags = None
-        self.retry_servfail = False
-        self.rotate = False
-        self.ndots = None
-
         self.reset()
         if configure:
             if sys.platform == 'win32':
@@ -543,6 +527,7 @@ class Resolver(object):
         self.nameserver_ports = {}
         self.port = 53
         self.search = []
+        self.use_search_by_default = False
         self.timeout = 2.0
         self.lifetime = 30.0
         self.keyring = None
@@ -809,6 +794,8 @@ class Resolver(object):
     def _get_qnames_to_try(self, qname, search):
         # This is a separate method so we can unit test the search
         # rules without requiring the Internet.
+        if search is None:
+            search = self.use_search_by_default
         qnames_to_try = []
         if qname.is_absolute():
             qnames_to_try.append(qname)
@@ -825,7 +812,7 @@ class Resolver(object):
 
     def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
                 tcp=False, source=None, raise_on_no_answer=True, source_port=0,
-                lifetime=None, search=False):
+                lifetime=None, search=None):
         """Query nameservers to find the answer to the question.
 
         The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -851,9 +838,10 @@ class Resolver(object):
         *lifetime*, a ``float``, how many seconds a query should run
          before timing out.
 
-        *search*, a ``bool``, determines whether search lists configured
-        in the system's resolver configuration are used.  The default is
-        ``False``.
+        *search*, a ``bool`` or ``None``, determines whether the search
+        list configured in the system's resolver configuration are
+        used.  The default is ``None``, which causes the value of
+        the resolver's ``use_search_by_default`` attribute to be used.
 
         Raises ``dns.exception.Timeout`` if no answers could be found
         in the specified lifetime.
@@ -1181,7 +1169,7 @@ def reset_default_resolver():
 
 def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
             tcp=False, source=None, raise_on_no_answer=True,
-            source_port=0, lifetime=None, search=False):
+            source_port=0, lifetime=None, search=None):
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
@@ -1308,8 +1296,8 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
     # Something needs resolution!
     try:
         if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
-            v6 = _resolver.query(host, dns.rdatatype.AAAA,
-                                 raise_on_no_answer=False)
+            v6 = _resolver.resolve(host, dns.rdatatype.AAAA,
+                                   raise_on_no_answer=False)
             # Note that setting host ensures we query the same name
             # for A as we did for AAAA.
             host = v6.qname
@@ -1318,8 +1306,8 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
                 for rdata in v6.rrset:
                     v6addrs.append(rdata.address)
         if family == socket.AF_INET or family == socket.AF_UNSPEC:
-            v4 = _resolver.query(host, dns.rdatatype.A,
-                                 raise_on_no_answer=False)
+            v4 = _resolver.resolve(host, dns.rdatatype.A,
+                                   raise_on_no_answer=False)
             host = v4.qname
             canonical_name = v4.canonical_name.to_text(True)
             if v4.rrset is not None:
@@ -1394,7 +1382,7 @@ def _getnameinfo(sockaddr, flags=0):
     qname = dns.reversename.from_address(addr)
     if flags & socket.NI_NUMERICHOST == 0:
         try:
-            answer = _resolver.query(qname, 'PTR')
+            answer = _resolver.resolve(qname, 'PTR')
             hostname = answer.rrset[0].target.to_text(True)
         except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
             if flags & socket.NI_NAMEREQD:
index 58ab332ed96b10614d912c7590d3aa2be6e1582e..99cf70e317108bd0a968a391742ff06187dc6842 100644 (file)
@@ -20,6 +20,13 @@ The dns.resolver.Resolver and dns.resolver.Answer Classes
       relative name, the resolver will construct absolute query names
       to try by appending values from the search list.
 
+   .. attribute:: use_search_by_default
+
+      A ``bool``, specifes whether or not ``resolve()`` uses the
+      search list configured in the system's resolver configuration
+      when the ``search`` parameter to ``resolve()`` is ``None``.  The
+      default is ``False``.
+
    .. attribute:: port
 
       An ``int``, the default DNS port to send to if not overriden by
index f6ad7624f9eaa51d5ff45d51719ff0428391b9f1..373ce606605bd746546825ffd64b67909457b29a 100644 (file)
@@ -264,7 +264,7 @@ class BaseResolverTests(unittest.TestCase):
         for a in answer:
             pass
 
-    def testSearchLists(self):
+    def testSearchListsRelative(self):
         res = dns.resolver.Resolver()
         res.domain = dns.name.from_text('example')
         res.search = [dns.name.from_text(x) for x in
@@ -277,11 +277,27 @@ class BaseResolverTests(unittest.TestCase):
         qnames = res._get_qnames_to_try(qname, False)
         self.assertEqual(qnames,
                          [dns.name.from_text('www.example.')])
+        qnames = res._get_qnames_to_try(qname, None)
+        self.assertEqual(qnames,
+                         [dns.name.from_text('www.example.')])
+        #
+        # Now change search default on resolver to True
+        #
+        res.use_search_by_default = True
+        qnames = res._get_qnames_to_try(qname, None)
+        self.assertEqual(qnames,
+                         [dns.name.from_text(x) for x in
+                          ['www.dnspython.org', 'www.dnspython.net']])
+
+    def testSearchListsAbsolute(self):
+        res = dns.resolver.Resolver()
         qname = dns.name.from_text('absolute')
         qnames = res._get_qnames_to_try(qname, True)
         self.assertEqual(qnames, [qname])
         qnames = res._get_qnames_to_try(qname, False)
         self.assertEqual(qnames, [qname])
+        qnames = res._get_qnames_to_try(qname, None)
+        self.assertEqual(qnames, [qname])
 
 class PollingMonkeyPatchMixin(object):
     def setUp(self):