from typing import Any, Dict, Optional, Tuple, Union
import base64
+import contextlib
import socket
import struct
import time
import dns.rdatatype
import dns.transaction
+from dns._asyncbackend import NullContext
from dns.query import (
_compute_times,
_matches_destination,
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
- s = None
- # After 3.6 is no longer supported, this can use an AsyncExitStack.
- try:
- af = dns.inet.af_for_address(where)
- destination = _lltuple((where, port), af)
- if sock:
- s = sock
+ af = dns.inet.af_for_address(where)
+ destination = _lltuple((where, port), af)
+ if sock:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ stuple = _source_tuple(af, source, source_port)
+ if backend.datagram_connection_required():
+ dtuple = (where, port)
else:
- if not backend:
- backend = dns.asyncbackend.get_default_backend()
- stuple = _source_tuple(af, source, source_port)
- if backend.datagram_connection_required():
- dtuple = (where, port)
- else:
- dtuple = None
- s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
- assert s is not None
+ dtuple = None
+ cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
+ async with cm as s:
await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(
s,
if not q.is_response(r):
raise BadResponse
return r
- finally:
- if not sock and s:
- await s.close()
async def udp_with_fallback(
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
- s = None
- # After 3.6 is no longer supported, this can use an AsyncExitStack.
- try:
- if sock:
- # Verify that the socket is connected, as if it's not connected,
- # it's not writable, and the polling in send_tcp() will time out or
- # hang forever.
- await sock.getpeername()
- s = sock
- else:
- # These are simple (address, port) pairs, not
- # family-dependent tuples you pass to lowlevel socket
- # code.
- af = dns.inet.af_for_address(where)
- stuple = _source_tuple(af, source, source_port)
- dtuple = (where, port)
- if not backend:
- backend = dns.asyncbackend.get_default_backend()
- s = await backend.make_socket(
- af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
- )
- assert s is not None
+ if sock:
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ await sock.getpeername()
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
+ # These are simple (address, port) pairs, not family-dependent tuples
+ # you pass to low-level socket code.
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ cm = await backend.make_socket(
+ af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
+ )
+ async with cm as s:
await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
if not q.is_response(r):
raise BadResponse
return r
- finally:
- if not sock and s:
- await s.close()
async def tls(
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- # After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout)
- if not sock:
+ if sock:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
if ssl_context is None:
# See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
- s = await backend.make_socket(
+ cm = await backend.make_socket(
af,
socket.SOCK_STREAM,
0,
ssl_context,
server_hostname,
)
- else:
- s = sock
- try:
+ async with cm as s:
timeout = _timeout(expiration)
response = await tcp(
q,
end_time = time.time()
response.time = end_time - begin_time
return response
- finally:
- if not sock and s:
- await s.close()
async def https(
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0])
- # After 3.6 is no longer supported, this can use an AsyncExitStack
- client_to_close = None
- try:
- if not client:
- client = httpx.AsyncClient(
- http1=True, http2=_have_http2, verify=verify, transport=transport
- )
- client_to_close = client
+ if client:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(client)
+ else:
+ cm = httpx.AsyncClient(
+ http1=True, http2=_have_http2, verify=verify, transport=transport
+ )
+ async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
"content-length": str(len(wire)),
}
)
- response = await client.post(
+ response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
- response = await client.get(
+ response = await the_client.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
)
- finally:
- if client_to_close:
- await client_to_close.aclose()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes