]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
initial implemention of async https query
authorBob Halley <halley@dnspython.org>
Wed, 17 Nov 2021 14:15:09 +0000 (06:15 -0800)
committerBob Halley <halley@dnspython.org>
Wed, 17 Nov 2021 14:15:09 +0000 (06:15 -0800)
dns/asyncquery.py
dns/query.py

index deeff2741926cfda1e2523d8736a1a251ac62457..58766e36497a9fcd26108615b79d4b601b2c3e71 100644 (file)
 
 """Talk to a DNS server."""
 
+import base64
 import socket
 import struct
 import time
+import urllib
 
 import dns.asyncbackend
 import dns.exception
@@ -31,8 +33,10 @@ import dns.rdataclass
 import dns.rdatatype
 
 from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
-    UDPMode
+    UDPMode, have_doh, _have_httpx, _have_http2, NoDOH
 
+if _have_httpx:
+    import httpx
 
 # for brevity
 _lltuple = dns.inet.low_level_address_tuple
@@ -354,6 +358,88 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
         if not sock and s:
             await s.close()
 
+async def https(q, where, timeout=None, port=443, source=None, source_port=0,
+                one_rr_per_rrset=False, ignore_trailing=False, client=None,
+                path='/dns-query', post=True, verify=True, backend=None):
+    """Return the response obtained after sending a query via DNS-over-HTTPS.
+
+    *client*, a ``httpx.AsyncClient``.  If provided, the client to use for
+    the query.
+
+    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+    the default, then dnspython will use the default backend.
+
+    See :py:func:`dns.query.https()` for the documentation of the other
+    parameters, exceptions, and return type of this method.
+    """
+
+    if not _have_httpx:
+        raise NoDOH('httpx is not available.')  # pragma: no cover
+
+    _httpx_ok = True
+
+    wire = q.to_wire()
+    try:
+        af = dns.inet.af_for_address(where)
+    except ValueError:
+        af = None
+    transport = None
+    headers = {
+        "accept": "application/dns-message"
+    }
+    if af is not None:
+        if af == socket.AF_INET:
+            url = 'https://{}:{}{}'.format(where, port, path)
+        elif af == socket.AF_INET6:
+            url = 'https://[{}]:{}{}'.format(where, port, path)
+    else:
+        url = where
+    if source is not None:
+        transport = httpx.AsyncHTTPTransport(local_address=source[0])
+
+    # After 3.6 is no longer supported, this can use an AsyncExitStack
+    client_to_close = None
+    try:
+        if not client:
+            client = httpx.AsyncClient(http1=True, http2=_have_http2,
+                                       verify=verify, transport=transport)
+            client_to_close = client
+
+        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+        # GET and POST examples
+        if post:
+            headers.update({
+                "content-type": "application/dns-message",
+                "content-length": str(len(wire))
+            })
+            response = await client.post(url, headers=headers, content=wire,
+                                         timeout=timeout)
+        else:
+            wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
+            wire = wire.decode()  # httpx does a repr() if we give it bytes
+            response = await client.get(url, headers=headers, timeout=timeout,
+                                        params={"dns": wire})
+    finally:
+        if client_to_close:
+            await client.aclose()
+
+    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+    # status codes
+    if response.status_code < 200 or response.status_code > 299:
+        raise ValueError('{} responded with status code {}'
+                         '\nResponse body: {}'.format(where,
+                                                      response.status_code,
+                                                      response.content))
+    r = dns.message.from_wire(response.content,
+                              keyring=q.keyring,
+                              request_mac=q.request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing)
+    r.time = response.elapsed
+    if not q.is_response(r):
+        raise BadResponse
+    return r
+
 async def inbound_xfr(where, txn_manager, query=None,
                       port=53, timeout=None, lifetime=None, source=None,
                       source_port=0, udp_mode=UDPMode.NEVER, backend=None):
index 314d8d83d9ac2781296a6edaf4cb3214c745269f..fbf76d8bcce49d2a67cbf46cb2420c0deb37e12a 100644 (file)
@@ -274,8 +274,8 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
     junk at end of the received message.
 
-    *session*, a ``requests.session.Session``.  If provided, the session to use
-    to send the queries.
+    *session*, an ``httpx.Client`` or ``requests.session.Session``.  If
+    provided, the client/session to use to send the queries.
 
     *path*, a ``str``. If *where* is an IP address, then *path* will be used to
     construct the URL to send the DNS query to.