From: Brian Wellington Date: Sun, 18 Aug 2024 13:54:16 +0000 (-0700) Subject: Refactor xfr. (#1122) X-Git-Tag: v2.7.0rc1~23 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=e03819b1cee50bc2813141db2becb801447db716;p=thirdparty%2Fdnspython.git Refactor xfr. (#1122) * Refactor xfr. Internally refactors the zone transfer code to separate the message processing from the socket management, allowing the (internal) callers to pass a socket in. This should allow a future interface that accepts a socket, which would mean that xfr over DoT would just work, and xfr over DoQ would be closer to working. Adds some necessary functionality to the asyncbackend Socket class to allow the async zone transfer code to be more similar to the sync code (specifically, adds a type field to Socket, and updates the trio backend to connect UDP sockets when requested). In asyncquery.py, reorder the inbound_xfr() and quic() methods for consistency. * Run black. * Fix typing. --- diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index 49f14fed..f6760fd0 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -26,6 +26,10 @@ class NullContext: class Socket: # pragma: no cover + def __init__(self, family: int, type: int): + self.family = family + self.type = type + async def close(self): pass @@ -46,9 +50,6 @@ class Socket: # pragma: no cover class DatagramSocket(Socket): # pragma: no cover - def __init__(self, family: int): - self.family = family - async def sendto(self, what, destination, timeout): raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 9d9ed369..de18c401 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout): class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, transport, protocol): - super().__init__(family) + super().__init__(family, socket.SOCK_DGRAM) self.transport = transport self.protocol = protocol @@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, af, reader, writer): - self.family = af + super().__init__(af, socket.SOCK_STREAM) self.reader = reader self.writer = writer diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 398e3276..1d2bdda9 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, socket): - super().__init__(socket.family) - self.socket = socket + def __init__(self, sock): + super().__init__(sock.family, socket.SOCK_DGRAM) + self.socket = sock async def sendto(self, what, destination, timeout): with _maybe_timeout(timeout): - return await self.socket.sendto(what, destination) + if destination is None: + return await self.socket.send(what) + else: + return await self.socket.sendto(what, destination) raise dns.exception.Timeout( timeout=timeout ) # pragma: no cover lgtm[py/unreachable-statement] @@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, family, stream, tls=False): - self.family = family + super().__init__(family, socket.SOCK_STREAM) self.stream = stream self.tls = tls @@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend): try: if source: await s.bind(_lltuple(source, af)) - if socktype == socket.SOCK_STREAM: + if socktype == socket.SOCK_STREAM or destination is not None: connected = False with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 717f43b4..622c9d52 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -24,7 +24,7 @@ import socket import struct import time import urllib.parse -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, cast import dns.asyncbackend import dns.exception @@ -716,107 +716,6 @@ async def _http3( return r -async def inbound_xfr( - where: str, - txn_manager: dns.transaction.TransactionManager, - query: Optional[dns.message.Message] = None, - port: int = 53, - timeout: Optional[float] = None, - lifetime: Optional[float] = None, - source: Optional[str] = None, - source_port: int = 0, - udp_mode: UDPMode = UDPMode.NEVER, - backend: Optional[dns.asyncbackend.Backend] = None, -) -> None: - """Conduct an inbound transfer and apply it via a transaction from the - txn_manager. - - *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, - the default, then dnspython will use the default backend. - - See :py:func:`dns.query.inbound_xfr()` for the documentation of - the other parameters, exceptions, and return type of this method. - """ - if query is None: - (query, serial) = dns.xfr.make_query(txn_manager) - else: - serial = dns.xfr.extract_serial_from_query(query) - rdtype = query.question[0].rdtype - is_ixfr = rdtype == dns.rdatatype.IXFR - origin = txn_manager.from_wire_origin() - wire = query.to_wire() - af = dns.inet.af_for_address(where) - stuple = _source_tuple(af, source, source_port) - dtuple = (where, port) - (_, expiration) = _compute_times(lifetime) - retry = True - while retry: - retry = False - if is_ixfr and udp_mode != UDPMode.NEVER: - sock_type = socket.SOCK_DGRAM - is_udp = True - else: - sock_type = socket.SOCK_STREAM - is_udp = False - if not backend: - backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket( - af, sock_type, 0, stuple, dtuple, _timeout(expiration) - ) - async with s: - if is_udp: - await s.sendto(wire, dtuple, _timeout(expiration)) - else: - tcpmsg = struct.pack("!H", len(wire)) + wire - await s.sendall(tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: - done = False - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if is_udp: - destination = _lltuple((where, port), af) - while True: - timeout = _timeout(mexpiration) - (rwire, from_address) = await s.recvfrom(65535, timeout) - if _matches_destination( - af, from_address, destination, True - ): - break - else: - ldata = await _read_exactly(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - rwire = await _read_exactly(s, l, mexpiration) - is_ixfr = rdtype == dns.rdatatype.IXFR - r = dns.message.from_wire( - rwire, - keyring=query.keyring, - request_mac=query.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr, - ) - try: - done = inbound.process_message(r) - except dns.xfr.UseTCP: - assert is_udp # should not happen if we used TCP! - if udp_mode == UDPMode.ONLY: - raise - done = True - retry = True - udp_mode = UDPMode.NEVER - continue - tsig_ctx = r.tsig_ctx - if not retry and query.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") - - async def quic( q: dns.message.Message, where: str, @@ -883,3 +782,112 @@ async def quic( if not q.is_response(r): raise BadResponse return r + + +async def _inbound_xfr( + txn_manager: dns.transaction.TransactionManager, + s: dns.asyncbackend.Socket, + query: dns.message.Message, + serial: Optional[int], + timeout: Optional[float], + expiration: float, +) -> Any: + """Given a socket, does the zone transfer.""" + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + is_udp = s.type == socket.SOCK_DGRAM + if is_udp: + udp_sock = cast(dns.asyncbackend.DatagramSocket, s) + await udp_sock.sendto(wire, None, _timeout(expiration)) + else: + tcp_sock = cast(dns.asyncbackend.StreamSocket, s) + tcpmsg = struct.pack("!H", len(wire)) + wire + await tcp_sock.sendall(tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + timeout = _timeout(mexpiration) + (rwire, _) = await udp_sock.recvfrom(65535, timeout) + else: + ldata = await _read_exactly(tcp_sock, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = await _read_exactly(tcp_sock, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + done = inbound.process_message(r) + yield r + tsig_ctx = r.tsig_ctx + if query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + + +async def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> None: + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.inbound_xfr()` for the documentation of + the other parameters, exceptions, and return type of this method. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + (_, expiration) = _compute_times(lifetime) + if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER: + s = await backend.make_socket( + af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration) + ) + async with s: + try: + async for _ in _inbound_xfr( + txn_manager, s, query, serial, timeout, expiration + ): + pass + return + except dns.xfr.UseTCP: + if udp_mode == UDPMode.ONLY: + raise + pass + + s = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration) + ) + async with s: + async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): + pass diff --git a/dns/query.py b/dns/query.py index 050aeca4..8e21ed2f 100644 --- a/dns/query.py +++ b/dns/query.py @@ -1405,119 +1405,54 @@ class UDPMode(enum.IntEnum): def _inbound_xfr( - where: str, txn_manager: dns.transaction.TransactionManager, - query: Optional[dns.message.Message] = None, - port: int = 53, - timeout: Optional[float] = None, - lifetime: Optional[float] = None, - source: Optional[str] = None, - source_port: int = 0, - udp_mode: UDPMode = UDPMode.NEVER, + s: socket.socket, + query: dns.message.Message, + serial: Optional[int], + timeout: Optional[float], + expiration: float, ) -> Any: - """Conduct an inbound transfer and apply it via a transaction from the - txn_manager. - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager - for this transfer (typically a ``dns.zone.Zone``). - - *query*, the query to send. If not supplied, a default query is - constructed using information from the *txn_manager*. - - *port*, an ``int``, the port send the message to. The default is 53. - - *timeout*, a ``float``, the number of seconds to wait for each - response message. If None, the default, wait forever. - - *lifetime*, a ``float``, the total number of seconds to spend - doing the transfer. If ``None``, the default, then there is no - limit on the time the transfer may take. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used - for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use - TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which - means "try UDP but fallback to TCP if needed", and - ``dns.UDPMode.ONLY``, which means "try UDP and raise - ``dns.xfr.UseTCP`` if it does not succeed. - - Raises on errors. - """ - if query is None: - (query, serial) = dns.xfr.make_query(txn_manager) - else: - serial = dns.xfr.extract_serial_from_query(query) + """Given a socket, does the zone transfer.""" rdtype = query.question[0].rdtype is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() wire = query.to_wire() - (af, destination, source) = _destination_and_source( - where, port, source, source_port - ) - (_, expiration) = _compute_times(lifetime) - retry = True - while retry: - retry = False - if is_ixfr and udp_mode != UDPMode.NEVER: - sock_type = socket.SOCK_DGRAM - is_udp = True - else: - sock_type = socket.SOCK_STREAM - is_udp = False - with _make_socket(af, sock_type, source) as s: - _connect(s, destination, expiration) + is_udp = s.type == socket.SOCK_DGRAM + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration if is_udp: - _udp_send(s, wire, None, expiration) + (rwire, _) = _udp_recv(s, 65535, mexpiration) else: - tcpmsg = struct.pack("!H", len(wire)) + wire - _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: - done = False - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if is_udp: - (rwire, _) = _udp_recv(s, 65535, mexpiration) - else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - rwire = _net_read(s, l, mexpiration) - r = dns.message.from_wire( - rwire, - keyring=query.keyring, - request_mac=query.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr, - ) - try: - done = inbound.process_message(r) - except dns.xfr.UseTCP: - assert is_udp # should not happen if we used TCP! - if udp_mode == UDPMode.ONLY: - raise - done = True - retry = True - udp_mode = UDPMode.NEVER - continue - yield r - tsig_ctx = r.tsig_ctx - if not retry and query.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + done = inbound.process_message(r) + yield r + tsig_ctx = r.tsig_ctx + if query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") def xfr( @@ -1624,13 +1559,17 @@ def xfr( rrset.add(soa, 0) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + (_, expiration) = _compute_times(lifetime) tm = DummyTransactionManager(zone, relativize) if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError("cannot do a UDP AXFR") - udp_mode = UDPMode.ONLY if use_udp else UDPMode.NEVER - yield from _inbound_xfr( - where, tm, q, port, timeout, lifetime, source, source_port, udp_mode - ) + sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM + with _make_socket(af, sock_type, source) as s: + _connect(s, destination, expiration) + yield from _inbound_xfr(tm, s, q, serial, timeout, expiration) def inbound_xfr( @@ -1680,15 +1619,30 @@ def inbound_xfr( Raises on errors. """ - for _ in _inbound_xfr( - where, - txn_manager, - query, - port, - timeout, - lifetime, - source, - source_port, - udp_mode, - ): - pass + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + (_, expiration) = _compute_times(lifetime) + if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER: + with _make_socket(af, socket.SOCK_DGRAM, source) as s: + _connect(s, destination, expiration) + try: + for _ in _inbound_xfr( + txn_manager, s, query, serial, timeout, expiration + ): + pass + return + except dns.xfr.UseTCP: + if udp_mode == UDPMode.ONLY: + raise + pass + + with _make_socket(af, socket.SOCK_STREAM, source) as s: + _connect(s, destination, expiration) + for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): + pass