From 18b7473e6bd9b5c46753b005db8c85ffcde1f248 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Thu, 29 Feb 2024 15:19:58 -0800 Subject: [PATCH] Implement xfr() in terms of inbound_xfr(). (#1062) This moves the implementation of inbound_xfr() to an internal generator function, and implements both inbound_xfr() and xfr() using it. --- dns/query.py | 355 ++++++++++++++++++++++++--------------------------- 1 file changed, 169 insertions(+), 186 deletions(-) diff --git a/dns/query.py b/dns/query.py index 8f82ab67..ed92ee3a 100644 --- a/dns/query.py +++ b/dns/query.py @@ -29,7 +29,7 @@ import socket import struct import time import urllib.parse -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import dns._features import dns.exception @@ -1367,6 +1367,135 @@ def quic( return r +class UDPMode(enum.IntEnum): + """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? + + NEVER means "never use UDP; always use TCP" + TRY_FIRST means "try to use UDP but fall back to TCP if needed" + ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" + """ + + NEVER = 0 + TRY_FIRST = 1 + ONLY = 2 + + +def _inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, +) -> Any: + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager + for this transfer (typically a ``dns.zone.Zone``). + + *query*, the query to send. If not supplied, a default query is + constructed using information from the *txn_manager*. + + *port*, an ``int``, the port send the message to. The default is 53. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used + for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use + TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which + means "try UDP but fallback to TCP if needed", and + ``dns.UDPMode.ONLY``, which means "try UDP and raise + ``dns.xfr.UseTCP`` if it does not succeed. + + Raises on errors. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + with _make_socket(af, sock_type, source) as s: + _connect(s, destination, expiration) + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + (rwire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + try: + done = inbound.process_message(r) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + yield r + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + + def xfr( where: str, zone: Union[dns.name.Name, str], @@ -1436,134 +1565,42 @@ def xfr( Returns a generator of ``dns.message.Message`` objects. """ + class DummyTransactionManager(dns.transaction.TransactionManager): + def __init__(self, origin, relativize): + self.info = (origin, relativize, dns.name.empty if relativize else origin) + + def origin_information(self): + return self.info + + def writer(self, replacement: bool = False) -> dns.transaction.Transaction: + class DummyTransaction(object): + def nop(*args, **kw): + pass + + def __getattr__(self, _): + return self.nop + + return cast(dns.transaction.Transaction, DummyTransaction()) + if isinstance(zone, str): zone = dns.name.from_text(zone) rdtype = dns.rdatatype.RdataType.make(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: - rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) - q.authority.append(rrset) + rrset = q.find_rrset( + q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True + ) + soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial) + rrset.add(soa, 0) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) - wire = q.to_wire() - (af, destination, source) = _destination_and_source( - where, port, source, source_port - ) + tm = DummyTransactionManager(zone, relativize) if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError("cannot do a UDP AXFR") - sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM - with _make_socket(af, sock_type, source) as s: - (_, expiration) = _compute_times(lifetime) - _connect(s, destination, expiration) - l = len(wire) - if use_udp: - _udp_send(s, wire, None, expiration) - else: - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - done = False - delete_mode = True - expecting_SOA = False - soa_rrset = None - if relativize: - origin = zone - oname = dns.name.empty - else: - origin = None - oname = zone - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if use_udp: - (wire, _) = _udp_recv(s, 65535, mexpiration) - else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, mexpiration) - is_ixfr = rdtype == dns.rdatatype.IXFR - r = dns.message.from_wire( - wire, - keyring=q.keyring, - request_mac=q.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=True, - one_rr_per_rrset=is_ixfr, - ) - rcode = r.rcode() - if rcode != dns.rcode.NOERROR: - raise TransferError(rcode) - tsig_ctx = r.tsig_ctx - answer_index = 0 - if soa_rrset is None: - if not r.answer or r.answer[0].name != oname: - raise dns.exception.FormError("No answer or RRset not for qname") - rrset = r.answer[0] - if rrset.rdtype != dns.rdatatype.SOA: - raise dns.exception.FormError("first RRset is not an SOA") - answer_index = 1 - soa_rrset = rrset.copy() - if rdtype == dns.rdatatype.IXFR: - if dns.serial.Serial(soa_rrset[0].serial) <= serial: - # - # We're already up-to-date. - # - done = True - else: - expecting_SOA = True - # - # Process SOAs in the answer section (other than the initial - # SOA in the first message). - # - for rrset in r.answer[answer_index:]: - if done: - raise dns.exception.FormError("answers after final SOA") - if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: - if expecting_SOA: - if rrset[0].serial != serial: - raise dns.exception.FormError("IXFR base serial mismatch") - expecting_SOA = False - elif rdtype == dns.rdatatype.IXFR: - delete_mode = not delete_mode - # - # If this SOA RRset is equal to the first we saw then we're - # finished. If this is an IXFR we also check that we're - # seeing the record in the expected part of the response. - # - if rrset == soa_rrset and ( - rdtype == dns.rdatatype.AXFR - or (rdtype == dns.rdatatype.IXFR and delete_mode) - ): - done = True - elif expecting_SOA: - # - # We made an IXFR request and are expecting another - # SOA RR, but saw something else, so this must be an - # AXFR response. - # - rdtype = dns.rdatatype.AXFR - expecting_SOA = False - if done and q.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") - yield r - - -class UDPMode(enum.IntEnum): - """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? - - NEVER means "never use UDP; always use TCP" - TRY_FIRST means "try to use UDP but fall back to TCP if needed" - ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" - """ - - NEVER = 0 - TRY_FIRST = 1 - ONLY = 2 + udp_mode = UDPMode.ONLY if use_udp else UDPMode.NEVER + yield from _inbound_xfr( + where, tm, q, port, timeout, lifetime, source, source_port, udp_mode + ) def inbound_xfr( @@ -1613,69 +1650,15 @@ def inbound_xfr( Raises on errors. """ - if query is None: - (query, serial) = dns.xfr.make_query(txn_manager) - else: - serial = dns.xfr.extract_serial_from_query(query) - rdtype = query.question[0].rdtype - is_ixfr = rdtype == dns.rdatatype.IXFR - origin = txn_manager.from_wire_origin() - wire = query.to_wire() - (af, destination, source) = _destination_and_source( - where, port, source, source_port - ) - (_, expiration) = _compute_times(lifetime) - retry = True - while retry: - retry = False - if is_ixfr and udp_mode != UDPMode.NEVER: - sock_type = socket.SOCK_DGRAM - is_udp = True - else: - sock_type = socket.SOCK_STREAM - is_udp = False - with _make_socket(af, sock_type, source) as s: - _connect(s, destination, expiration) - if is_udp: - _udp_send(s, wire, None, expiration) - else: - tcpmsg = struct.pack("!H", len(wire)) + wire - _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: - done = False - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if is_udp: - (rwire, _) = _udp_recv(s, 65535, mexpiration) - else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - rwire = _net_read(s, l, mexpiration) - r = dns.message.from_wire( - rwire, - keyring=query.keyring, - request_mac=query.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr, - ) - try: - done = inbound.process_message(r) - except dns.xfr.UseTCP: - assert is_udp # should not happen if we used TCP! - if udp_mode == UDPMode.ONLY: - raise - done = True - retry = True - udp_mode = UDPMode.NEVER - continue - tsig_ctx = r.tsig_ctx - if not retry and query.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") + for msg in _inbound_xfr( + where, + txn_manager, + query, + port, + timeout, + lifetime, + source, + source_port, + udp_mode, + ): + pass -- 2.47.3