]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Implement xfr() in terms of inbound_xfr(). (#1062)
authorBrian Wellington <bwelling@xbill.org>
Thu, 29 Feb 2024 23:19:58 +0000 (15:19 -0800)
committerGitHub <noreply@github.com>
Thu, 29 Feb 2024 23:19:58 +0000 (15:19 -0800)
This moves the implementation of inbound_xfr() to an internal generator
function, and implements both inbound_xfr() and xfr() using it.

dns/query.py

index 8f82ab676789c46ae137ca25a6409fb86476e70b..ed92ee3a51eb677f58aa27f3e3e2bf920a1d1d2b 100644 (file)
@@ -29,7 +29,7 @@ import socket
 import struct
 import time
 import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
 
 import dns._features
 import dns.exception
@@ -1367,6 +1367,135 @@ def quic(
     return r
 
 
+class UDPMode(enum.IntEnum):
+    """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
+
+    NEVER means "never use UDP; always use TCP"
+    TRY_FIRST means "try to use UDP but fall back to TCP if needed"
+    ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
+    """
+
+    NEVER = 0
+    TRY_FIRST = 1
+    ONLY = 2
+
+
+def _inbound_xfr(
+    where: str,
+    txn_manager: dns.transaction.TransactionManager,
+    query: Optional[dns.message.Message] = None,
+    port: int = 53,
+    timeout: Optional[float] = None,
+    lifetime: Optional[float] = None,
+    source: Optional[str] = None,
+    source_port: int = 0,
+    udp_mode: UDPMode = UDPMode.NEVER,
+) -> Any:
+    """Conduct an inbound transfer and apply it via a transaction from the
+    txn_manager.
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
+    for this transfer (typically a ``dns.zone.Zone``).
+
+    *query*, the query to send.  If not supplied, a default query is
+    constructed using information from the *txn_manager*.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *timeout*, a ``float``, the number of seconds to wait for each
+    response message.  If None, the default, wait forever.
+
+    *lifetime*, a ``float``, the total number of seconds to spend
+    doing the transfer.  If ``None``, the default, then there is no
+    limit on the time the transfer may take.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
+    for IXFRs.  The default is ``dns.UDPMode.NEVER``, i.e. only use
+    TCP.  Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
+    means "try UDP but fallback to TCP if needed", and
+    ``dns.UDPMode.ONLY``, which means "try UDP and raise
+    ``dns.xfr.UseTCP`` if it does not succeed.
+
+    Raises on errors.
+    """
+    if query is None:
+        (query, serial) = dns.xfr.make_query(txn_manager)
+    else:
+        serial = dns.xfr.extract_serial_from_query(query)
+    rdtype = query.question[0].rdtype
+    is_ixfr = rdtype == dns.rdatatype.IXFR
+    origin = txn_manager.from_wire_origin()
+    wire = query.to_wire()
+    (af, destination, source) = _destination_and_source(
+        where, port, source, source_port
+    )
+    (_, expiration) = _compute_times(lifetime)
+    retry = True
+    while retry:
+        retry = False
+        if is_ixfr and udp_mode != UDPMode.NEVER:
+            sock_type = socket.SOCK_DGRAM
+            is_udp = True
+        else:
+            sock_type = socket.SOCK_STREAM
+            is_udp = False
+        with _make_socket(af, sock_type, source) as s:
+            _connect(s, destination, expiration)
+            if is_udp:
+                _udp_send(s, wire, None, expiration)
+            else:
+                tcpmsg = struct.pack("!H", len(wire)) + wire
+                _net_write(s, tcpmsg, expiration)
+            with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+                done = False
+                tsig_ctx = None
+                while not done:
+                    (_, mexpiration) = _compute_times(timeout)
+                    if mexpiration is None or (
+                        expiration is not None and mexpiration > expiration
+                    ):
+                        mexpiration = expiration
+                    if is_udp:
+                        (rwire, _) = _udp_recv(s, 65535, mexpiration)
+                    else:
+                        ldata = _net_read(s, 2, mexpiration)
+                        (l,) = struct.unpack("!H", ldata)
+                        rwire = _net_read(s, l, mexpiration)
+                    r = dns.message.from_wire(
+                        rwire,
+                        keyring=query.keyring,
+                        request_mac=query.mac,
+                        xfr=True,
+                        origin=origin,
+                        tsig_ctx=tsig_ctx,
+                        multi=(not is_udp),
+                        one_rr_per_rrset=is_ixfr,
+                    )
+                    try:
+                        done = inbound.process_message(r)
+                    except dns.xfr.UseTCP:
+                        assert is_udp  # should not happen if we used TCP!
+                        if udp_mode == UDPMode.ONLY:
+                            raise
+                        done = True
+                        retry = True
+                        udp_mode = UDPMode.NEVER
+                        continue
+                    yield r
+                    tsig_ctx = r.tsig_ctx
+                if not retry and query.keyring and not r.had_tsig:
+                    raise dns.exception.FormError("missing TSIG")
+
+
 def xfr(
     where: str,
     zone: Union[dns.name.Name, str],
@@ -1436,134 +1565,42 @@ def xfr(
     Returns a generator of ``dns.message.Message`` objects.
     """
 
+    class DummyTransactionManager(dns.transaction.TransactionManager):
+        def __init__(self, origin, relativize):
+            self.info = (origin, relativize, dns.name.empty if relativize else origin)
+
+        def origin_information(self):
+            return self.info
+
+        def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
+            class DummyTransaction(object):
+                def nop(*args, **kw):
+                    pass
+
+                def __getattr__(self, _):
+                    return self.nop
+
+            return cast(dns.transaction.Transaction, DummyTransaction())
+
     if isinstance(zone, str):
         zone = dns.name.from_text(zone)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     q = dns.message.make_query(zone, rdtype, rdclass)
     if rdtype == dns.rdatatype.IXFR:
-        rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
-        q.authority.append(rrset)
+        rrset = q.find_rrset(
+            q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
+        )
+        soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
+        rrset.add(soa, 0)
     if keyring is not None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
-    wire = q.to_wire()
-    (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
-    )
+    tm = DummyTransactionManager(zone, relativize)
     if use_udp and rdtype != dns.rdatatype.IXFR:
         raise ValueError("cannot do a UDP AXFR")
-    sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
-    with _make_socket(af, sock_type, source) as s:
-        (_, expiration) = _compute_times(lifetime)
-        _connect(s, destination, expiration)
-        l = len(wire)
-        if use_udp:
-            _udp_send(s, wire, None, expiration)
-        else:
-            tcpmsg = struct.pack("!H", l) + wire
-            _net_write(s, tcpmsg, expiration)
-        done = False
-        delete_mode = True
-        expecting_SOA = False
-        soa_rrset = None
-        if relativize:
-            origin = zone
-            oname = dns.name.empty
-        else:
-            origin = None
-            oname = zone
-        tsig_ctx = None
-        while not done:
-            (_, mexpiration) = _compute_times(timeout)
-            if mexpiration is None or (
-                expiration is not None and mexpiration > expiration
-            ):
-                mexpiration = expiration
-            if use_udp:
-                (wire, _) = _udp_recv(s, 65535, mexpiration)
-            else:
-                ldata = _net_read(s, 2, mexpiration)
-                (l,) = struct.unpack("!H", ldata)
-                wire = _net_read(s, l, mexpiration)
-            is_ixfr = rdtype == dns.rdatatype.IXFR
-            r = dns.message.from_wire(
-                wire,
-                keyring=q.keyring,
-                request_mac=q.mac,
-                xfr=True,
-                origin=origin,
-                tsig_ctx=tsig_ctx,
-                multi=True,
-                one_rr_per_rrset=is_ixfr,
-            )
-            rcode = r.rcode()
-            if rcode != dns.rcode.NOERROR:
-                raise TransferError(rcode)
-            tsig_ctx = r.tsig_ctx
-            answer_index = 0
-            if soa_rrset is None:
-                if not r.answer or r.answer[0].name != oname:
-                    raise dns.exception.FormError("No answer or RRset not for qname")
-                rrset = r.answer[0]
-                if rrset.rdtype != dns.rdatatype.SOA:
-                    raise dns.exception.FormError("first RRset is not an SOA")
-                answer_index = 1
-                soa_rrset = rrset.copy()
-                if rdtype == dns.rdatatype.IXFR:
-                    if dns.serial.Serial(soa_rrset[0].serial) <= serial:
-                        #
-                        # We're already up-to-date.
-                        #
-                        done = True
-                    else:
-                        expecting_SOA = True
-            #
-            # Process SOAs in the answer section (other than the initial
-            # SOA in the first message).
-            #
-            for rrset in r.answer[answer_index:]:
-                if done:
-                    raise dns.exception.FormError("answers after final SOA")
-                if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
-                    if expecting_SOA:
-                        if rrset[0].serial != serial:
-                            raise dns.exception.FormError("IXFR base serial mismatch")
-                        expecting_SOA = False
-                    elif rdtype == dns.rdatatype.IXFR:
-                        delete_mode = not delete_mode
-                    #
-                    # If this SOA RRset is equal to the first we saw then we're
-                    # finished. If this is an IXFR we also check that we're
-                    # seeing the record in the expected part of the response.
-                    #
-                    if rrset == soa_rrset and (
-                        rdtype == dns.rdatatype.AXFR
-                        or (rdtype == dns.rdatatype.IXFR and delete_mode)
-                    ):
-                        done = True
-                elif expecting_SOA:
-                    #
-                    # We made an IXFR request and are expecting another
-                    # SOA RR, but saw something else, so this must be an
-                    # AXFR response.
-                    #
-                    rdtype = dns.rdatatype.AXFR
-                    expecting_SOA = False
-            if done and q.keyring and not r.had_tsig:
-                raise dns.exception.FormError("missing TSIG")
-            yield r
-
-
-class UDPMode(enum.IntEnum):
-    """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
-
-    NEVER means "never use UDP; always use TCP"
-    TRY_FIRST means "try to use UDP but fall back to TCP if needed"
-    ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
-    """
-
-    NEVER = 0
-    TRY_FIRST = 1
-    ONLY = 2
+    udp_mode = UDPMode.ONLY if use_udp else UDPMode.NEVER
+    yield from _inbound_xfr(
+        where, tm, q, port, timeout, lifetime, source, source_port, udp_mode
+    )
 
 
 def inbound_xfr(
@@ -1613,69 +1650,15 @@ def inbound_xfr(
 
     Raises on errors.
     """
-    if query is None:
-        (query, serial) = dns.xfr.make_query(txn_manager)
-    else:
-        serial = dns.xfr.extract_serial_from_query(query)
-    rdtype = query.question[0].rdtype
-    is_ixfr = rdtype == dns.rdatatype.IXFR
-    origin = txn_manager.from_wire_origin()
-    wire = query.to_wire()
-    (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
-    )
-    (_, expiration) = _compute_times(lifetime)
-    retry = True
-    while retry:
-        retry = False
-        if is_ixfr and udp_mode != UDPMode.NEVER:
-            sock_type = socket.SOCK_DGRAM
-            is_udp = True
-        else:
-            sock_type = socket.SOCK_STREAM
-            is_udp = False
-        with _make_socket(af, sock_type, source) as s:
-            _connect(s, destination, expiration)
-            if is_udp:
-                _udp_send(s, wire, None, expiration)
-            else:
-                tcpmsg = struct.pack("!H", len(wire)) + wire
-                _net_write(s, tcpmsg, expiration)
-            with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
-                done = False
-                tsig_ctx = None
-                while not done:
-                    (_, mexpiration) = _compute_times(timeout)
-                    if mexpiration is None or (
-                        expiration is not None and mexpiration > expiration
-                    ):
-                        mexpiration = expiration
-                    if is_udp:
-                        (rwire, _) = _udp_recv(s, 65535, mexpiration)
-                    else:
-                        ldata = _net_read(s, 2, mexpiration)
-                        (l,) = struct.unpack("!H", ldata)
-                        rwire = _net_read(s, l, mexpiration)
-                    r = dns.message.from_wire(
-                        rwire,
-                        keyring=query.keyring,
-                        request_mac=query.mac,
-                        xfr=True,
-                        origin=origin,
-                        tsig_ctx=tsig_ctx,
-                        multi=(not is_udp),
-                        one_rr_per_rrset=is_ixfr,
-                    )
-                    try:
-                        done = inbound.process_message(r)
-                    except dns.xfr.UseTCP:
-                        assert is_udp  # should not happen if we used TCP!
-                        if udp_mode == UDPMode.ONLY:
-                            raise
-                        done = True
-                        retry = True
-                        udp_mode = UDPMode.NEVER
-                        continue
-                    tsig_ctx = r.tsig_ctx
-                if not retry and query.keyring and not r.had_tsig:
-                    raise dns.exception.FormError("missing TSIG")
+    for msg in _inbound_xfr(
+        where,
+        txn_manager,
+        query,
+        port,
+        timeout,
+        lifetime,
+        source,
+        source_port,
+        udp_mode,
+    ):
+        pass