* Refactor xfr.
Internally refactors the zone transfer code to separate the message
processing from the socket management, allowing the (internal) callers
to pass a socket in. This should allow a future interface that accepts
a socket, which would mean that xfr over DoT would just work, and xfr
over DoQ would be closer to working.
Adds some necessary functionality to the asyncbackend Socket class to
allow the async zone transfer code to be more similar to the sync code
(specifically, adds a type field to Socket, and updates the trio backend
to connect UDP sockets when requested).
In asyncquery.py, reorder the inbound_xfr() and quic() methods for
consistency.
* Run black.
* Fix typing.
class Socket: # pragma: no cover
+ def __init__(self, family: int, type: int):
+ self.family = family
+ self.type = type
+
async def close(self):
pass
class DatagramSocket(Socket): # pragma: no cover
- def __init__(self, family: int):
- self.family = family
-
async def sendto(self, what, destination, timeout):
raise NotImplementedError
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
- super().__init__(family)
+ super().__init__(family, socket.SOCK_DGRAM)
self.transport = transport
self.protocol = protocol
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
- self.family = af
+ super().__init__(af, socket.SOCK_STREAM)
self.reader = reader
self.writer = writer
class DatagramSocket(dns._asyncbackend.DatagramSocket):
- def __init__(self, socket):
- super().__init__(socket.family)
- self.socket = socket
+ def __init__(self, sock):
+ super().__init__(sock.family, socket.SOCK_DGRAM)
+ self.socket = sock
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
- return await self.socket.sendto(what, destination)
+ if destination is None:
+ return await self.socket.send(what)
+ else:
+ return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
- self.family = family
+ super().__init__(family, socket.SOCK_STREAM)
self.stream = stream
self.tls = tls
try:
if source:
await s.bind(_lltuple(source, af))
- if socktype == socket.SOCK_STREAM:
+ if socktype == socket.SOCK_STREAM or destination is not None:
connected = False
with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af))
import struct
import time
import urllib.parse
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union, cast
import dns.asyncbackend
import dns.exception
return r
-async def inbound_xfr(
- where: str,
- txn_manager: dns.transaction.TransactionManager,
- query: Optional[dns.message.Message] = None,
- port: int = 53,
- timeout: Optional[float] = None,
- lifetime: Optional[float] = None,
- source: Optional[str] = None,
- source_port: int = 0,
- udp_mode: UDPMode = UDPMode.NEVER,
- backend: Optional[dns.asyncbackend.Backend] = None,
-) -> None:
- """Conduct an inbound transfer and apply it via a transaction from the
- txn_manager.
-
- *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
- the default, then dnspython will use the default backend.
-
- See :py:func:`dns.query.inbound_xfr()` for the documentation of
- the other parameters, exceptions, and return type of this method.
- """
- if query is None:
- (query, serial) = dns.xfr.make_query(txn_manager)
- else:
- serial = dns.xfr.extract_serial_from_query(query)
- rdtype = query.question[0].rdtype
- is_ixfr = rdtype == dns.rdatatype.IXFR
- origin = txn_manager.from_wire_origin()
- wire = query.to_wire()
- af = dns.inet.af_for_address(where)
- stuple = _source_tuple(af, source, source_port)
- dtuple = (where, port)
- (_, expiration) = _compute_times(lifetime)
- retry = True
- while retry:
- retry = False
- if is_ixfr and udp_mode != UDPMode.NEVER:
- sock_type = socket.SOCK_DGRAM
- is_udp = True
- else:
- sock_type = socket.SOCK_STREAM
- is_udp = False
- if not backend:
- backend = dns.asyncbackend.get_default_backend()
- s = await backend.make_socket(
- af, sock_type, 0, stuple, dtuple, _timeout(expiration)
- )
- async with s:
- if is_udp:
- await s.sendto(wire, dtuple, _timeout(expiration))
- else:
- tcpmsg = struct.pack("!H", len(wire)) + wire
- await s.sendall(tcpmsg, expiration)
- with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
- done = False
- tsig_ctx = None
- while not done:
- (_, mexpiration) = _compute_times(timeout)
- if mexpiration is None or (
- expiration is not None and mexpiration > expiration
- ):
- mexpiration = expiration
- if is_udp:
- destination = _lltuple((where, port), af)
- while True:
- timeout = _timeout(mexpiration)
- (rwire, from_address) = await s.recvfrom(65535, timeout)
- if _matches_destination(
- af, from_address, destination, True
- ):
- break
- else:
- ldata = await _read_exactly(s, 2, mexpiration)
- (l,) = struct.unpack("!H", ldata)
- rwire = await _read_exactly(s, l, mexpiration)
- is_ixfr = rdtype == dns.rdatatype.IXFR
- r = dns.message.from_wire(
- rwire,
- keyring=query.keyring,
- request_mac=query.mac,
- xfr=True,
- origin=origin,
- tsig_ctx=tsig_ctx,
- multi=(not is_udp),
- one_rr_per_rrset=is_ixfr,
- )
- try:
- done = inbound.process_message(r)
- except dns.xfr.UseTCP:
- assert is_udp # should not happen if we used TCP!
- if udp_mode == UDPMode.ONLY:
- raise
- done = True
- retry = True
- udp_mode = UDPMode.NEVER
- continue
- tsig_ctx = r.tsig_ctx
- if not retry and query.keyring and not r.had_tsig:
- raise dns.exception.FormError("missing TSIG")
-
-
async def quic(
q: dns.message.Message,
where: str,
if not q.is_response(r):
raise BadResponse
return r
+
+
+async def _inbound_xfr(
+ txn_manager: dns.transaction.TransactionManager,
+ s: dns.asyncbackend.Socket,
+ query: dns.message.Message,
+ serial: Optional[int],
+ timeout: Optional[float],
+ expiration: float,
+) -> Any:
+ """Given a socket, does the zone transfer."""
+ rdtype = query.question[0].rdtype
+ is_ixfr = rdtype == dns.rdatatype.IXFR
+ origin = txn_manager.from_wire_origin()
+ wire = query.to_wire()
+ is_udp = s.type == socket.SOCK_DGRAM
+ if is_udp:
+ udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
+ await udp_sock.sendto(wire, None, _timeout(expiration))
+ else:
+ tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
+ tcpmsg = struct.pack("!H", len(wire)) + wire
+ await tcp_sock.sendall(tcpmsg, expiration)
+ with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+ done = False
+ tsig_ctx = None
+ while not done:
+ (_, mexpiration) = _compute_times(timeout)
+ if mexpiration is None or (
+ expiration is not None and mexpiration > expiration
+ ):
+ mexpiration = expiration
+ if is_udp:
+ timeout = _timeout(mexpiration)
+ (rwire, _) = await udp_sock.recvfrom(65535, timeout)
+ else:
+ ldata = await _read_exactly(tcp_sock, 2, mexpiration)
+ (l,) = struct.unpack("!H", ldata)
+ rwire = await _read_exactly(tcp_sock, l, mexpiration)
+ r = dns.message.from_wire(
+ rwire,
+ keyring=query.keyring,
+ request_mac=query.mac,
+ xfr=True,
+ origin=origin,
+ tsig_ctx=tsig_ctx,
+ multi=(not is_udp),
+ one_rr_per_rrset=is_ixfr,
+ )
+ done = inbound.process_message(r)
+ yield r
+ tsig_ctx = r.tsig_ctx
+ if query.keyring and not r.had_tsig:
+ raise dns.exception.FormError("missing TSIG")
+
+
+async def inbound_xfr(
+ where: str,
+ txn_manager: dns.transaction.TransactionManager,
+ query: Optional[dns.message.Message] = None,
+ port: int = 53,
+ timeout: Optional[float] = None,
+ lifetime: Optional[float] = None,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ udp_mode: UDPMode = UDPMode.NEVER,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> None:
+ """Conduct an inbound transfer and apply it via a transaction from the
+ txn_manager.
+
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
+ See :py:func:`dns.query.inbound_xfr()` for the documentation of
+ the other parameters, exceptions, and return type of this method.
+ """
+ if query is None:
+ (query, serial) = dns.xfr.make_query(txn_manager)
+ else:
+ serial = dns.xfr.extract_serial_from_query(query)
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ (_, expiration) = _compute_times(lifetime)
+ if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
+ s = await backend.make_socket(
+ af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
+ )
+ async with s:
+ try:
+ async for _ in _inbound_xfr(
+ txn_manager, s, query, serial, timeout, expiration
+ ):
+ pass
+ return
+ except dns.xfr.UseTCP:
+ if udp_mode == UDPMode.ONLY:
+ raise
+ pass
+
+ s = await backend.make_socket(
+ af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
+ )
+ async with s:
+ async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+ pass
def _inbound_xfr(
- where: str,
txn_manager: dns.transaction.TransactionManager,
- query: Optional[dns.message.Message] = None,
- port: int = 53,
- timeout: Optional[float] = None,
- lifetime: Optional[float] = None,
- source: Optional[str] = None,
- source_port: int = 0,
- udp_mode: UDPMode = UDPMode.NEVER,
+ s: socket.socket,
+ query: dns.message.Message,
+ serial: Optional[int],
+ timeout: Optional[float],
+ expiration: float,
) -> Any:
- """Conduct an inbound transfer and apply it via a transaction from the
- txn_manager.
-
- *where*, a ``str`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
- for this transfer (typically a ``dns.zone.Zone``).
-
- *query*, the query to send. If not supplied, a default query is
- constructed using information from the *txn_manager*.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *timeout*, a ``float``, the number of seconds to wait for each
- response message. If None, the default, wait forever.
-
- *lifetime*, a ``float``, the total number of seconds to spend
- doing the transfer. If ``None``, the default, then there is no
- limit on the time the transfer may take.
-
- *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
- for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use
- TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
- means "try UDP but fallback to TCP if needed", and
- ``dns.UDPMode.ONLY``, which means "try UDP and raise
- ``dns.xfr.UseTCP`` if it does not succeed.
-
- Raises on errors.
- """
- if query is None:
- (query, serial) = dns.xfr.make_query(txn_manager)
- else:
- serial = dns.xfr.extract_serial_from_query(query)
+ """Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
- (af, destination, source) = _destination_and_source(
- where, port, source, source_port
- )
- (_, expiration) = _compute_times(lifetime)
- retry = True
- while retry:
- retry = False
- if is_ixfr and udp_mode != UDPMode.NEVER:
- sock_type = socket.SOCK_DGRAM
- is_udp = True
- else:
- sock_type = socket.SOCK_STREAM
- is_udp = False
- with _make_socket(af, sock_type, source) as s:
- _connect(s, destination, expiration)
+ is_udp = s.type == socket.SOCK_DGRAM
+ if is_udp:
+ _udp_send(s, wire, None, expiration)
+ else:
+ tcpmsg = struct.pack("!H", len(wire)) + wire
+ _net_write(s, tcpmsg, expiration)
+ with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+ done = False
+ tsig_ctx = None
+ while not done:
+ (_, mexpiration) = _compute_times(timeout)
+ if mexpiration is None or (
+ expiration is not None and mexpiration > expiration
+ ):
+ mexpiration = expiration
if is_udp:
- _udp_send(s, wire, None, expiration)
+ (rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
- tcpmsg = struct.pack("!H", len(wire)) + wire
- _net_write(s, tcpmsg, expiration)
- with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
- done = False
- tsig_ctx = None
- while not done:
- (_, mexpiration) = _compute_times(timeout)
- if mexpiration is None or (
- expiration is not None and mexpiration > expiration
- ):
- mexpiration = expiration
- if is_udp:
- (rwire, _) = _udp_recv(s, 65535, mexpiration)
- else:
- ldata = _net_read(s, 2, mexpiration)
- (l,) = struct.unpack("!H", ldata)
- rwire = _net_read(s, l, mexpiration)
- r = dns.message.from_wire(
- rwire,
- keyring=query.keyring,
- request_mac=query.mac,
- xfr=True,
- origin=origin,
- tsig_ctx=tsig_ctx,
- multi=(not is_udp),
- one_rr_per_rrset=is_ixfr,
- )
- try:
- done = inbound.process_message(r)
- except dns.xfr.UseTCP:
- assert is_udp # should not happen if we used TCP!
- if udp_mode == UDPMode.ONLY:
- raise
- done = True
- retry = True
- udp_mode = UDPMode.NEVER
- continue
- yield r
- tsig_ctx = r.tsig_ctx
- if not retry and query.keyring and not r.had_tsig:
- raise dns.exception.FormError("missing TSIG")
+ ldata = _net_read(s, 2, mexpiration)
+ (l,) = struct.unpack("!H", ldata)
+ rwire = _net_read(s, l, mexpiration)
+ r = dns.message.from_wire(
+ rwire,
+ keyring=query.keyring,
+ request_mac=query.mac,
+ xfr=True,
+ origin=origin,
+ tsig_ctx=tsig_ctx,
+ multi=(not is_udp),
+ one_rr_per_rrset=is_ixfr,
+ )
+ done = inbound.process_message(r)
+ yield r
+ tsig_ctx = r.tsig_ctx
+ if query.keyring and not r.had_tsig:
+ raise dns.exception.FormError("missing TSIG")
def xfr(
rrset.add(soa, 0)
if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ (_, expiration) = _compute_times(lifetime)
tm = DummyTransactionManager(zone, relativize)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR")
- udp_mode = UDPMode.ONLY if use_udp else UDPMode.NEVER
- yield from _inbound_xfr(
- where, tm, q, port, timeout, lifetime, source, source_port, udp_mode
- )
+ sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
+ with _make_socket(af, sock_type, source) as s:
+ _connect(s, destination, expiration)
+ yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
def inbound_xfr(
Raises on errors.
"""
- for _ in _inbound_xfr(
- where,
- txn_manager,
- query,
- port,
- timeout,
- lifetime,
- source,
- source_port,
- udp_mode,
- ):
- pass
+ if query is None:
+ (query, serial) = dns.xfr.make_query(txn_manager)
+ else:
+ serial = dns.xfr.extract_serial_from_query(query)
+
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ (_, expiration) = _compute_times(lifetime)
+ if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
+ with _make_socket(af, socket.SOCK_DGRAM, source) as s:
+ _connect(s, destination, expiration)
+ try:
+ for _ in _inbound_xfr(
+ txn_manager, s, query, serial, timeout, expiration
+ ):
+ pass
+ return
+ except dns.xfr.UseTCP:
+ if udp_mode == UDPMode.ONLY:
+ raise
+ pass
+
+ with _make_socket(af, socket.SOCK_STREAM, source) as s:
+ _connect(s, destination, expiration)
+ for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+ pass