]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Refactor xfr. (#1122)
authorBrian Wellington <bwelling@xbill.org>
Sun, 18 Aug 2024 13:54:16 +0000 (06:54 -0700)
committerGitHub <noreply@github.com>
Sun, 18 Aug 2024 13:54:16 +0000 (06:54 -0700)
* Refactor xfr.

Internally refactors the zone transfer code to separate the message
processing from the socket management, allowing the (internal) callers
to pass a socket in.  This should allow a future interface that accepts
a socket, which would mean that xfr over DoT would just work, and xfr
over DoQ would be closer to working.

Adds some necessary functionality to the asyncbackend Socket class to
allow the async zone transfer code to be more similar to the sync code
(specifically, adds a type field to Socket, and updates the trio backend
to connect UDP sockets when requested).

In asyncquery.py, reorder the inbound_xfr() and quic() methods for
consistency.

* Run black.

* Fix typing.

dns/_asyncbackend.py
dns/_asyncio_backend.py
dns/_trio_backend.py
dns/asyncquery.py
dns/query.py

index 49f14fed682f6088fc506ce19978fbe62da1fafe..f6760fd0da90f1f000ba45631ee5ae9ef13cf78d 100644 (file)
@@ -26,6 +26,10 @@ class NullContext:
 
 
 class Socket:  # pragma: no cover
+    def __init__(self, family: int, type: int):
+        self.family = family
+        self.type = type
+
     async def close(self):
         pass
 
@@ -46,9 +50,6 @@ class Socket:  # pragma: no cover
 
 
 class DatagramSocket(Socket):  # pragma: no cover
-    def __init__(self, family: int):
-        self.family = family
-
     async def sendto(self, what, destination, timeout):
         raise NotImplementedError
 
index 9d9ed3690c6c84aa88102df63481dec1bf51d3d4..de18c40173cd4793bb3385446dbe3e5f7843982e 100644 (file)
@@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout):
 
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, family, transport, protocol):
-        super().__init__(family)
+        super().__init__(family, socket.SOCK_DGRAM)
         self.transport = transport
         self.protocol = protocol
 
@@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
 
 class StreamSocket(dns._asyncbackend.StreamSocket):
     def __init__(self, af, reader, writer):
-        self.family = af
+        super().__init__(af, socket.SOCK_STREAM)
         self.reader = reader
         self.writer = writer
 
index 398e3276923bb6ae91d14a12de5d089cd134d7bb..1d2bdda981a0581e4a6d26b952455c0b17bb8a57 100644 (file)
@@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple
 
 
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
-    def __init__(self, socket):
-        super().__init__(socket.family)
-        self.socket = socket
+    def __init__(self, sock):
+        super().__init__(sock.family, socket.SOCK_DGRAM)
+        self.socket = sock
 
     async def sendto(self, what, destination, timeout):
         with _maybe_timeout(timeout):
-            return await self.socket.sendto(what, destination)
+            if destination is None:
+                return await self.socket.send(what)
+            else:
+                return await self.socket.sendto(what, destination)
         raise dns.exception.Timeout(
             timeout=timeout
         )  # pragma: no cover  lgtm[py/unreachable-statement]
@@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
 
 class StreamSocket(dns._asyncbackend.StreamSocket):
     def __init__(self, family, stream, tls=False):
-        self.family = family
+        super().__init__(family, socket.SOCK_STREAM)
         self.stream = stream
         self.tls = tls
 
@@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend):
         try:
             if source:
                 await s.bind(_lltuple(source, af))
-            if socktype == socket.SOCK_STREAM:
+            if socktype == socket.SOCK_STREAM or destination is not None:
                 connected = False
                 with _maybe_timeout(timeout):
                     await s.connect(_lltuple(destination, af))
index 717f43b4707daf0c5247714452443789841156ee..622c9d520f422c49c003c521d6636868663df160 100644 (file)
@@ -24,7 +24,7 @@ import socket
 import struct
 import time
 import urllib.parse
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union, cast
 
 import dns.asyncbackend
 import dns.exception
