Start of pyright linting (more to come in future work).
type:
python -m mypy --install-types --non-interactive --disallow-incomplete-defs dns
+pyright:
+ pyright dns
+
lint:
pylint dns
finally:
_in__init__.reset(previous)
- nf.__signature__ = inspect.signature(f)
+ nf.__signature__ = inspect.signature(f) # pyright: ignore
return nf
import dns.rdataclass
import dns.rdatatype
import dns.transaction
+import dns.tsig
+import dns.xfr
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
dtuple = None
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
- await send_udp(s, wire, destination, expiration)
+ await send_udp(s, wire, destination, expiration) # pyright: ignore
(r, received_time, _) = await receive_udp(
- s,
+ s, # pyright: ignore
destination,
expiration,
ignore_unexpected,
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
- await send_tcp(s, wire, expiration)
+ await send_tcp(s, wire, expiration) # pyright: ignore
(r, received_time) = await receive_tcp(
- s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
+ s, # pyright: ignore
+ expiration,
+ one_rr_per_rrset,
+ q.keyring,
+ q.mac,
+ ignore_trailing,
)
r.time = received_time - begin_time
if not q.is_response(r):
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
- ssl_context = _make_dot_ssl_context(server_hostname, verify)
+ ssl_context = _make_dot_ssl_context(
+ server_hostname, verify
+ ) # pyright: ignore
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
def _maybe_get_resolver(
- resolver: Optional["dns.asyncresolver.Resolver"],
-) -> "dns.asyncresolver.Resolver":
+ resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
+) -> "dns.asyncresolver.Resolver": # pyright: ignore
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
post: bool = True,
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
- resolver: Optional["dns.asyncresolver.Resolver"] = None,
+ resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
+ # we bind url and then override as pyright can't figure out all paths bind.
+ url = where
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
- else:
- url = where
extensions = {}
if bootstrap_address is None:
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
- assert parsed.hostname is not None # for mypy
- answers = await resolver.resolve_name(parsed.hostname, family)
+ assert parsed.hostname is not None # pyright: ignore
+ answers = await resolver.resolve_name( # pyright: ignore
+ parsed.hostname, family # pyright: ignore
+ )
bootstrap_address = random.choice(list(answers.addresses()))
return await _http3(
q,
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
- if client and not isinstance(client, httpx.AsyncClient):
+ if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
raise ValueError("session parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
family=family,
)
- cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
+ cm = httpx.AsyncClient( # pyright: ignore
+ http1=h1, http2=h2, verify=verify, transport=transport
+ )
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
}
)
response = await backend.wait_for(
- the_client.post(
+ the_client.post( # pyright: ignore
url,
headers=headers,
content=wire,
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
- the_client.get(
+ the_client.get( # pyright: ignore
url,
headers=headers,
params={"dns": twire},
server_name=server_hostname,
) as the_manager:
if not connection:
- the_connection = the_manager.connect(where, port, source, source_port)
+ the_connection = the_manager.connect( # pyright: ignore
+ where, port, source, source_port
+ )
(start, expiration) = _compute_times(timeout)
- stream = await the_connection.make_stream(timeout)
+ stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
+ r: Optional[dns.message.Message] = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
- (rwire, _) = await udp_sock.recvfrom(65535, timeout)
+ (rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
else:
- ldata = await _read_exactly(tcp_sock, 2, mexpiration)
+ ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
(l,) = struct.unpack("!H", ldata)
- rwire = await _read_exactly(tcp_sock, l, mexpiration)
+ rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
- if query.keyring and not r.had_tsig:
+ if query.keyring and r is not None and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
)
async with s:
try:
- async for _ in _inbound_xfr(
- txn_manager, s, query, serial, timeout, expiration
+ async for _ in _inbound_xfr( # pyright: ignore
+ txn_manager,
+ s,
+ query,
+ serial,
+ timeout,
+ expiration, # pyright: ignore
):
pass
return
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
- async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+ async for _ in _inbound_xfr( # pyright: ignore
+ txn_manager, s, query, serial, timeout, expiration # pyright: ignore
+ ):
pass
import dns.asyncbackend
import dns.asyncquery
import dns.exception
+import dns.inet
import dns.name
+import dns.nameserver
import dns.query
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
+import dns.reversename
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
- res = dns.asyncresolver.Resolver(configure=False)
+ res = Resolver(configure=False)
res.nameservers = nameservers
return res
def __init__(self):
pass
- def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover
+ def ok_to_sign(self, key: DNSKEY) -> bool: # pragma: no cover
return False
- def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover
+ def ok_to_validate(self, key: DNSKEY) -> bool: # pragma: no cover
return False
- def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover
+ def ok_to_create_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover
return False
- def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover
+ def ok_to_validate_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover
return False
signature=b"",
)
- data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
+ data = _make_rrsig_signature_data(rrset, rrsig_template, origin)
# pylint: disable=possibly-used-before-assignment
if isinstance(private_key, GenericPrivateKey):
keys = zsks
for private_key, dnskey in keys:
- rrsig = dns.dnssec.sign(
+ rrsig = sign(
rrset=rrset,
private_key=private_key,
dnskey=dnskey,
for domain in domains:
if isinstance(domain, str):
domain = dns.name.from_text(domain)
- qname = dns.e164.from_e164(number, domain)
+ qname = from_e164(number, domain)
try:
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
import dns.enum
import dns.inet
+import dns.ipv4
+import dns.ipv6
+import dns.name
import dns.rdata
import dns.wire
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover
- def to_generic(self) -> "dns.edns.GenericOption":
+ def to_generic(self) -> "GenericOption":
"""Creates a dns.edns.GenericOption equivalent of this rdata.
Returns a ``dns.edns.GenericOption``.
"""
wire = self.to_wire()
assert wire is not None # for mypy
- return dns.edns.GenericOption(self.otype, wire)
+ return GenericOption(self.otype, wire)
@classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
def to_text(self) -> str:
return "Generic %d" % self.otype
- def to_generic(self) -> "dns.edns.GenericOption":
+ def to_generic(self) -> "GenericOption":
return self
@classmethod
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
- super().__init__(dns.edns.OptionType.COOKIE)
+ super().__init__(OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
import random
import threading
import time
-from typing import Any, Optional
+from typing import Any, Optional, Union
class EntropyPool:
self.seeded = False
self.seed_pid = 0
- def _stir(self, entropy: bytes) -> None:
+ def _stir(self, entropy: Union[bytes, bytearray]) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
self.pool[self.pool_index] ^= b
self.pool_index += 1
- def stir(self, entropy: bytes) -> None:
+ def stir(self, entropy: Union[bytes, bytearray]) -> None:
with self.lock:
self._stir(entropy)
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import enum
-from typing import Type, TypeVar, Union
+from typing import Any, Optional, Type, TypeVar, Union
+
+import dns.exception
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
@classmethod
def _missing_(cls, value):
cls._check_value(value)
- val = int.__new__(cls, value)
+ val = int.__new__(cls, value) # pyright: ignore
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
- val._value_ = value
+ val._value_ = value # pyright: ignore
return val
@classmethod
if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :])
cls._check_value(value)
- try:
- return cls(value)
- except ValueError:
- return value
+ return cls(value)
raise cls._unknown_exception_class()
@classmethod
return cls.__name__.lower()
@classmethod
- def _prefix(cls):
+ def _prefix(cls) -> str:
return ""
@classmethod
- def _extra_from_text(cls, text): # pylint: disable=W0613
+ def _extra_from_text(cls, text: str) -> Optional[Any]: # pylint: disable=W0613
return None
@classmethod
return current_text
@classmethod
- def _unknown_exception_class(cls):
+ def _unknown_exception_class(cls) -> Type[Exception]:
return ValueError
from typing import Tuple
-import dns
+import dns.exception
def from_text(text: str) -> Tuple[int, int, int]:
"""
# Note that inet_aton() only accepts canonial form, but we still run through
# inet_ntoa() to ensure the output is a str.
- return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))
+ return inet_ntoa(inet_aton(text))
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
- return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))
+ return inet_ntoa(inet_aton(text))
import dns.rdataclass
import dns.rdatatype
import dns.rdtypes.ANY.OPT
+import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.TSIG
import dns.renderer
import dns.rrset
+import dns.tokenizer
import dns.tsig
import dns.ttl
import dns.wire
# worry about that for now. We also don't worry if there is an existing padding
# option, as it is unlikely and probably harmless, as the worst case is that we
# may add another, and this seems to be legal.
- for option in self.opt[0].options:
+ opt_rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+ for option in opt_rdata.options:
wire = option.to_wire()
# We add 4 here to account for the option type and length
size += len(wire) + 4
@property
def keyalgorithm(self) -> Optional[dns.name.Name]:
if self.tsig:
- return self.tsig[0].algorithm
+ rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+ return rdata.algorithm
else:
return None
@property
def mac(self) -> Optional[bytes]:
if self.tsig:
- return self.tsig[0].mac
+ rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+ return rdata.mac
else:
return None
@property
def tsig_error(self) -> Optional[int]:
if self.tsig:
- return self.tsig[0].error
+ rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+ return rdata.error
else:
return None
@property
def payload(self) -> int:
if self.opt:
- return self.opt[0].payload
+ rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+ return rdata.payload
else:
return 0
@property
def options(self) -> Tuple:
if self.opt:
- return self.opt[0].options
+ rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+ return rdata.options
else:
return ()
srrset = self.find_rrset(
self.authority, auname, question.rdclass, dns.rdatatype.SOA
)
- min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum)
+ srdata = cast(dns.rdtypes.ANY.SOA.SOA, srrset[0])
+ min_ttl = min(min_ttl, srrset.ttl, srdata.minimum)
break
except KeyError:
try:
return QueryMessage
elif opcode == dns.opcode.UPDATE:
_maybe_import_update()
- return dns.update.UpdateMessage
+ return dns.update.UpdateMessage # pyright: ignore
else:
return Message
else:
with self.parser.restrict_to(rdlen):
rd = dns.rdata.from_wire_parser(
- rdclass, rdtype, self.parser, self.message.origin
+ rdclass, # pyright: ignore
+ rdtype,
+ self.parser,
+ self.message.origin,
)
covers = rd.covers()
if self.message.xfr and rdtype == dns.rdatatype.SOA:
if rdtype == dns.rdatatype.OPT:
self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
elif rdtype == dns.rdatatype.TSIG:
+ trd = cast(dns.rdtypes.ANY.TSIG.TSIG, rd)
if self.keyring is None or self.keyring is True:
raise UnknownTSIGKey("got signed message without keyring")
elif isinstance(self.keyring, dict):
key = self.keyring.get(absolute_name)
if isinstance(key, bytes):
- key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+ key = dns.tsig.Key(absolute_name, key, trd.algorithm)
elif callable(self.keyring):
key = self.keyring(self.message, absolute_name)
else:
rrset = self.message.find_rrset(
section,
name,
- rdclass,
+ rdclass, # pyright: ignore
rdtype,
covers,
deleting,
def __init__(
self,
- text,
- idna_codec,
- one_rr_per_rrset=False,
- origin=None,
- relativize=True,
- relativize_to=None,
+ text: str,
+ idna_codec: Optional[dns.name.IDNACodec],
+ one_rr_per_rrset: bool = False,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
):
- self.message = None
+ self.message: Optional[Message] = None # mypy: ignore
self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
self.last_name = None
self.one_rr_per_rrset = one_rr_per_rrset
def _question_line(self, section_number):
"""Process one line from the text format question section."""
+ assert self.message is not None
section = self.message.sections[section_number]
token = self.tok.get(want_leading=True)
if not token.is_whitespace():
additional data sections.
"""
+ assert self.message is not None
section = self.message.sections[section_number]
# Name
token = self.tok.get(want_leading=True)
pad = 468
response.use_edns(0, 0, our_payload, query.payload, pad=pad)
if query.had_tsig:
+ assert query.mac is not None
+ assert query.keyalgorithm is not None
response.use_tsig(
query.keyring,
query.keyname,
import dns.immutable
import dns.wire
+# Dnspython will never access idna if the import fails, but pyright can't figure
+# that out, so...
+#
+# pyright: reportAttributeAccessIssue = false, reportPossiblyUnboundVariable = false
+
if dns._features.have("idna"):
import idna # type: ignore
else: # pragma: no cover
have_idna_2008 = False
+
CompressType = Dict["Name", int]
"""DNS Opcodes."""
+from typing import Type
+
import dns.enum
import dns.exception
return 15
@classmethod
- def _unknown_exception_class(cls):
+ def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownOpcode
import dns.name
import dns.quic
import dns.rcode
+import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.serial
self._family = family
def connect_tcp(
- self, host, port, timeout, local_address, socket_options=None
+ self, host, port, timeout=None, local_address=None, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
for address in addresses:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
+ if local_address is None:
+ local_address = "0.0.0.0"
source = dns.inet.low_level_address_tuple(
(local_address, self._local_port), af
)
raise httpcore.ConnectError
def connect_unix_socket(
- self, path, timeout, socket_options=None
+ self, path, timeout=None, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
- class _HTTPTransport(httpx.HTTPTransport):
+ class _HTTPTransport(httpx.HTTPTransport): # pyright: ignore
def __init__(
self,
*args,
else:
class _HTTPTransport: # type: ignore
+ def __init__(
+ self,
+ *args,
+ local_port=0,
+ bootstrap_address=None,
+ resolver=None,
+ family=socket.AF_UNSPEC,
+ **kwargs,
+ ):
+ pass
+
def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
have_doh = _have_httpx
try:
- import ssl
+ import ssl # pyright: ignore
except ImportError: # pragma: no cover
class ssl: # type: ignore
class WantWriteException(Exception):
pass
+ class SSLWantReadError(Exception):
+ pass
+
+ class SSLWantWriteError(Exception):
+ pass
+
class SSLContext:
pass
class SSLSocket:
- pass
+ def pending(self) -> bool:
+ return False
@classmethod
def create_default_context(cls, *args, **kwargs):
if writable:
events |= selectors.EVENT_WRITE
if events:
- sel.register(fd, events)
+ sel.register(fd, events) # pyright: ignore
if expiration is None:
timeout = None
else:
def _maybe_get_resolver(
- resolver: Optional["dns.resolver.Resolver"],
-) -> "dns.resolver.Resolver":
+ resolver: Optional["dns.resolver.Resolver"], # pyright: ignore
+) -> "dns.resolver.Resolver": # pyright: ignore
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
post: bool = True,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
- resolver: Optional["dns.resolver.Resolver"] = None,
+ resolver: Optional["dns.resolver.Resolver"] = None, # pyright: ignore
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
+ # we bind url and then override as pyright can't figure out all paths bind.
+ url = where
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
- else:
- url = where
extensions = {}
if bootstrap_address is None:
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
- assert parsed.hostname is not None # for mypy
- answers = resolver.resolve_name(parsed.hostname, family)
+ assert parsed.hostname is not None # pyright: ignore
+ answers = resolver.resolve_name(parsed.hostname, family) # pyright: ignore
bootstrap_address = random.choice(list(answers.addresses()))
return _http3(
q,
bootstrap_address,
- url,
+ url, # pyright: ignore
timeout,
port,
source,
if not have_doh:
raise NoDOH # pragma: no cover
- if session and not isinstance(session, httpx.Client):
+ if session and not isinstance(session, httpx.Client): # pyright: ignore
raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire()
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
- family=family,
+ family=family, # pyright: ignore
)
- cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
+ cm = httpx.Client( # pyright: ignore
+ http1=h1, http2=h2, verify=verify, transport=transport # pyright: ignore
+ )
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
q.id = 0
wire = q.to_wire()
manager = dns.quic.SyncQuicManager(
- verify_mode=verify, server_name=hostname, h3=True
+ verify_mode=verify, server_name=hostname, h3=True # pyright: ignore
)
with manager:
with cm as s:
if not sock:
# pylint: disable=possibly-used-before-assignment
- _connect(s, destination, expiration)
+ _connect(s, destination, expiration) # pyright: ignore
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection
else:
- manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
+ manager = dns.quic.SyncQuicManager(
+ verify_mode=verify, server_name=hostname # pyright: ignore
+ )
the_manager = manager # for type checking happiness
with manager:
if not connection:
- the_connection = the_manager.connect(where, port, source, source_port)
+ the_connection = the_manager.connect( # pyright: ignore
+ where, port, source, source_port
+ )
(start, expiration) = _compute_times(timeout)
- with the_connection.make_stream(timeout) as stream:
+ with the_connection.make_stream(timeout) as stream: # pyright: ignore
stream.send(wire, True)
wire = stream.receive(_remaining(expiration))
finish = time.time()
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
- expiration: float,
+ expiration: Optional[float],
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
+ r: Optional[dns.message.Message] = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
- if query.keyring and not r.had_tsig:
+ if query.keyring and r is not None and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
"""DNS Result Codes."""
-from typing import Tuple
+from typing import Tuple, Type
import dns.enum
import dns.exception
return 4095
@classmethod
- def _unknown_exception_class(cls):
+ def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownRcode
def _to_wire(
self,
- file: Optional[Any],
+ file: Any,
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
self._to_wire(f, compress, origin, canonicalize)
return f.getvalue()
- def to_generic(
- self, origin: Optional[dns.name.Name] = None
- ) -> "dns.rdata.GenericRdata":
+ def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
- return dns.rdata.GenericRdata(
- self.rdclass, self.rdtype, self.to_wire(origin=origin)
- )
+ return GenericRdata(self.rdclass, self.rdtype, self.to_wire(origin=origin))
def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
"""
+ # the next two lines are for type checkers, so they are bound
+ our = b""
+ their = b""
try:
our = self.to_digestable()
our_relative = False
relativize: bool = True,
**kw: Dict[str, Any],
) -> str:
- return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
+ return r"\# %d " % len(self.data) + _hexify(self.data, **kw) # pyright: ignore
@classmethod
def from_text(
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(self.data)
- def to_generic(
- self, origin: Optional[dns.name.Name] = None
- ) -> "dns.rdata.GenericRdata":
+ def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata":
return self
@classmethod
def get_rdata_class(rdclass, rdtype, use_generic=True):
cls = _rdata_classes.get((rdclass, rdtype))
if not cls:
- cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
+ cls = _rdata_classes.get((dns.rdataclass.ANY, rdtype))
if not cls and _dynamic_load_allowed:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
+ assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
rdata = None
if cls != GenericRdata:
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
+ assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.FormError):
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
self.ttl = ttl
def _clone(self):
- obj = super()._clone()
+ obj = cast(Rdataset, super()._clone())
obj.rdclass = self.rdclass
obj.rdtype = self.rdtype
obj.covers = self.covers
elif ttl < self.ttl:
self.ttl = ttl
- def add( # pylint: disable=arguments-differ,arguments-renamed
+ # pylint: disable=arguments-differ,arguments-renamed
+ def add( # pyright: ignore
self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
) -> None:
"""Add the specified rdata to the rdataset.
if len(self) == 0:
return []
else:
- return self[0]._processing_order(iter(self))
+ return self[0]._processing_order(iter(self)) # pyright: ignore
@dns.immutable.immutable
raise TypeError("immutable")
def __copy__(self):
- return ImmutableRdataset(super().copy())
+ return ImmutableRdataset(super().copy()) # pyright: ignore
def copy(self):
- return ImmutableRdataset(super().copy())
+ return ImmutableRdataset(super().copy()) # pyright: ignore
def union(self, other):
- return ImmutableRdataset(super().union(other))
+ return ImmutableRdataset(super().union(other)) # pyright: ignore
def intersection(self, other):
- return ImmutableRdataset(super().intersection(other))
+ return ImmutableRdataset(super().intersection(other)) # pyright: ignore
def difference(self, other):
- return ImmutableRdataset(super().difference(other))
+ return ImmutableRdataset(super().difference(other)) # pyright: ignore
def symmetric_difference(self, other):
- return ImmutableRdataset(super().symmetric_difference(other))
+ return ImmutableRdataset(super().symmetric_difference(other)) # pyright: ignore
def from_text_list(
self.flags,
self.protocol,
self.algorithm,
- dns.rdata._base64ify(self.key, **kw),
+ dns.rdata._base64ify(self.key, **kw), # pyright: ignore
)
@classmethod
self.key_tag,
self.algorithm,
self.digest_type,
- dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
+ dns.rdata._hexify(
+ self.digest, chunksize=chunksize, **kw # pyright: ignore
+ ),
)
@classmethod
import binascii
+import dns.exception
import dns.immutable
import dns.rdata
# see: rfc7043.txt
__slots__ = ["eui"]
- # define these in subclasses
+ # redefine these in subclasses
+ byte_len = 0
+ text_len = 0
# byte_len = 6 # 0123456789ab (in hex)
# text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab
import base64
import enum
import struct
+from typing import Any, Dict
import dns.enum
import dns.exception
return text
-def _unescape(value):
+def _unescape(value: str) -> bytes:
if value == "":
- return value
+ return b""
unescaped = b""
l = len(value)
i = 0
"""Abstract base class for SVCB parameters"""
@classmethod
- def emptiness(cls):
+ def emptiness(cls) -> Emptiness:
return Emptiness.NEVER
raise NotImplementedError # pragma: no cover
-_class_for_key = {
+_class_for_key: Dict[ParamKey, Any] = {
ParamKey.MANDATORY: MandatoryParam,
ParamKey.ALPN: ALPNParam,
ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam,
raise dns.exception.FormError("keys not in order")
prior_key = key
vlen = parser.get_uint16()
- pcls = _class_for_key.get(key, GenericParam)
+ pkey = ParamKey.make(key)
+ pcls = _class_for_key.get(pkey, GenericParam)
with parser.restrict_to(vlen):
value = pcls.from_wire_parser(parser, origin)
- params[key] = value
+ params[pkey] = value
return cls(rdclass, rdtype, priority, target, params)
def _processing_priority(self):
self.usage,
self.selector,
self.mtype,
- dns.rdata._hexify(self.cert, chunksize=chunksize, **kw),
+ dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), # pyright: ignore
)
@classmethod
import dns.exception
import dns.immutable
+import dns.name
import dns.rdata
+import dns.rdataclass
+import dns.rdatatype
import dns.renderer
import dns.tokenizer
import collections
import random
import struct
-from typing import Any, List
+from typing import Any, Iterable, List, Optional, Tuple, Union
import dns.exception
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdata
+import dns.rdatatype
+import dns.tokenizer
+import dns.wire
class Gateway:
name = ""
- def __init__(self, type, gateway=None):
+ def __init__(self, type: Any, gateway: Optional[Union[str, dns.name.Name]] = None):
self.type = dns.rdata.Rdata._as_uint8(type)
self.gateway = gateway
self._check()
self.gateway = None
elif self.type == 1:
# check that it's OK
+ assert isinstance(self.gateway, str)
dns.ipv4.inet_aton(self.gateway)
elif self.type == 2:
# check that it's OK
+ assert isinstance(self.gateway, str)
dns.ipv6.inet_aton(self.gateway)
elif self.type == 3:
if not isinstance(self.gateway, dns.name.Name):
elif self.type in (1, 2):
return self.gateway
elif self.type == 3:
+ assert isinstance(self.gateway, dns.name.Name)
return str(self.gateway.choose_relativity(origin, relativize))
else:
raise ValueError(self._invalid_type(self.type)) # pragma: no cover
if self.type == 0:
pass
elif self.type == 1:
+ assert isinstance(self.gateway, str)
file.write(dns.ipv4.inet_aton(self.gateway))
elif self.type == 2:
+ assert isinstance(self.gateway, str)
file.write(dns.ipv6.inet_aton(self.gateway))
elif self.type == 3:
+ assert isinstance(self.gateway, dns.name.Name)
self.gateway.to_wire(file, None, origin, False)
else:
raise ValueError(self._invalid_type(self.type)) # pragma: no cover
type_name = ""
- def __init__(self, windows=None):
+ def __init__(self, windows: Optional[Iterable[Tuple[int, bytes]]] = None):
last_window = -1
+ if windows is None:
+ windows = []
self.windows = windows
for window, bitmap in self.windows:
if not isinstance(window, int):
for i, byte in enumerate(bitmap):
for j in range(0, 8):
if byte & (0x80 >> j):
- rdtype = window * 256 + i * 8 + j
+ rdtype = dns.rdatatype.RdataType.make(window * 256 + i * 8 + j)
bits.append(dns.rdatatype.to_text(rdtype))
text += " " + " ".join(bits)
return text
if weight > r:
break
r -= weight
- total -= weight
- ordered.append(rdata) # pylint: disable=undefined-loop-variable
- del rdatas[n] # pylint: disable=undefined-loop-variable
+ total -= weight # pyright: ignore[reportPossiblyUnboundVariable]
+ # pylint: disable=undefined-loop-variable
+ ordered.append(rdata) # pyright: ignore[reportPossiblyUnboundVariable]
+ del rdatas[n] # pyright: ignore[reportPossiblyUnboundVariable]
ordered.append(rdatas[0])
return ordered
import struct
import time
+import dns.edns
import dns.exception
+import dns.rdataclass
+import dns.rdatatype
import dns.tsig
+# Note we can't import dns.message for cicularity reasons
+
QUESTION = 0
ANSWER = 1
AUTHORITY = 2
pad = b""
options = list(opt_rdata.options)
options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad))
- opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options)
+ opt = dns.message.Message._make_opt( # pyright: ignore
+ ttl, opt_rdata.rdclass, options
+ )
self.was_padded = True
self.add_rrset(ADDITIONAL, opt)
# make sure the EDNS version in ednsflags agrees with edns
ednsflags &= 0xFF00FFFF
ednsflags |= edns << 16
- opt = dns.message.Message._make_opt(ednsflags, payload, options)
+ opt = dns.message.Message._make_opt( # pyright: ignore
+ ednsflags, payload, options
+ )
self.add_opt(opt)
def add_tsig(
key = secret
else:
key = dns.tsig.Key(keyname, secret, algorithm)
- tsig = dns.message.Message._make_tsig(
+ tsig = dns.message.Message._make_tsig( # pyright: ignore
keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
)
(tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac)
key = secret
else:
key = dns.tsig.Key(keyname, secret, algorithm)
- tsig = dns.message.Message._make_tsig(
+ tsig = dns.message.Message._make_tsig( # pyright: ignore
keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
)
(tsig, ctx) = dns.tsig.sign(
import threading
import time
import warnings
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
from urllib.parse import urlparse
import dns._ddr
import dns.rdata
import dns.rdataclass
import dns.rdatatype
+import dns.rdtypes.ANY.PTR
import dns.rdtypes.svcbbase
import dns.reversename
import dns.tsig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- def _check_kwargs(self, qnames, responses=None):
+ def _check_kwargs(self, qnames, responses=None): # pyright: ignore
if not isinstance(qnames, (list, tuple, set)):
raise AttributeError("qnames must be a list, tuple or set")
if len(qnames) == 0:
self.expiration = time.time() + self.chaining_result.minimum_ttl
def __getattr__(self, attr): # pragma: no cover
- if attr == "name":
- return self.rrset.name
- elif attr == "ttl":
- return self.rrset.ttl
- elif attr == "covers":
- return self.rrset.covers
- elif attr == "rdclass":
- return self.rrset.rdclass
- elif attr == "rdtype":
- return self.rrset.rdtype
+ if self.rrset is not None:
+ if attr == "name":
+ return self.rrset.name
+ elif attr == "ttl":
+ return self.rrset.ttl
+ elif attr == "covers":
+ return self.rrset.covers
+ elif attr == "rdclass":
+ return self.rrset.rdclass
+ elif attr == "rdtype":
+ return self.rrset.rdtype
else:
raise AttributeError(attr)
def __len__(self) -> int:
- return self.rrset and len(self.rrset) or 0
+ return self.rrset is not None and len(self.rrset) or 0
def __iter__(self) -> Iterator[Any]:
- return self.rrset and iter(self.rrset) or iter(tuple())
+ return self.rrset is not None and iter(self.rrset) or iter(tuple())
def __getitem__(self, i):
if self.rrset is None:
try:
answer = self.resolve(name, raise_on_no_answer=False)
canonical_name = answer.canonical_name
- except dns.resolver.NXDOMAIN as e:
+ except NXDOMAIN as e:
canonical_name = e.canonical_name
return canonical_name
tcp: bool = False,
resolver: Optional[Resolver] = None,
lifetime: Optional[float] = None,
-) -> dns.name.Name:
+) -> dns.name.Name: # pyright: ignore[reportReturnType]
"""Find the name of the zone which contains the specified name.
*name*, an absolute ``dns.name.Name`` or ``str``, the query name.
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher
- except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e:
- if isinstance(e, dns.resolver.NXDOMAIN):
+ except (NXDOMAIN, NoAnswer) as e:
+ if isinstance(e, NXDOMAIN):
response = e.responses().get(name)
else:
response = e.response() # pylint: disable=no-value-for-parameter
else:
for address in resolver.resolve_name(where, family).addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
- res = dns.resolver.Resolver(configure=False)
+ res = Resolver(configure=False)
res.nameservers = nameservers
return res
# running process.
#
-_protocols_for_socktype = {
+_protocols_for_socktype: Dict[Any, List[Any]] = {
socket.SOCK_DGRAM: [socket.SOL_UDP],
socket.SOCK_STREAM: [socket.SOL_TCP],
}
-_resolver = None
+_resolver: Optional[Resolver] = None
_original_getaddrinfo = socket.getaddrinfo
_original_getnameinfo = socket.getnameinfo
_original_getfqdn = socket.getfqdn
pass
# Something needs resolution!
try:
+ assert _resolver is not None
answers = _resolver.resolve_name(host, family)
addrs = answers.addresses_and_families()
canonical_name = answers.canonical_name().to_text(True)
- except dns.resolver.NXDOMAIN:
+ except NXDOMAIN:
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
except Exception:
# We raise EAI_AGAIN here as the failure may be temporary
except Exception:
if flags & socket.AI_NUMERICSERV == 0:
try:
- port = socket.getservbyname(service)
+ port = socket.getservbyname(service) # pyright: ignore
except Exception:
pass
if port is None:
cname = ""
for addr, af in addrs:
for socktype in socktypes:
- for proto in _protocols_for_socktype[socktype]:
+ for sockproto in _protocols_for_socktype[socktype]:
+ proto = int(sockproto)
addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
tuples.append((af, socktype, proto, cname, addr_tuple))
if len(tuples) == 0:
qname = dns.reversename.from_address(addr)
if flags & socket.NI_NUMERICHOST == 0:
try:
+ assert _resolver is not None
answer = _resolver.resolve(qname, "PTR")
- hostname = answer.rrset[0].target.to_text(True)
- except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ assert answer.rrset is not None
+ rdata = cast(dns.rdtypes.ANY.PTR.PTR, answer.rrset[0])
+ hostname = rdata.target.to_text(True)
+ except (NXDOMAIN, NoAnswer):
if flags & socket.NI_NAMEREQD:
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
hostname = addr
import binascii
+import dns.exception
import dns.ipv4
import dns.ipv6
import dns.name
from typing import Any, Collection, Dict, Optional, Union, cast
import dns.name
+import dns.rdata
import dns.rdataclass
import dns.rdataset
+import dns.rdatatype
import dns.renderer
self.deleting = deleting
def _clone(self):
- obj = super()._clone()
+ obj = cast(RRset, super()._clone())
obj.name = self.name
obj.deleting = self.deleting
return obj
import dns.exception
import dns.name
import dns.node
+import dns.rdata
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
raise TypeError(f"{method}: expected more arguments")
def _add(self, replace, args):
+ if replace:
+ method = "replace()"
+ else:
+ method = "add()"
try:
args = collections.deque(args)
- if replace:
- method = "replace()"
- else:
- method = "add()"
arg = args.popleft()
if isinstance(arg, str):
arg = dns.name.from_text(arg, None)
raise TypeError(
f"{method} requires a name or RRset as the first argument"
)
+ assert rdataset is not None # for type checkers
if rdataset.rdclass != self.manager.get_class():
raise ValueError(f"{method} has objects of wrong RdataClass")
if rdataset.rdtype == dns.rdatatype.SOA:
raise TypeError(f"not enough parameters to {method}")
def _delete(self, exact, args):
+ if exact:
+ method = "delete_exact()"
+ else:
+ method = "delete()"
try:
args = collections.deque(args)
- if exact:
- method = "delete_exact()"
- else:
- method = "delete()"
arg = args.popleft()
if isinstance(arg, str):
arg = dns.name.from_text(arg, None)
import hashlib
import hmac
import struct
+from typing import Union
import dns.exception
import dns.name
import dns.rcode
import dns.rdataclass
+import dns.rdatatype
class BadTime(dns.exception.DNSException):
if request_mac:
ctx.update(struct.pack("!H", len(request_mac)))
ctx.update(request_mac)
+ assert ctx is not None # for type checkers
ctx.update(struct.pack("!H", rdata.original_id))
ctx.update(wire[2:])
if first:
class Key:
- def __init__(self, name, secret, algorithm=default_algorithm):
+ def __init__(
+ self,
+ name: Union[dns.name.Name, str],
+ secret: Union[bytes, str],
+ algorithm: Union[dns.name.Name, str] = default_algorithm,
+ ):
if isinstance(name, str):
name = dns.name.from_text(name)
self.name = name
import dns.tsig
-def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
+def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, Any]:
"""Convert a dictionary containing (textual DNS name, base64 secret)
pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
a dictionary containing (textual DNS name, (algorithm, base64 secret))
pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs.
@rtype: dict"""
- keyring = {}
+ keyring: Dict[dns.name.Name, Any] = {}
for name, value in textring.items():
kname = dns.name.from_text(name)
if isinstance(value, str):
if isinstance(value, int):
return value
elif isinstance(value, str):
- return dns.ttl.from_text(value)
+ return from_text(value)
else:
raise ValueError("cannot convert value to TTL")
from typing import Any, List, Optional, Union
+import dns.enum
+import dns.exception
import dns.message
import dns.name
import dns.opcode
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
+import dns.rrset
import dns.tsig
# Updates are always one_rr_per_rrset
return True
- def _parse_rr_header(self, section, name, rdclass, rdtype):
+ def _parse_rr_header(self, section, name, rdclass, rdtype): # pyright: ignore
deleting = None
empty = False
if section == UpdateSection.ZONE:
import collections
import threading
-from typing import Callable, Deque, Optional, Set, Union
+from typing import Callable, Deque, Optional, Set, Union, cast
import dns.exception
import dns.immutable
n = v.nodes.get(oname)
if n:
rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
- if rds and rds[0].serial == serial:
+ if rds is None:
+ continue
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, rds[0])
+ if rds and soa.serial == serial:
version = v
break
if version is None:
# Note our definition of least_kept also ensures we do not try to
# delete the greatest version.
if len(self._readers) > 0:
- least_kept = min(txn.version.id for txn in self._readers)
+ least_kept = min(txn.version.id for txn in self._readers) # pyright: ignore
else:
least_kept = self._versions[-1].id
while self._versions[0].id < least_kept and self._pruning_policy(
if max_versions is not None and max_versions < 1:
raise ValueError("max versions must be at least 1")
if max_versions is None:
-
- def policy(zone, _): # pylint: disable=unused-argument
+ # pylint: disable=unused-argument
+ def policy(zone, _): # pyright: ignore
return False
else:
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, List, Optional, Tuple, Union, cast
+import dns.edns
import dns.exception
import dns.message
import dns.name
import dns.rcode
+import dns.rdata
import dns.rdataset
import dns.rdatatype
+import dns.rdtypes
+import dns.rdtypes.ANY
+import dns.rdtypes.ANY.SMIMEA
+import dns.rdtypes.ANY.SOA
+import dns.rdtypes.svcbbase
import dns.serial
import dns.transaction
import dns.tsig
if rdataset.rdtype != dns.rdatatype.SOA:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
- self.soa_rdataset = rdataset.copy()
+ self.soa_rdataset = rdataset.copy() # pyright: ignore
if self.rdtype == dns.rdatatype.IXFR:
- if self.soa_rdataset[0].serial == self.serial:
+ assert self.soa_rdataset is not None
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, self.soa_rdataset[0])
+ if soa.serial == self.serial:
#
# We're already up-to-date.
#
self.done = True
- elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial:
+ elif dns.serial.Serial(soa.serial) < self.serial:
# It went backwards!
raise SerialWentBackwards
else:
#
# This is the final SOA
#
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
if self.expecting_SOA:
# We got an empty IXFR sequence!
raise dns.exception.FormError("empty IXFR sequence")
- if (
- self.rdtype == dns.rdatatype.IXFR
- and self.serial != rdataset[0].serial
- ):
+ if self.rdtype == dns.rdatatype.IXFR and self.serial != soa.serial:
raise dns.exception.FormError("unexpected end of IXFR sequence")
self.txn.replace(name, rdataset)
self.txn.commit()
# This is not the final SOA
#
self.expecting_SOA = False
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
if self.rdtype == dns.rdatatype.IXFR:
if self.delete_mode:
# This is the start of an IXFR deletion set
- if rdataset[0].serial != self.serial:
+ if soa.serial != self.serial:
raise dns.exception.FormError(
"IXFR base serial mismatch"
)
else:
# This is the start of an IXFR addition set
- self.serial = rdataset[0].serial
+ self.serial = soa.serial
self.txn.replace(name, rdataset)
else:
# We saw a non-final SOA for the origin in an AXFR.
with txn_manager.reader() as txn:
rdataset = txn.get(origin, "SOA")
if rdataset:
- serial = rdataset[0].serial
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
+ serial = soa.serial
rdtype = dns.rdatatype.IXFR
else:
serial = None
return None
elif question.rdtype != dns.rdatatype.IXFR:
raise ValueError("query is not an AXFR or IXFR")
- soa = query.find_rrset(
+ soa_rrset = query.find_rrset(
query.authority, question.name, question.rdclass, dns.rdatatype.SOA
)
- return soa[0].serial
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rrset[0])
+ return soa.serial
Set,
Tuple,
Union,
+ cast,
)
import dns.exception
for n in names:
l = self[n].to_text(
n,
- origin=self.origin,
- relativize=relativize,
- want_comments=want_comments,
+ origin=self.origin, # pyright: ignore
+ relativize=relativize, # pyright: ignore
+ want_comments=want_comments, # pyright: ignore
)
l_b = l.encode(file_enc)
# an SOA if there is no origin.
raise NoSOA
origin_name = self.origin
- soa: Optional[dns.rdataset.Rdataset]
+ soa_rds: Optional[dns.rdataset.Rdataset]
if txn:
- soa = txn.get(origin_name, dns.rdatatype.SOA)
+ soa_rds = txn.get(origin_name, dns.rdatatype.SOA)
else:
- soa = self.get_rdataset(origin_name, dns.rdatatype.SOA)
- if soa is None:
+ soa_rds = self.get_rdataset(origin_name, dns.rdatatype.SOA)
+ if soa_rds is None:
raise NoSOA
- return soa[0]
+ else:
+ soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rds[0])
+ return soa
def _compute_digest(
self,
def _end_write(self, txn):
pass
- def _commit_version(self, _, version, origin):
+ def _commit_version(self, txn, version, origin):
self.nodes = version.nodes
if self.origin is None:
self.origin = origin
- def _get_next_version_id(self):
+ def _get_next_version_id(self) -> int:
# Versions are ephemeral and all have id 1
return 1
def _setup_version(self):
assert self.version is None
- factory = self.manager.writable_version_factory
+ factory = self.manager.writable_version_factory # pyright: ignore
if factory is None:
factory = WritableVersion
- self.version = factory(self.zone, self.replacement)
+ self.version = factory(self.zone, self.replacement) # pyright: ignore
def _get_rdataset(self, name, rdtype, covers):
+ assert self.version is not None
return self.version.get_rdataset(name, rdtype, covers)
def _put_rdataset(self, name, rdataset):
assert not self.read_only
+ assert self.version is not None
self.version.put_rdataset(name, rdataset)
def _delete_name(self, name):
assert not self.read_only
+ assert self.version is not None
self.version.delete_node(name)
def _delete_rdataset(self, name, rdtype, covers):
assert not self.read_only
+ assert self.version is not None
self.version.delete_rdataset(name, rdtype, covers)
def _name_exists(self, name):
+ assert self.version is not None
return self.version.get_node(name) is not None
def _changed(self):
if self.read_only:
return False
else:
+ assert self.version is not None
return len(self.version.changed) > 0
def _end_transaction(self, commit):
+ assert self.zone is not None
+ assert self.version is not None
if self.read_only:
- self.zone._end_read(self)
+ self.zone._end_read(self) # pyright: ignore
elif commit and len(self.version.changed) > 0:
if self.make_immutable:
- factory = self.manager.immutable_version_factory
+ factory = self.manager.immutable_version_factory # pyright: ignore
if factory is None:
factory = ImmutableVersion
version = factory(self.version)
else:
version = self.version
- self.zone._commit_version(self, version, self.version.origin)
+ self.zone._commit_version( # pyright: ignore
+ self, version, self.version.origin
+ )
+
else:
# rollback
- self.zone._end_write(self)
+ self.zone._end_write(self) # pyright: ignore
def _set_origin(self, origin):
+ assert self.version is not None
if self.version.origin is None:
self.version.origin = origin
def _iterate_rdatasets(self):
+ assert self.version is not None
for name, node in self.version.items():
for rdataset in node:
yield (name, rdataset)
def _iterate_names(self):
+ assert self.version is not None
return self.version.keys()
def _get_node(self, name):
+ assert self.version is not None
return self.version.get_node(name)
def _origin_information(self):
+ assert self.version is not None
(absolute, relativize, effective) = self.manager.origin_information()
if absolute is None and self.version.origin is not None:
# No origin has been committed yet, but we've learned one as part of
reader.read()
except dns.zonefile.UnknownOrigin:
# for backwards compatibility
- raise dns.zone.UnknownOrigin
+ raise UnknownOrigin
# Now that we're done reading, do some basic checking of the zone.
if check_origin:
zone.check_origin()
import re
import sys
-from typing import Any, Iterable, List, Optional, Set, Tuple, Union
+from typing import Any, Iterable, List, Optional, Set, Tuple, Union, cast
import dns.exception
import dns.grange
return
self.tok.unget(token)
name = self.last_name
+ if name is None:
+ raise dns.exception.SyntaxError("the last used name is undefined")
+ assert self.zone_origin is not None
if not name.is_subdomain(self.zone_origin):
self._eat_line()
return
# The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
# TTL from the SOA minttl if no $TTL statement is present before the
# SOA is parsed.
- self.default_ttl = rd.minimum
+ soa_rd = cast(dns.rdtypes.ANY.SOA.SOA, rd)
+ self.default_ttl = soa_rd.minimum
self.default_ttl_known = True
if ttl is None:
# if we didn't have a TTL on the SOA, set it!
- ttl = rd.minimum
+ ttl = soa_rd.minimum
# TTL check. We had to wait until now to do this as the SOA RR's
# own TTL can be inferred from its minimum.
ttl = self.default_ttl
elif self.last_ttl_known:
ttl = self.last_ttl
+ else:
+ # We don't go to the extra "look at the SOA" level of effort for
+ # $GENERATE, because the user really ought to have defined a TTL
+ # somehow!
+ raise dns.exception.SyntaxError("Missing default TTL value")
+
# Class
try:
rdclass = dns.rdataclass.from_text(token.value)
name, self.current_origin, self.tok.idna_codec
)
name = self.last_name
+ assert self.zone_origin is not None
if not name.is_subdomain(self.zone_origin):
self._eat_line()
return
)
rrset.update(rdataset)
rrsets.append(rrset)
- self.manager.set_rrsets(rrsets)
+ self.manager.set_rrsets(rrsets) # pyright: ignore
def _set_origin(self, origin):
pass
class RRSetsReaderManager(dns.transaction.TransactionManager):
def __init__(
- self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN
+ self,
+ origin: Optional[dns.name.Name] = dns.name.root,
+ relativize: bool = False,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
):
self.origin = origin
self.relativize = relativize
self.rdclass = rdclass
- self.rrsets = []
+ self.rrsets: List[dns.rrset.RRset] = []
def reader(self): # pragma: no cover
raise NotImplementedError
effective = self.origin
return (self.origin, self.relativize, effective)
- def set_rrsets(self, rrsets):
+ def set_rrsets(self, rrsets: List[dns.rrset.RRset]) -> None:
self.rrsets = rrsets
[[tool.mypy.overrides]]
module = "wmi"
ignore_missing_imports = true
+
+[tool.pyright]
+reportUnsupportedDunderAll = false
+exclude = [
+ "dns/_*_backend.py",
+ "dns/dnssecalgs/*.py",
+ "dns/quic/*.py",
+ "dns/rdtypes/ANY/*.py",
+ "dns/rdtypes/CH/*.py",
+ "dns/rdtypes/IN/*.py",
+ "examples/*.py",
+ "tests/*.py",
+] # (mostly) temporary!