]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add make_resolver_at() and resolve_at(). (#926)
authorBob Halley <halley@dnspython.org>
Wed, 19 Apr 2023 15:51:16 +0000 (08:51 -0700)
committerGitHub <noreply@github.com>
Wed, 19 Apr 2023 15:51:16 +0000 (08:51 -0700)
dns/asyncresolver.py
dns/resolver.py
doc/async-resolver-functions.rst
doc/resolver-functions.rst
doc/whatsnew.rst
examples/async_dns.py
examples/query_specific.py
tests/test_async.py
tests/test_resolver.py

index aa8af7cdfe4c052b8ded3880a58b6b66c9c5c788..a78f2d68e4109290590ed82d8a4274a356a4c720 100644 (file)
@@ -394,3 +394,84 @@ async def zone_for_name(
             name = name.parent()
         except dns.name.NoParent:  # pragma: no cover
             raise NoRootSOA
+
+
+async def make_resolver_at(
+    where: Union[dns.name.Name, str],
+    port: int = 53,
+    family: int = socket.AF_UNSPEC,
+    resolver: Optional[Resolver] = None,
+) -> Resolver:
+    """Make a stub resolver using the specified destination as the full resolver.
+
+    *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
+    full resolver.
+
+    *port*, an ``int``, the port to use.  If not specified, the default is 53.
+
+    *family*, an ``int``, the address family to use.  This parameter is used if
+    *where* is not an address.  The default is ``socket.AF_UNSPEC`` in which case
+    the first address returned by ``resolve_name()`` will be used, otherwise the
+    first address of the specified family will be used.
+
+    *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
+    resolution of hostnames.  If not specified, the default resolver will be used.
+
+    Returns a ``dns.resolver.Resolver`` or raises an exception.
+    """
+    if resolver is None:
+        resolver = get_default_resolver()
+    nameservers = []
+    if isinstance(where, str) and dns.inet.is_address(where):
+        nameservers.append(dns.nameserver.Do53Nameserver(where, port))
+    else:
+        answers = await resolver.resolve_name(where, family)
+        for address in answers.addresses():
+            nameservers.append(dns.nameserver.Do53Nameserver(address, port))
+    res = dns.asyncresolver.Resolver(configure=False)
+    res.nameservers = nameservers
+    return res
+
+
+async def resolve_at(
+    where: Union[dns.name.Name, str],
+    qname: Union[dns.name.Name, str],
+    rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+    rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+    tcp: bool = False,
+    source: Optional[str] = None,
+    raise_on_no_answer: bool = True,
+    source_port: int = 0,
+    lifetime: Optional[float] = None,
+    search: Optional[bool] = None,
+    backend: Optional[dns.asyncbackend.Backend] = None,
+    port: int = 53,
+    family: int = socket.AF_UNSPEC,
+    resolver: Optional[Resolver] = None,
+) -> dns.resolver.Answer:
+    """Query nameservers to find the answer to the question.
+
+    This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
+    to make a resolver, and then uses it to resolve the query.
+
+    See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
+    parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
+    resolver parameters *where*, *port*, *family*, and *resolver*.
+
+    If making more than one query, it is more efficient to call
+    ``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
+    instead of calling ``resolve_at()`` multiple times.
+    """
+    res = await make_resolver_at(where, port, family, resolver)
+    return await res.resolve(
+        qname,
+        rdtype,
+        rdclass,
+        tcp,
+        source,
+        raise_on_no_answer,
+        source_port,
+        lifetime,
+        search,
+        backend,
+    )
index c28c137cac6fb480de40dde7c6468249e2b6652f..f12d9977b4905061ed22a14a8c440b404ba9341c 100644 (file)
@@ -1736,6 +1736,83 @@ def zone_for_name(
             raise NoRootSOA
 
 
+def make_resolver_at(
+    where: Union[dns.name.Name, str],
+    port: int = 53,
+    family: int = socket.AF_UNSPEC,
+    resolver: Optional[Resolver] = None,
+) -> Resolver:
+    """Make a stub resolver using the specified destination as the full resolver.
+
+    *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
+    full resolver.
+
+    *port*, an ``int``, the port to use.  If not specified, the default is 53.
+
+    *family*, an ``int``, the address family to use.  This parameter is used if
+    *where* is not an address.  The default is ``socket.AF_UNSPEC`` in which case
+    the first address returned by ``resolve_name()`` will be used, otherwise the
+    first address of the specified family will be used.
+
+    *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
+    resolution of hostnames.  If not specified, the default resolver will be used.
+
+    Returns a ``dns.resolver.Resolver`` or raises an exception.
+    """
+    if resolver is None:
+        resolver = get_default_resolver()
+    nameservers = []
+    if isinstance(where, str) and dns.inet.is_address(where):
+        nameservers.append(dns.nameserver.Do53Nameserver(where, port))
+    else:
+        for address in resolver.resolve_name(where, family).addresses():
+            nameservers.append(dns.nameserver.Do53Nameserver(address, port))
+    res = dns.resolver.Resolver(configure=False)
+    res.nameservers = nameservers
+    return res
+
+
+def resolve_at(
+    where: Union[dns.name.Name, str],
+    qname: Union[dns.name.Name, str],
+    rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+    rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+    tcp: bool = False,
+    source: Optional[str] = None,
+    raise_on_no_answer: bool = True,
+    source_port: int = 0,
+    lifetime: Optional[float] = None,
+    search: Optional[bool] = None,
+    port: int = 53,
+    family: int = socket.AF_UNSPEC,
+    resolver: Optional[Resolver] = None,
+) -> Answer:
+    """Query nameservers to find the answer to the question.
+
+    This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to
+    make a resolver, and then uses it to resolve the query.
+
+    See ``dns.resolver.Resolver.resolve`` for more information on the resolution
+    parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver
+    parameters *where*, *port*, *family*, and *resolver*.
+
+    If making more than one query, it is more efficient to call
+    ``dns.resolver.make_resolver_at()`` and then use that resolver for the queries
+    instead of calling ``resolve_at()`` multiple times.
+    """
+    return make_resolver_at(where, port, family, resolver).resolve(
+        qname,
+        rdtype,
+        rdclass,
+        tcp,
+        source,
+        raise_on_no_answer,
+        source_port,
+        lifetime,
+        search,
+    )
+
+
 #
 # Support for overriding the system resolver for all python code in the
 # running process.
index c79d58128f4baa3f43b446fa8d8d638a2be9bcfc..8e81dea29f41ebcf636b9df1a6527e8dd8c2af28 100644 (file)
@@ -9,6 +9,8 @@ Asynchronous Resolver Functions
 .. autofunction:: dns.asyncresolver.canonical_name
 .. autofunction:: dns.asyncresolver.try_ddr
 .. autofunction:: dns.asyncresolver.zone_for_name
+.. autofunction:: dns.asyncresolver.make_resolver_at
+.. autofunction:: dns.asyncresolver.resolve_at
 .. autodata:: dns.asyncresolver.default_resolver
 .. autofunction:: dns.asyncresolver.get_default_resolver
 .. autofunction:: dns.asyncresolver.reset_default_resolver
index 0399a0b97d0d63be3196ef709264a89cbd13c24b..531ca78a3a4f6e033a8a117bce62b49a8a143846 100644 (file)
@@ -10,6 +10,8 @@ Resolver Functions and The Default Resolver
 .. autofunction:: dns.resolver.try_ddr
 .. autofunction:: dns.resolver.zone_for_name
 .. autofunction:: dns.resolver.query
+.. autofunction:: dns.resolver.make_resolver_at
+.. autofunction:: dns.resolver.resolve_at
 .. autodata:: dns.resolver.default_resolver
 .. autofunction:: dns.resolver.get_default_resolver
 .. autofunction:: dns.resolver.reset_default_resolver
index 3d42c8db9e8f9c2f4519523a9e31ab69c6c30d7d..ac5548d259aba6355ffacc75f4abf68853b78f4b 100644 (file)
@@ -28,6 +28,10 @@ What's New in dnspython
   DNS-over-QUIC. This feature is currently experimental as the standard is still in
   draft stage.
 
+* The resolver and async resolver now have the ``make_resolver_at()`` and
+  ``resolve_at()`` functions, as a convenience for making queries to specific
+  recursive servers.
+
 * Curio support has been removed.
 
 2.3.0
index f7e3fe5d7df2c8c20e8de3f98896a208b6741e30..297afcb01fd3bcc0dc3f08a2c9a0e051755f2c82 100644 (file)
@@ -25,6 +25,10 @@ async def main():
     print(a.response)
     zn = await dns.asyncresolver.zone_for_name(host)
     print(zn)
+    answer = await dns.asyncresolver.resolve_at("8.8.8.8", "amazon.com", "NS")
+    print("The amazon.com nameservers are:")
+    for rr in answer:
+        print(rr.target)
 
 
 if __name__ == "__main__":
index 73dc35138e8f1aaed80e626ee09d5091821e0886..2f13b240495e71ad11046759a6d04bf3dfd52728 100644 (file)
@@ -25,16 +25,31 @@ for rr in ns_rrset:
 print("")
 print("")
 
-# A higher-level way
+# A higher-level way:
 
 import dns.resolver
 
-resolver = dns.resolver.Resolver(configure=False)
-resolver.nameservers = ["8.8.8.8"]
-answer = resolver.resolve("amazon.com", "NS")
+answer = dns.resolver.resolve_at("8.8.8.8", "amazon.com", "NS")
 print("The nameservers are:")
 for rr in answer:
     print(rr.target)
+print("")
+print("")
+
+# If you're going to make a bunch of queries to the server, make the resolver once
+# and then use it multiple times:
+
+res = dns.resolver.make_resolver_at("dns.google")
+answer = res.resolve("amazon.com", "NS")
+print("The amazon.com nameservers are:")
+for rr in answer:
+    print(rr.target)
+answer = res.resolve("google.com", "NS")
+print("The google.com nameservers are:")
+for rr in answer:
+    print(rr.target)
+print("")
+print("")
 
 # Sending a query with the all flags set to 0.  This is the easiest way
 # to make a query with the RD flag off.
index 5ae8854b586a003cc25444466694fd027a556b8b..d46f79e72aedeff949cbeeb3e1bae9235e6a75a4 100644 (file)
@@ -560,6 +560,28 @@ class AsyncTests(unittest.TestCase):
 
         self.async_run(run)
 
+    @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+    def testResolveAtAddress(self):
+        async def run():
+            answer = await dns.asyncresolver.resolve_at("8.8.8.8", "dns.google.", "A")
+            seen = set([rdata.address for rdata in answer])
+            self.assertIn("8.8.8.8", seen)
+            self.assertIn("8.8.4.4", seen)
+
+        self.async_run(run)
+
+    @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+    def testResolveAtName(self):
+        async def run():
+            answer = await dns.asyncresolver.resolve_at(
+                "dns.google", "dns.google.", "A", family=socket.AF_INET
+            )
+            seen = set([rdata.address for rdata in answer])
+            self.assertIn("8.8.8.8", seen)
+            self.assertIn("8.8.4.4", seen)
+
+        self.async_run(run)
+
     def testSleep(self):
         async def run():
             before = time.time()
index c73451e65ef1754f0b3ab92f6328a67ec225f1ce..903780863c2590cb30ba0b139f9b1fb3abc22d73 100644 (file)
@@ -767,6 +767,22 @@ class LiveResolverTests(unittest.TestCase):
         self.assertIn("94.140.14.14", seen)
         self.assertIn("94.140.15.15", seen)
 
+    @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+    def testResolveAtAddress(self):
+        answer = dns.resolver.resolve_at("8.8.8.8", "dns.google.", "A")
+        seen = set([rdata.address for rdata in answer])
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+
+    @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+    def testResolveAtName(self):
+        answer = dns.resolver.resolve_at(
+            "dns.google", "dns.google.", "A", family=socket.AF_INET
+        )
+        seen = set([rdata.address for rdata in answer])
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+
     def testCanonicalNameNoCNAME(self):
         cname = dns.name.from_text("www.google.com")
         self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname)