]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use nullcontext for async code, as well. 794/head
authorBrian Wellington <bwelling@xbill.org>
Fri, 18 Mar 2022 20:41:28 +0000 (13:41 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 18 Mar 2022 20:41:28 +0000 (13:41 -0700)
We can't use contextlib.nullcontext(), as it doesn't support async
context managers until 3.10, but we can use
dns._asyncbackend.NullContext.

dns/asyncquery.py

index 28e124d704c35a07741f35b7991623a009f6d04a..885b8bf5e9ea84e5028e91e32738e206e3e70310 100644 (file)
@@ -20,6 +20,7 @@
 from typing import Any, Dict, Optional, Tuple, Union
 
 import base64
+import contextlib
 import socket
 import struct
 import time
@@ -34,6 +35,7 @@ import dns.rdataclass
 import dns.rdatatype
 import dns.transaction
 
+from dns._asyncbackend import NullContext
 from dns.query import (
     _compute_times,
     _matches_destination,
@@ -173,23 +175,20 @@ async def udp(
     """
     wire = q.to_wire()
     (begin_time, expiration) = _compute_times(timeout)
-    s = None
-    # After 3.6 is no longer supported, this can use an AsyncExitStack.
-    try:
-        af = dns.inet.af_for_address(where)
-        destination = _lltuple((where, port), af)
-        if sock:
-            s = sock
+    af = dns.inet.af_for_address(where)
+    destination = _lltuple((where, port), af)
+    if sock:
+        cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+    else:
+        if not backend:
+            backend = dns.asyncbackend.get_default_backend()
+        stuple = _source_tuple(af, source, source_port)
+        if backend.datagram_connection_required():
+            dtuple = (where, port)
         else:
-            if not backend:
-                backend = dns.asyncbackend.get_default_backend()
-            stuple = _source_tuple(af, source, source_port)
-            if backend.datagram_connection_required():
-                dtuple = (where, port)
-            else:
-                dtuple = None
-            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
-        assert s is not None
+            dtuple = None
+        cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
+    async with cm as s:
         await send_udp(s, wire, destination, expiration)
         (r, received_time, _) = await receive_udp(
             s,
@@ -206,9 +205,6 @@ async def udp(
         if not q.is_response(r):
             raise BadResponse
         return r
-    finally:
-        if not sock and s:
-            await s.close()
 
 
 async def udp_with_fallback(
@@ -376,28 +372,24 @@ async def tcp(
 
     wire = q.to_wire()
     (begin_time, expiration) = _compute_times(timeout)
-    s = None
-    # After 3.6 is no longer supported, this can use an AsyncExitStack.
-    try:
-        if sock:
-            # Verify that the socket is connected, as if it's not connected,
-            # it's not writable, and the polling in send_tcp() will time out or
-            # hang forever.
-            await sock.getpeername()
-            s = sock
-        else:
-            # These are simple (address, port) pairs, not
-            # family-dependent tuples you pass to lowlevel socket
-            # code.
-            af = dns.inet.af_for_address(where)
-            stuple = _source_tuple(af, source, source_port)
-            dtuple = (where, port)
-            if not backend:
-                backend = dns.asyncbackend.get_default_backend()
-            s = await backend.make_socket(
-                af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
-            )
-        assert s is not None
+    if sock:
+        # Verify that the socket is connected, as if it's not connected,
+        # it's not writable, and the polling in send_tcp() will time out or
+        # hang forever.
+        await sock.getpeername()
+        cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+    else:
+        # These are simple (address, port) pairs, not family-dependent tuples
+        # you pass to low-level socket code.
+        af = dns.inet.af_for_address(where)
+        stuple = _source_tuple(af, source, source_port)
+        dtuple = (where, port)
+        if not backend:
+            backend = dns.asyncbackend.get_default_backend()
+        cm = await backend.make_socket(
+            af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
+        )
+    async with cm as s:
         await send_tcp(s, wire, expiration)
         (r, received_time) = await receive_tcp(
             s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
@@ -406,9 +398,6 @@ async def tcp(
         if not q.is_response(r):
             raise BadResponse
         return r
-    finally:
-        if not sock and s:
-            await s.close()
 
 
 async def tls(
@@ -440,9 +429,10 @@ async def tls(
     See :py:func:`dns.query.tls()` for the documentation of the other
     parameters, exceptions, and return type of this method.
     """
-    # After 3.6 is no longer supported, this can use an AsyncExitStack.
     (begin_time, expiration) = _compute_times(timeout)
-    if not sock:
+    if sock:
+        cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+    else:
         if ssl_context is None:
             # See the comment about ssl.create_default_context() in query.py
             ssl_context = ssl.create_default_context()  # lgtm[py/insecure-protocol]
@@ -457,7 +447,7 @@ async def tls(
         dtuple = (where, port)
         if not backend:
             backend = dns.asyncbackend.get_default_backend()
-        s = await backend.make_socket(
+        cm = await backend.make_socket(
             af,
             socket.SOCK_STREAM,
             0,
@@ -467,9 +457,7 @@ async def tls(
             ssl_context,
             server_hostname,
         )
-    else:
-        s = sock
-    try:
+    async with cm as s:
         timeout = _timeout(expiration)
         response = await tcp(
             q,
@@ -486,9 +474,6 @@ async def tls(
         end_time = time.time()
         response.time = end_time - begin_time
         return response
-    finally:
-        if not sock and s:
-            await s.close()
 
 
 async def https(
@@ -537,15 +522,14 @@ async def https(
     if source is not None:
         transport = httpx.AsyncHTTPTransport(local_address=source[0])
 
-    # After 3.6 is no longer supported, this can use an AsyncExitStack
-    client_to_close = None
-    try:
-        if not client:
-            client = httpx.AsyncClient(
-                http1=True, http2=_have_http2, verify=verify, transport=transport
-            )
-            client_to_close = client
+    if client:
+        cm: contextlib.AbstractAsyncContextManager = NullContext(client)
+    else:
+        cm = httpx.AsyncClient(
+            http1=True, http2=_have_http2, verify=verify, transport=transport
+        )
 
+    async with cm as the_client:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
         # GET and POST examples
         if post:
@@ -555,18 +539,15 @@ async def https(
                     "content-length": str(len(wire)),
                 }
             )
-            response = await client.post(
+            response = await the_client.post(
                 url, headers=headers, content=wire, timeout=timeout
             )
         else:
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
             twire = wire.decode()  # httpx does a repr() if we give it bytes
-            response = await client.get(
+            response = await the_client.get(
                 url, headers=headers, timeout=timeout, params={"dns": twire}
             )
-    finally:
-        if client_to_close:
-            await client_to_close.aclose()
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
     # status codes