From ed42e237d19c4ae70036f24a55fa1e04bffc1a5a Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 18 Mar 2022 13:41:28 -0700 Subject: [PATCH] Use nullcontext for async code, as well. We can't use contextlib.nullcontext(), as it doesn't support async context managers until 3.10, but we can use dns._asyncbackend.NullContext. --- dns/asyncquery.py | 113 +++++++++++++++++++--------------------------- 1 file changed, 47 insertions(+), 66 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 28e124d7..885b8bf5 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Tuple, Union import base64 +import contextlib import socket import struct import time @@ -34,6 +35,7 @@ import dns.rdataclass import dns.rdatatype import dns.transaction +from dns._asyncbackend import NullContext from dns.query import ( _compute_times, _matches_destination, @@ -173,23 +175,20 @@ async def udp( """ wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - s = None - # After 3.6 is no longer supported, this can use an AsyncExitStack. - try: - af = dns.inet.af_for_address(where) - destination = _lltuple((where, port), af) - if sock: - s = sock + af = dns.inet.af_for_address(where) + destination = _lltuple((where, port), af) + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + if not backend: + backend = dns.asyncbackend.get_default_backend() + stuple = _source_tuple(af, source, source_port) + if backend.datagram_connection_required(): + dtuple = (where, port) else: - if not backend: - backend = dns.asyncbackend.get_default_backend() - stuple = _source_tuple(af, source, source_port) - if backend.datagram_connection_required(): - dtuple = (where, port) - else: - dtuple = None - s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) - assert s is not None + dtuple = None + cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) + async with cm as s: await send_udp(s, wire, destination, expiration) (r, received_time, _) = await receive_udp( s, @@ -206,9 +205,6 @@ async def udp( if not q.is_response(r): raise BadResponse return r - finally: - if not sock and s: - await s.close() async def udp_with_fallback( @@ -376,28 +372,24 @@ async def tcp( wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - s = None - # After 3.6 is no longer supported, this can use an AsyncExitStack. - try: - if sock: - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - await sock.getpeername() - s = sock - else: - # These are simple (address, port) pairs, not - # family-dependent tuples you pass to lowlevel socket - # code. - af = dns.inet.af_for_address(where) - stuple = _source_tuple(af, source, source_port) - dtuple = (where, port) - if not backend: - backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket( - af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout - ) - assert s is not None + if sock: + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + await sock.getpeername() + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + # These are simple (address, port) pairs, not family-dependent tuples + # you pass to low-level socket code. + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + cm = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout + ) + async with cm as s: await send_tcp(s, wire, expiration) (r, received_time) = await receive_tcp( s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing @@ -406,9 +398,6 @@ async def tcp( if not q.is_response(r): raise BadResponse return r - finally: - if not sock and s: - await s.close() async def tls( @@ -440,9 +429,10 @@ async def tls( See :py:func:`dns.query.tls()` for the documentation of the other parameters, exceptions, and return type of this method. """ - # After 3.6 is no longer supported, this can use an AsyncExitStack. (begin_time, expiration) = _compute_times(timeout) - if not sock: + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: if ssl_context is None: # See the comment about ssl.create_default_context() in query.py ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol] @@ -457,7 +447,7 @@ async def tls( dtuple = (where, port) if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket( + cm = await backend.make_socket( af, socket.SOCK_STREAM, 0, @@ -467,9 +457,7 @@ async def tls( ssl_context, server_hostname, ) - else: - s = sock - try: + async with cm as s: timeout = _timeout(expiration) response = await tcp( q, @@ -486,9 +474,6 @@ async def tls( end_time = time.time() response.time = end_time - begin_time return response - finally: - if not sock and s: - await s.close() async def https( @@ -537,15 +522,14 @@ async def https( if source is not None: transport = httpx.AsyncHTTPTransport(local_address=source[0]) - # After 3.6 is no longer supported, this can use an AsyncExitStack - client_to_close = None - try: - if not client: - client = httpx.AsyncClient( - http1=True, http2=_have_http2, verify=verify, transport=transport - ) - client_to_close = client + if client: + cm: contextlib.AbstractAsyncContextManager = NullContext(client) + else: + cm = httpx.AsyncClient( + http1=True, http2=_have_http2, verify=verify, transport=transport + ) + async with cm as the_client: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: @@ -555,18 +539,15 @@ async def https( "content-length": str(len(wire)), } ) - response = await client.post( + response = await the_client.post( url, headers=headers, content=wire, timeout=timeout ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await client.get( + response = await the_client.get( url, headers=headers, timeout=timeout, params={"dns": twire} ) - finally: - if client_to_close: - await client_to_close.aclose() # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes -- 2.47.3