@@ -716,107 +716,6 @@ async def _http3(
     return r
 
 
-async def inbound_xfr(
-    where: str,
-    txn_manager: dns.transaction.TransactionManager,
-    query: Optional[dns.message.Message] = None,
-    port: int = 53,
-    timeout: Optional[float] = None,
-    lifetime: Optional[float] = None,
-    source: Optional[str] = None,
-    source_port: int = 0,
-    udp_mode: UDPMode = UDPMode.NEVER,
-    backend: Optional[dns.asyncbackend.Backend] = None,
-) -> None:
-    """Conduct an inbound transfer and apply it via a transaction from the
-    txn_manager.
-
-    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
-    the default, then dnspython will use the default backend.
-
-    See :py:func:`dns.query.inbound_xfr()` for the documentation of
-    the other parameters, exceptions, and return type of this method.
-    """
-    if query is None:
-        (query, serial) = dns.xfr.make_query(txn_manager)
-    else:
-        serial = dns.xfr.extract_serial_from_query(query)
-    rdtype = query.question[0].rdtype
-    is_ixfr = rdtype == dns.rdatatype.IXFR
-    origin = txn_manager.from_wire_origin()
-    wire = query.to_wire()
-    af = dns.inet.af_for_address(where)
-    stuple = _source_tuple(af, source, source_port)
-    dtuple = (where, port)
-    (_, expiration) = _compute_times(lifetime)
-    retry = True
-    while retry:
-        retry = False
-        if is_ixfr and udp_mode != UDPMode.NEVER:
-            sock_type = socket.SOCK_DGRAM
-            is_udp = True
-        else:
-            sock_type = socket.SOCK_STREAM
-            is_udp = False
-        if not backend:
-            backend = dns.asyncbackend.get_default_backend()
-        s = await backend.make_socket(
-            af, sock_type, 0, stuple, dtuple, _timeout(expiration)
-        )
-        async with s:
-            if is_udp:
-                await s.sendto(wire, dtuple, _timeout(expiration))
-            else:
-                tcpmsg = struct.pack("!H", len(wire)) + wire
-                await s.sendall(tcpmsg, expiration)
-            with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
-                done = False
-                tsig_ctx = None
-                while not done:
-                    (_, mexpiration) = _compute_times(timeout)
-                    if mexpiration is None or (
-                        expiration is not None and mexpiration > expiration
-                    ):
-                        mexpiration = expiration
-                    if is_udp:
-                        destination = _lltuple((where, port), af)
-                        while True:
-                            timeout = _timeout(mexpiration)
-                            (rwire, from_address) = await s.recvfrom(65535, timeout)
-                            if _matches_destination(
-                                af, from_address, destination, True
-                            ):
-                                break
-                    else:
-                        ldata = await _read_exactly(s, 2, mexpiration)
-                        (l,) = struct.unpack("!H", ldata)
-                        rwire = await _read_exactly(s, l, mexpiration)
-                    is_ixfr = rdtype == dns.rdatatype.IXFR
-                    r = dns.message.from_wire(
-                        rwire,
-                        keyring=query.keyring,
-                        request_mac=query.mac,
-                        xfr=True,
-                        origin=origin,
-                        tsig_ctx=tsig_ctx,
-                        multi=(not is_udp),
-                        one_rr_per_rrset=is_ixfr,
-                    )
-                    try:
-                        done = inbound.process_message(r)
-                    except dns.xfr.UseTCP:
-                        assert is_udp  # should not happen if we used TCP!
-                        if udp_mode == UDPMode.ONLY:
-                            raise
-                        done = True
-                        retry = True
-                        udp_mode = UDPMode.NEVER
-                        continue
-                    tsig_ctx = r.tsig_ctx
-                if not retry and query.keyring and not r.had_tsig:
-                    raise dns.exception.FormError("missing TSIG")
-
-
 async def quic(
     q: dns.message.Message,
     where: str,
@@ -883,3 +782,112 @@ async def quic(
     if not q.is_response(r):
         raise BadResponse
     return r
+
+
+async def _inbound_xfr(
+    txn_manager: dns.transaction.TransactionManager,
+    s: dns.asyncbackend.Socket,
+    query: dns.message.Message,
+    serial: Optional[int],
+    timeout: Optional[float],
+    expiration: float,
+) -> Any:
+    """Given a socket, does the zone transfer."""
+    rdtype = query.question[0].rdtype
+    is_ixfr = rdtype == dns.rdatatype.IXFR
+    origin = txn_manager.from_wire_origin()
+    wire = query.to_wire()
+    is_udp = s.type == socket.SOCK_DGRAM
+    if is_udp:
+        udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
+        await udp_sock.sendto(wire, None, _timeout(expiration))
+    else:
+        tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
+        tcpmsg = struct.pack("!H", len(wire)) + wire
+        await tcp_sock.sendall(tcpmsg, expiration)
+    with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+        done = False
+        tsig_ctx = None
+        while not done:
+            (_, mexpiration) = _compute_times(timeout)
+            if mexpiration is None or (
+                expiration is not None and mexpiration > expiration
+            ):
+                mexpiration = expiration
+            if is_udp:
+                timeout = _timeout(mexpiration)
+                (rwire, _) = await udp_sock.recvfrom(65535, timeout)
+            else:
+                ldata = await _read_exactly(tcp_sock, 2, mexpiration)
+                (l,) = struct.unpack("!H", ldata)
+                rwire = await _read_exactly(tcp_sock, l, mexpiration)
+            r = dns.message.from_wire(
+                rwire,
+                keyring=query.keyring,
+                request_mac=query.mac,
+                xfr=True,
+                origin=origin,
+                tsig_ctx=tsig_ctx,
+                multi=(not is_udp),
+                one_rr_per_rrset=is_ixfr,
+            )
+            done = inbound.process_message(r)
+            yield r
+            tsig_ctx = r.tsig_ctx
+        if query.keyring and not r.had_tsig:
+            raise dns.exception.FormError("missing TSIG")
+
+
+async def inbound_xfr(
+    where: str,
+    txn_manager: dns.transaction.TransactionManager,
+    query: Optional[dns.message.Message] = None,
+    port: int = 53,
+    timeout: Optional[float] = None,
+    lifetime: Optional[float] = None,
+    source: Optional[str] = None,
+    source_port: int = 0,
+    udp_mode: UDPMode = UDPMode.NEVER,
+    backend: Optional[dns.asyncbackend.Backend] = None,
+) -> None:
+    """Conduct an inbound transfer and apply it via a transaction from the
+    txn_manager.
+
+    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+    the default, then dnspython will use the default backend.
+
+    See :py:func:`dns.query.inbound_xfr()` for the documentation of
+    the other parameters, exceptions, and return type of this method.
+    """
+    if query is None:
+        (query, serial) = dns.xfr.make_query(txn_manager)
+    else:
+        serial = dns.xfr.extract_serial_from_query(query)
+    af = dns.inet.af_for_address(where)
+    stuple = _source_tuple(af, source, source_port)
+    dtuple = (where, port)
+    if not backend:
+        backend = dns.asyncbackend.get_default_backend()
+    (_, expiration) = _compute_times(lifetime)
+    if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
+        s = await backend.make_socket(
+            af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
+        )
+        async with s:
+            try:
+                async for _ in _inbound_xfr(
+                    txn_manager, s, query, serial, timeout, expiration
+                ):
+                    pass
+                return
+            except dns.xfr.UseTCP:
+                if udp_mode == UDPMode.ONLY:
+                    raise
+                pass
+
+    s = await backend.make_socket(
+        af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
+    )
+    async with s:
+        async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+            pass
index 050aeca44ae370060ef61103149ca325e22f1ded..8e21ed2fa989ee2bce06f71e4f554a9b7a11cebd 100644 (file)
@@ -1405,119 +1405,54 @@ class UDPMode(enum.IntEnum):
 
 
 def _inbound_xfr(
-    where: str,
     txn_manager: dns.transaction.TransactionManager,
-    query: Optional[dns.message.Message] = None,
-    port: int = 53,
-    timeout: Optional[float] = None,
-    lifetime: Optional[float] = None,
-    source: Optional[str] = None,
-    source_port: int = 0,
-    udp_mode: UDPMode = UDPMode.NEVER,
+    s: socket.socket,
+    query: dns.message.Message,
+    serial: Optional[int],
+    timeout: Optional[float],
+    expiration: float,
 ) -> Any:
-    """Conduct an inbound transfer and apply it via a transaction from the
-    txn_manager.
-
-    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
-    to send the message.
-
-    *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
-    for this transfer (typically a ``dns.zone.Zone``).
-
-    *query*, the query to send.  If not supplied, a default query is
-    constructed using information from the *txn_manager*.
-
-    *port*, an ``int``, the port send the message to.  The default is 53.
-
-    *timeout*, a ``float``, the number of seconds to wait for each
-    response message.  If None, the default, wait forever.
-
-    *lifetime*, a ``float``, the total number of seconds to spend
-    doing the transfer.  If ``None``, the default, then there is no
-    limit on the time the transfer may take.
-
-    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
-    the source address.  The default is the wildcard address.
-
-    *source_port*, an ``int``, the port from which to send the message.
-    The default is 0.
-
-    *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
-    for IXFRs.  The default is ``dns.UDPMode.NEVER``, i.e. only use
-    TCP.  Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
-    means "try UDP but fallback to TCP if needed", and
-    ``dns.UDPMode.ONLY``, which means "try UDP and raise
-    ``dns.xfr.UseTCP`` if it does not succeed.
-
-    Raises on errors.
-    """
-    if query is None:
-        (query, serial) = dns.xfr.make_query(txn_manager)
-    else:
-        serial = dns.xfr.extract_serial_from_query(query)
+    """Given a socket, does the zone transfer."""
     rdtype = query.question[0].rdtype
     is_ixfr = rdtype == dns.rdatatype.IXFR
     origin = txn_manager.from_wire_origin()
     wire = query.to_wire()
-    (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
-    )
-    (_, expiration) = _compute_times(lifetime)
-    retry = True
-    while retry:
-        retry = False
-        if is_ixfr and udp_mode != UDPMode.NEVER:
-            sock_type = socket.SOCK_DGRAM
-            is_udp = True
-        else:
-            sock_type = socket.SOCK_STREAM
-            is_udp = False
-        with _make_socket(af, sock_type, source) as s:
-            _connect(s, destination, expiration)
+    is_udp = s.type == socket.SOCK_DGRAM
+    if is_udp:
+        _udp_send(s, wire, None, expiration)
+    else:
+        tcpmsg = struct.pack("!H", len(wire)) + wire
+        _net_write(s, tcpmsg, expiration)
+    with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+        done = False
+        tsig_ctx = None
+        while not done:
+            (_, mexpiration) = _compute_times(timeout)
+            if mexpiration is None or (
+                expiration is not None and mexpiration > expiration
+            ):
+                mexpiration = expiration
             if is_udp:
-                _udp_send(s, wire, None, expiration)
+                (rwire, _) = _udp_recv(s, 65535, mexpiration)
             else:
-                tcpmsg = struct.pack("!H", len(wire)) + wire
-                _net_write(s, tcpmsg, expiration)
-            with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
-                done = False
-                tsig_ctx = None
-                while not done:
-                    (_, mexpiration) = _compute_times(timeout)
-                    if mexpiration is None or (
-                        expiration is not None and mexpiration > expiration
-                    ):
-                        mexpiration = expiration
-                    if is_udp:
-                        (rwire, _) = _udp_recv(s, 65535, mexpiration)
-                    else:
-                        ldata = _net_read(s, 2, mexpiration)
-                        (l,) = struct.unpack("!H", ldata)
-                        rwire = _net_read(s, l, mexpiration)
-                    r = dns.message.from_wire(
-                        rwire,
-                        keyring=query.keyring,
-                        request_mac=query.mac,
-                        xfr=True,
-                        origin=origin,
-                        tsig_ctx=tsig_ctx,
-                        multi=(not is_udp),
-                        one_rr_per_rrset=is_ixfr,
-                    )
-                    try:
-                        done = inbound.process_message(r)
-                    except dns.xfr.UseTCP:
-                        assert is_udp  # should not happen if we used TCP!
-                        if udp_mode == UDPMode.ONLY:
-                            raise
-                        done = True
-                        retry = True
-                        udp_mode = UDPMode.NEVER
-                        continue
-                    yield r
-                    tsig_ctx = r.tsig_ctx
-                if not retry and query.keyring and not r.had_tsig:
-                    raise dns.exception.FormError("missing TSIG")
+                ldata = _net_read(s, 2, mexpiration)
+                (l,) = struct.unpack("!H", ldata)
+                rwire = _net_read(s, l, mexpiration)
+            r = dns.message.from_wire(
+                rwire,
+                keyring=query.keyring,
+                request_mac=query.mac,
+                xfr=True,
+                origin=origin,
+                tsig_ctx=tsig_ctx,
+                multi=(not is_udp),
+                one_rr_per_rrset=is_ixfr,
+            )
+            done = inbound.process_message(r)
+            yield r
+            tsig_ctx = r.tsig_ctx
+        if query.keyring and not r.had_tsig:
+            raise dns.exception.FormError("missing TSIG")
 
 
 def xfr(
@@ -1624,13 +1559,17 @@ def xfr(
         rrset.add(soa, 0)
     if keyring is not None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
+    (af, destination, source) = _destination_and_source(
+        where, port, source, source_port
+    )
+    (_, expiration) = _compute_times(lifetime)
     tm = DummyTransactionManager(zone, relativize)
     if use_udp and rdtype != dns.rdatatype.IXFR:
         raise ValueError("cannot do a UDP AXFR")
-    udp_mode = UDPMode.ONLY if use_udp else UDPMode.NEVER
-    yield from _inbound_xfr(
-        where, tm, q, port, timeout, lifetime, source, source_port, udp_mode
-    )
+    sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
+    with _make_socket(af, sock_type, source) as s:
+        _connect(s, destination, expiration)
+        yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
 
 
 def inbound_xfr(
@@ -1680,15 +1619,30 @@ def inbound_xfr(
 
     Raises on errors.
     """
-    for _ in _inbound_xfr(
-        where,
-        txn_manager,
-        query,
-        port,
-        timeout,
-        lifetime,
-        source,
-        source_port,
-        udp_mode,
-    ):
-        pass
+    if query is None:
+        (query, serial) = dns.xfr.make_query(txn_manager)
+    else:
+        serial = dns.xfr.extract_serial_from_query(query)
+
+    (af, destination, source) = _destination_and_source(
+        where, port, source, source_port
+    )
+    (_, expiration) = _compute_times(lifetime)
+    if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
+        with _make_socket(af, socket.SOCK_DGRAM, source) as s:
+            _connect(s, destination, expiration)
+            try:
+                for _ in _inbound_xfr(
+                    txn_manager, s, query, serial, timeout, expiration
+                ):
+                    pass
+                return
+            except dns.xfr.UseTCP:
+                if udp_mode == UDPMode.ONLY:
+                    raise
+                pass
+
+    with _make_socket(af, socket.SOCK_STREAM, source) as s:
+        _connect(s, destination, expiration)
+        for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+            pass