poetry run pytest --lf
potype:
- poetry run python -m mypy examples tests dns/*.py
+ poetry run python -m mypy dns/*.py
+
+potypetests:
+ poetry run python -m mypy --check-untyped-defs examples tests
polint:
poetry run pylint dns
'asyncquery',
'asyncresolver',
'dnssec',
+ 'dnssectypes',
'e164',
'edns',
'entropy',
'wire',
'xfr',
'zone',
+ 'zonetypes',
'zonefile',
]
class DatagramSocket(Socket): # pragma: no cover
+ def __init__(self, family: int):
+ self.family = family
+
async def sendto(self, what, destination, timeout):
raise NotImplementedError
def datagram_connection_required(self):
return False
+
+ async def sleep(self, interval):
+ raise NotImplementedError
class DatagramSocket(dns._asyncbackend.DatagramSocket):
- def __init__(self, family, transport, protocol):
- self.family = family
+ def __init__(self, family: int, transport, protocol):
+ super().__init__(family)
self.transport = transport
self.protocol = protocol
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
+ super().__init__(socket.family)
self.socket = socket
- self.family = socket.family
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
+ super().__init__(socket.family)
self.socket = socket
- self.family = socket.family
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+from typing import Dict
+
import dns.exception
# pylint: disable=unused-import
_default_backend = None
-_backends = {}
+_backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
pass
-def get_backend(name):
+def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio",
return backend
-def sniff():
+def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
except RuntimeError:
raise AsyncLibraryNotFoundError('no async library detected')
except AttributeError: # pragma: no cover
- # we have to check current_task on 3.6
- if not asyncio.Task.current_task():
+ # we have to check current_task on 3.6; we ignore for mypy
+ # purposes at it is otherwise unhappy on >= 3.7
+ if not asyncio.Task.current_task(): # type: ignore
raise AsyncLibraryNotFoundError('no async library detected')
return 'asyncio'
-def get_default_backend():
+def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary.
"""
if _default_backend:
return set_default_backend(sniff())
-def set_default_backend(name):
+def set_default_backend(name: str):
"""Set the default backend.
It's not normally necessary to call this method, as
+++ /dev/null
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-class Backend:
- ...
-
-def get_backend(name: str) -> Backend:
- ...
-def sniff() -> str:
- ...
-def get_default_backend() -> Backend:
- ...
-def set_default_backend(name: str) -> Backend:
- ...
"""Talk to a DNS server."""
+from typing import Any, Dict, Optional, Tuple, Union
+
import base64
import socket
import struct
import dns.rcode
import dns.rdataclass
import dns.rdatatype
+import dns.transaction
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
UDPMode, _have_httpx, _have_http2, NoDOH
return None
-async def send_udp(sock, what, destination, expiration=None):
+async def send_udp(sock: dns.asyncbackend.DatagramSocket,
+ what: Union[dns.message.Message, bytes], destination: Any,
+ expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
return (n, sent_time)
-async def receive_udp(sock, destination=None, expiration=None,
+async def receive_udp(sock: dns.asyncbackend.DatagramSocket,
+ destination: Optional[Any]=None, expiration: Optional[float]=None,
ignore_unexpected=False, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False,
- raise_on_truncation=False):
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+ ignore_trailing=False, raise_on_truncation=False) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
raise_on_truncation=raise_on_truncation)
return (r, received_time, from_address)
-async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
- ignore_unexpected=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False, sock=None,
- backend=None):
+async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
+ ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+ raise_on_truncation=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
dtuple)
+ assert s is not None
await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration,
ignore_unexpected,
if not sock and s:
await s.close()
-async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
- source_port=0, ignore_unexpected=False,
- one_rr_per_rrset=False, ignore_trailing=False,
- udp_sock=None, tcp_sock=None, backend=None):
+async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
+ ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+ udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None,
+ tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
return (response, True)
-async def send_tcp(sock, what, expiration=None):
+async def send_tcp(sock: dns.asyncbackend.StreamSocket,
+ what: Union[dns.message.Message, bytes],
+ expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
"""
if isinstance(what, dns.message.Message):
- what = what.to_wire()
- l = len(what)
+ wire = what.to_wire()
+ else:
+ wire = what
+ l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
- tcpmsg = struct.pack("!H", l) + what
+ tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
return s
-async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False):
+async def receive_tcp(sock: dns.asyncbackend.StreamSocket,
+ expiration: Optional[float]=None, one_rr_per_rrset=False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
+ request_mac=b'', ignore_trailing=False) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
return (r, received_time)
-async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False, sock=None,
- backend=None):
+async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ sock: Optional[dns.asyncbackend.StreamSocket]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout)
+ assert s is not None
await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac,
if not sock and s:
await s.close()
-async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False, sock=None,
- backend=None, ssl_context=None, server_hostname=None):
+async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+ port=853, source: Optional[str]=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ sock: Optional[dns.asyncbackend.StreamSocket]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None,
+ ssl_context: Optional[ssl.SSLContext]=None,
+ server_hostname: Optional[str]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
if not sock and s:
await s.close()
-async def https(q, where, timeout=None, port=443, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False, client=None,
+async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+ port=443, source: Optional[str]=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ client: Optional[httpx.AsyncClient]=None,
path='/dns-query', post=True, verify=True):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
timeout=timeout)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
- wire = wire.decode() # httpx does a repr() if we give it bytes
+ twire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout,
- params={"dns": wire})
+ params={"dns": twire})
finally:
if client_to_close:
- await client.aclose()
+ await client_to_close.aclose()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}'
- '\nResponse body: {}'.format(where,
+ '\nResponse body: {!r}'.format(where,
response.status_code,
response.content))
r = dns.message.from_wire(response.content,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
- r.time = response.elapsed
+ r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
-async def inbound_xfr(where, txn_manager, query=None,
- port=53, timeout=None, lifetime=None, source=None,
- source_port=0, udp_mode=UDPMode.NEVER, backend=None):
+async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
+ query: Optional[dns.message.Message]=None,
+ port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+ source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER,
+ backend: Optional[dns.asyncbackend.Backend]=None):
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
+++ /dev/null
-from typing import Optional, Union, Dict, Generator, Any
-from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
-
-# If the ssl import works, then
-#
-# error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
- import ssl
-except ImportError:
- class ssl: # type: ignore
- SSLContext : Dict = {}
-
-async def udp(q : message.Message, where : str,
- timeout : Optional[float] = None, port=53,
- source : Optional[str] = None, source_port : Optional[int] = 0,
- ignore_unexpected : Optional[bool] = False,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[asyncbackend.DatagramSocket] = None,
- backend : Optional[asyncbackend.Backend] = None) -> message.Message:
- pass
-
-async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
- af : Optional[int] = None, source : Optional[str] = None,
- source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[asyncbackend.StreamSocket] = None,
- backend : Optional[asyncbackend.Backend] = None) -> message.Message:
- pass
-
-async def tls(q : message.Message, where : str,
- timeout : Optional[float] = None, port=53,
- source : Optional[str] = None, source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[asyncbackend.StreamSocket] = None,
- backend : Optional[asyncbackend.Backend] = None,
- ssl_context: Optional[ssl.SSLContext] = None,
- server_hostname: Optional[str] = None) -> message.Message:
- pass
"""Asynchronous DNS stub resolver."""
+from typing import Optional, Union
+
import time
import dns.asyncbackend
import dns.asyncquery
import dns.exception
+import dns.name
import dns.query
import dns.resolver # lgtm[py/import-and-import-from]
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
- async def resolve(self, qname, rdtype=dns.rdatatype.A,
+ async def resolve(self, qname: Union[dns.name.Name, str],
+ rdtype=dns.rdatatype.A,
rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime=None, search=None,
- backend=None):
+ tcp=False, source: Optional[str]=None,
+ raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None, search: Optional[bool]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
if answer is not None:
# cache hit!
return answer
+ assert request is not None # needed for type checking
done = False
while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
if answer is not None:
return answer
- async def resolve_address(self, ipaddr, *args, **kwargs):
+ async def resolve_address(self, ipaddr: str, *args, **kwargs) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
function.
"""
-
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs['rdtype'] = dns.rdatatype.PTR
+ modified_kwargs['rdclass'] = dns.rdataclass.IN
return await self.resolve(dns.reversename.from_address(ipaddr),
- rdtype=dns.rdatatype.PTR,
- rdclass=dns.rdataclass.IN,
- *args, **kwargs)
+ *args, **modified_kwargs)
# pylint: disable=redefined-outer-name
- async def canonical_name(self, name):
+ async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
default_resolver = None
-def get_default_resolver():
+def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
+ assert default_resolver is not None
return default_resolver
default_resolver = Resolver()
-async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime=None, search=None, backend=None):
+async def resolve(qname: Union[dns.name.Name, str],
+ rdtype=dns.rdatatype.A,
+ rdclass=dns.rdataclass.IN,
+ tcp=False, source: Optional[str]=None,
+ raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None, search: Optional[bool]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
backend)
-async def resolve_address(ipaddr, *args, **kwargs):
+async def resolve_address(ipaddr: str, *args, **kwargs) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
-async def canonical_name(name):
+async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
return await get_default_resolver().canonical_name(name)
-async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
- resolver=None, backend=None):
+async def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
+ tcp=False, resolver: Optional[Resolver]=None,
+ backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
try:
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
tcp, backend=backend)
+ assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher
+++ /dev/null
-from typing import Union, Optional, List, Any, Dict
-from . import exception, rdataclass, name, rdatatype, asyncbackend
-
-async def resolve(qname : str, rdtype : Union[int,str] = 0,
- rdclass : Union[int,str] = 0,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime : Optional[float]=None,
- search : Optional[bool]=None,
- backend : Optional[asyncbackend.Backend]=None):
- ...
-async def resolve_address(self, ipaddr: str,
- *args: Any, **kwargs: Optional[Dict]):
- ...
-
-class Resolver:
- def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
- configure : Optional[bool] = True):
- self.nameservers : List[str]
- async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
- rdclass : Union[int,str] = rdataclass.IN,
- tcp : bool = False, source : Optional[str] = None,
- raise_on_no_answer=True, source_port : int = 0,
- lifetime : Optional[float]=None,
- search : Optional[bool]=None,
- backend : Optional[asyncbackend.Backend]=None):
- ...
"""Common DNSSEC-related functions and constants."""
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
+
import hashlib
import struct
import time
import base64
-import dns.enum
+from dns.dnssectypes import *
+
import dns.exception
import dns.name
import dns.node
import dns.rdata
import dns.rdatatype
import dns.rdataclass
+import dns.rrset
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+from dns.rdtypes.ANY.DS import DS
+from dns.rdtypes.ANY.RRSIG import RRSIG
class UnsupportedAlgorithm(dns.exception.DNSException):
"""The DNSSEC signature is invalid."""
-class Algorithm(dns.enum.IntEnum):
- RSAMD5 = 1
- DH = 2
- DSA = 3
- ECC = 4
- RSASHA1 = 5
- DSANSEC3SHA1 = 6
- RSASHA1NSEC3SHA1 = 7
- RSASHA256 = 8
- RSASHA512 = 10
- ECCGOST = 12
- ECDSAP256SHA256 = 13
- ECDSAP384SHA384 = 14
- ED25519 = 15
- ED448 = 16
- INDIRECT = 252
- PRIVATEDNS = 253
- PRIVATEOID = 254
-
- @classmethod
- def _maximum(cls):
- return 255
-
-
-def algorithm_from_text(text):
+def algorithm_from_text(text: str) -> Algorithm:
"""Convert text into a DNSSEC algorithm value.
*text*, a ``str``, the text to convert to into an algorithm value.
return Algorithm.from_text(text)
-def algorithm_to_text(value):
+def algorithm_to_text(value: Union[Algorithm, int]) -> str:
"""Convert a DNSSEC algorithm value to text
- *value*, an ``int`` a DNSSEC algorithm.
+ *value*, a ``dns.dnssec.Algorithm``.
Returns a ``str``, the name of a DNSSEC algorithm.
"""
return Algorithm.to_text(value)
-def key_id(key):
+def key_id(key: DNSKEY) -> int:
"""Return the key id (a 16-bit number) for the specified key.
*key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``
total += ((total >> 16) & 0xffff)
return total & 0xffff
-class DSDigest(dns.enum.IntEnum):
- """DNSSEC Delegation Signer Digest Algorithm"""
-
- SHA1 = 1
- SHA256 = 2
- SHA384 = 4
-
- @classmethod
- def _maximum(cls):
- return 255
-
-def make_ds(name, key, algorithm, origin=None):
+def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
+ algorithm: Union[DSDigest, str],
+ origin: Optional[dns.name.Name]=None) -> DS:
"""Create a DS record for a DNSSEC key.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record.
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
-
+ if not isinstance(key, DNSKEY):
+ raise ValueError('key is not a DNSKEY')
if algorithm == DSDigest.SHA1:
dshash = hashlib.sha1()
elif algorithm == DSDigest.SHA256:
if isinstance(name, str):
name = dns.name.from_text(name, origin)
- dshash.update(name.canonicalize().to_wire())
+ wire = name.canonicalize().to_wire()
+ assert wire is not None
+ dshash.update(wire)
dshash.update(key.to_wire(origin=origin))
digest = dshash.digest()
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \
digest
- return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
- len(dsrdata))
+ ds = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
+ len(dsrdata))
+ return cast(DS, ds)
-def _find_candidate_keys(keys, rrsig):
+def _find_candidate_keys(keys, rrsig: RRSIG) -> Optional[List[DNSKEY]]:
value = keys.get(rrsig.signer)
if isinstance(value, dns.node.Node):
rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY)
rdataset = value
if rdataset is None:
return None
- return [rd for rd in rdataset if
+ return [cast(DNSKEY, rd) for rd in rdataset if
rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag]
-def _is_rsa(algorithm):
+def _is_rsa(algorithm: int) -> bool:
return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1,
Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256,
Algorithm.RSASHA512)
-def _is_dsa(algorithm):
+def _is_dsa(algorithm: int) -> bool:
return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
-def _is_ecdsa(algorithm):
+def _is_ecdsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
-def _is_eddsa(algorithm):
+def _is_eddsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ED25519, Algorithm.ED448)
-def _is_gost(algorithm):
+def _is_gost(algorithm: int) -> bool:
return algorithm == Algorithm.ECCGOST
-def _is_md5(algorithm):
+def _is_md5(algorithm: int) -> bool:
return algorithm == Algorithm.RSAMD5
-def _is_sha1(algorithm):
+def _is_sha1(algorithm: int) -> bool:
return algorithm in (Algorithm.DSA, Algorithm.RSASHA1,
Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1)
-def _is_sha256(algorithm):
+def _is_sha256(algorithm: int) -> bool:
return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
-def _is_sha384(algorithm):
+def _is_sha384(algorithm: int) -> bool:
return algorithm == Algorithm.ECDSAP384SHA384
-def _is_sha512(algorithm):
+def _is_sha512(algorithm: int) -> bool:
return algorithm == Algorithm.RSASHA512
-def _make_hash(algorithm):
+def _make_hash(algorithm: int) -> Any:
if _is_md5(algorithm):
return hashes.MD5()
if _is_sha1(algorithm):
raise ValidationFailure('unknown hash for algorithm %u' % algorithm)
-def _bytes_to_long(b):
+def _bytes_to_long(b: bytes) -> int:
return int.from_bytes(b, 'big')
-def _validate_signature(sig, data, key, chosen_hash):
+def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any):
+ keyptr: bytes
if _is_rsa(key.algorithm):
+ # we ignore because mypy is confused and thinks key.key is a str for unknown reasons.
keyptr = key.key
(bytes_,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[1:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
try:
- public_key = rsa.RSAPublicNumbers(
+ rsa_public_key = rsa.RSAPublicNumbers(
_bytes_to_long(rsa_e),
_bytes_to_long(rsa_n)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
- public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
+ rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
elif _is_dsa(key.algorithm):
keyptr = key.key
(t,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
- public_key = dsa.DSAPublicNumbers(
+ dsa_public_key = dsa.DSAPublicNumbers(
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p),
_bytes_to_long(dsa_g))).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
- public_key.verify(sig, data, chosen_hash)
+ dsa_public_key.verify(sig, data, chosen_hash)
elif _is_ecdsa(key.algorithm):
keyptr = key.key
+ curve: Any
if key.algorithm == Algorithm.ECDSAP256SHA256:
curve = ec.SECP256R1()
octets = 32
ecdsa_x = keyptr[0:octets]
ecdsa_y = keyptr[octets:octets * 2]
try:
- public_key = ec.EllipticCurvePublicNumbers(
+ ecdsa_public_key = ec.EllipticCurvePublicNumbers(
curve=curve,
x=_bytes_to_long(ecdsa_x),
y=_bytes_to_long(ecdsa_y)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
- public_key.verify(sig, data, ec.ECDSA(chosen_hash))
+ ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
elif _is_eddsa(key.algorithm):
keyptr = key.key
+ loader: Any
if key.algorithm == Algorithm.ED25519:
loader = ed25519.Ed25519PublicKey
else:
loader = ed448.Ed448PublicKey
try:
- public_key = loader.from_public_bytes(keyptr)
+ eddsa_public_key = loader.from_public_bytes(keyptr)
except ValueError:
raise ValidationFailure('invalid public key')
- public_key.verify(sig, data)
+ eddsa_public_key.verify(sig, data)
elif _is_gost(key.algorithm):
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' %
raise ValidationFailure('unknown algorithm %u' % key.algorithm)
-def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
+def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ rrsig: RRSIG,
+ keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+ origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
"""Validate an RRset against a single signature rdata, throwing an
exception if validation is not successful.
*origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative
names.
- *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to
+ *now*, a ``float`` or ``None``, the time, in seconds since the epoch, to
use as the current time when validating. If ``None``, the actual current
time is used.
data += rrsig.signer.to_digestable(origin)
# Derelativize the name before considering labels.
- rrname = rrname.derelativize(origin)
+ if not rrname.is_absolute():
+ if origin is None:
+ raise ValidationFailure('relative RR name without an origin specified')
+ rrname = rrname.derelativize(origin)
if len(rrname) - 1 < rrsig.labels:
raise ValidationFailure('owner name longer than RRSIG labels')
raise ValidationFailure('verify failure')
-def _validate(rrset, rrsigset, keys, origin=None, now=None):
+def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+ origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
"""Validate an RRset against a signature RRset, throwing an exception
if none of the signatures validate.
raise ValidationFailure("owner names do not match")
for rrsig in rrsigrdataset:
+ if not isinstance(rrsig, RRSIG):
+ raise ValidationFailure('expected an RRSIG')
try:
_validate_rrsig(rrset, rrsig, keys, origin, now)
return
raise ValidationFailure("no RRSIGs validated")
-class NSEC3Hash(dns.enum.IntEnum):
- """NSEC3 hash algorithm"""
-
- SHA1 = 1
-
- @classmethod
- def _maximum(cls):
- return 255
-
def nsec3_hash(domain, salt, iterations, algorithm):
"""
Calculate the NSEC3 hash, according to
+++ /dev/null
-from typing import Union, Dict, Tuple, Optional
-from . import rdataset, rrset, exception, name, rdtypes, rdata, node
-import dns.rdtypes.ANY.DS as DS
-import dns.rdtypes.ANY.DNSKEY as DNSKEY
-
-_have_pyca : bool
-
-def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
- ...
-
-def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
- ...
-
-class ValidationFailure(exception.DNSException):
- ...
-
-def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
- ...
-
-def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
- ...
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""Common DNSSEC-related types."""
+
+# This is a separate file to avoid import circularity between dns.dnssec and
+# the implementations of the DS and DNSKEY types.
+
+import dns.enum
+
+
+class Algorithm(dns.enum.IntEnum):
+ RSAMD5 = 1
+ DH = 2
+ DSA = 3
+ ECC = 4
+ RSASHA1 = 5
+ DSANSEC3SHA1 = 6
+ RSASHA1NSEC3SHA1 = 7
+ RSASHA256 = 8
+ RSASHA512 = 10
+ ECCGOST = 12
+ ECDSAP256SHA256 = 13
+ ECDSAP384SHA384 = 14
+ ED25519 = 15
+ ED448 = 16
+ INDIRECT = 252
+ PRIVATEDNS = 253
+ PRIVATEOID = 254
+
+ @classmethod
+ def _maximum(cls):
+ return 255
+
+
+class DSDigest(dns.enum.IntEnum):
+ """DNSSEC Delegation Signer Digest Algorithm"""
+
+ SHA1 = 1
+ SHA256 = 2
+ SHA384 = 4
+
+ @classmethod
+ def _maximum(cls):
+ return 255
+
+
+class NSEC3Hash(dns.enum.IntEnum):
+ """NSEC3 hash algorithm"""
+
+ SHA1 = 1
+
+ @classmethod
+ def _maximum(cls):
+ return 255
"""DNS E.164 helpers."""
+from typing import Iterable, Optional, Union
+
import dns.exception
import dns.name
import dns.resolver
public_enum_domain = dns.name.from_text('e164.arpa.')
-def from_e164(text, origin=public_enum_domain):
+def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
return dns.name.from_text('.'.join(parts), origin=origin)
-def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
+def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain,
+ want_plus_prefix=True) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
return text.decode()
-def query(number, domains, resolver=None):
+def query(number: str, domains: Iterable[Union[dns.name.Name, str]],
+ resolver: Optional[dns.resolver.Resolver]=None) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
+++ /dev/null
-from typing import Optional, Iterable
-from . import name, resolver
-def from_e164(text : str, origin=name.Name(".")) -> name.Name:
- ...
-
-def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
- ...
-
-def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
- ...
"""EDNS Options"""
+from typing import Any, Dict, Optional, Union
+
import math
import socket
import struct
import dns.enum
import dns.inet
import dns.rdata
+import dns.wire
class OptionType(dns.enum.IntEnum):
"""Base class for all EDNS option types."""
- def __init__(self, otype):
+ def __init__(self, otype: Union[OptionType, str]):
"""Initialize an option.
- *otype*, an ``int``, is the option type.
+ *otype*, a ``dns.edns.OptionType``, is the option type.
"""
self.otype = OptionType.make(otype)
- def to_wire(self, file=None):
+ def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
raise NotImplementedError # pragma: no cover
@classmethod
- def from_wire_parser(cls, otype, parser):
+ def from_wire_parser(cls, otype: OptionType, parser: 'dns.wire.Parser') -> 'Option':
"""Build an EDNS option object from wire format.
- *otype*, an ``int``, is the option type.
+ *otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length.
implementation.
"""
- def __init__(self, otype, data):
+ def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
- def to_wire(self, file=None):
+ def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
if file:
file.write(self.data)
+ return None
else:
return self.data
- def to_text(self):
+ def to_text(self) -> str:
return "Generic %d" % self.otype
@classmethod
- def from_wire_parser(cls, otype, parser):
+ def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option:
return cls(otype, parser.get_remaining())
class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
- def __init__(self, address, srclen=None, scopelen=0):
+ def __init__(self, address: str, srclen: Optional[int]=None, scopelen=0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family')
+ assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
- def to_text(self):
+ def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen,
self.scopelen)
@staticmethod
- def from_text(text):
+ def from_text(text) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
'"{}": srclen must be an integer'.format(srclen))
return ECSOption(address, srclen, scope)
- def to_wire(self, file=None):
+ def to_wire(self, file=None) -> Optional[bytes]:
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
self.addrdata)
if file:
file.write(value)
+ return None
else:
return value
@classmethod
- def from_wire_parser(cls, otype, parser):
+ def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser'):
family, src, scope = parser.get_struct('!HBB')
addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen)
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
- def __init__(self, code, text=None):
+ def __init__(self, code: Union[EDECode, str], text: Optional[str]=None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None')
-
- self.code = code
self.text = text
- def to_text(self):
+ def to_text(self) -> str:
output = f'EDE {self.code}'
if self.text is not None:
output += f': {self.text}'
return output
- def to_wire(self, file=None):
+ def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
value = struct.pack('!H', self.code)
if self.text is not None:
value += self.text.encode('utf8')
if file:
file.write(value)
+ return None
else:
return value
@classmethod
- def from_wire_parser(cls, otype, parser):
+ def from_wire_parser(cls, otype: Union[OptionType, str], parser) -> Option:
code = parser.get_uint16()
text = parser.get_remaining()
return cls(code, text)
-_type_to_class = {
+_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
}
-def get_option_class(otype):
+def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
return cls
-def option_from_wire_parser(otype, parser):
+def option_from_wire_parser(otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
Returns an instance of a subclass of ``dns.edns.Option``.
"""
- cls = get_option_class(otype)
- otype = OptionType.make(otype)
+ the_otype = OptionType.make(otype)
+ cls = get_option_class(the_otype)
return cls.from_wire_parser(otype, parser)
-def option_from_wire(otype, wire, current, olen):
+def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, olen: int) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser)
-def register_type(implementation, otype):
+def register_type(implementation: Any, otype: OptionType):
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+from typing import Any, Optional
+
import os
import hashlib
import random
def __init__(self, seed=None):
self.pool_index = 0
- self.digest = None
+ self.digest: Optional[bytearray] = None
self.next_byte = 0
self.lock = _threading.Lock()
self.hash = hashlib.sha1()
seed = bytearray(seed)
self._stir(seed)
- def random_8(self):
+ def random_8(self) -> int:
with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
self.next_byte += 1
return value
- def random_16(self):
+ def random_16(self) -> int:
return self.random_8() * 256 + self.random_8()
- def random_32(self):
+ def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16()
- def random_between(self, first, last):
+ def random_between(self, first: int, last: int) -> int:
size = last - first + 1
if size > 4294967296:
raise ValueError('too big')
pool = EntropyPool()
+system_random: Optional[Any]
try:
system_random = random.SystemRandom()
except Exception: # pragma: no cover
system_random = None
-def random_16():
+def random_16() -> int:
if system_random is not None:
return system_random.randrange(0, 65536)
else:
return pool.random_16()
-def between(first, last):
+def between(first: int, last: int) -> int:
if system_random is not None:
return system_random.randrange(first, last + 1)
else:
+++ /dev/null
-from typing import Optional
-from random import SystemRandom
-
-system_random : Optional[SystemRandom]
-
-def random_16() -> int:
- pass
-
-def between(first: int, last: int) -> int:
- pass
always be subclasses of ``DNSException``.
"""
+
+from typing import Dict, Optional, Set
+
+
class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions.
and ``fmt`` class variables to get nice parametrized messages.
"""
- msg = None # non-parametrized message
- supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check)
- fmt = None # message parametrized with results from _fmt_kwargs
+ msg: Optional[str] = None # non-parametrized message
+ supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
+ fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
supp_kwargs = {'timeout'}
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
class ExceptionWrapper:
def __init__(self, exception_class):
+++ /dev/null
-from typing import Set, Optional, Dict
-
-class DNSException(Exception):
- supp_kwargs : Set[str]
- kwargs : Optional[Dict]
- fmt : Optional[str]
-
-class SyntaxError(DNSException): ...
-class FormError(DNSException): ...
-class Timeout(DNSException): ...
-class TooBig(DNSException): ...
-class UnexpectedEnd(SyntaxError): ...
"""DNS Message Flags."""
+from typing import Any
+
import enum
# Standard DNS flags
DO = 0x8000
-def _from_text(text, enum_class):
+def _from_text(text: str, enum_class: Any) -> int:
flags = 0
tokens = text.split()
for t in tokens:
return flags
-def _to_text(flags, enum_class):
+def _to_text(flags: int, enum_class: Any) -> str:
text_flags = []
for k, v in enum_class.__members__.items():
if flags & v != 0:
return ' '.join(text_flags)
-def from_text(text):
+def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
return _from_text(text, Flag)
-def to_text(flags):
+def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
return _to_text(flags, Flag)
-def edns_from_text(text):
+def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
return _from_text(text, EDNSFlag)
-def edns_to_text(flags):
+def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
"""DNS GENERATE range conversion."""
+from typing import Tuple
+
import dns
-def from_text(text):
+def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an
integer.
"""Generic Internet address helper functions."""
+from typing import Any, Optional, Tuple
+
import socket
import dns.ipv4
AF_INET6 = socket.AF_INET6
-def inet_pton(family, text):
+def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family.
raise NotImplementedError
-def inet_ntop(family, address):
+def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family.
raise NotImplementedError
-def af_for_address(text):
+def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address.
raise ValueError
-def is_multicast(text):
+def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address.
raise ValueError
-def is_address(text):
+def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address.
return False
-def low_level_address_tuple(high_tuple, af=None):
+def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) -> Any:
"""Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
If an *af* other than ``None`` is provided, it is assumed the
address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called.
-
"""
address, port = high_tuple
if af is None:
+++ /dev/null
-from typing import Union
-from socket import AddressFamily
-
-AF_INET6 : Union[int, AddressFamily]
"""IPv4 helper functions."""
+from typing import Union
+
import struct
import dns.exception
-def inet_ntoa(address):
+def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form.
return ('%u.%u.%u.%u' % (address[0], address[1],
address[2], address[3]))
-def inet_aton(text):
+def inet_aton(text: Union[str, bytes]) -> bytes:
"""Convert an IPv4 address in text form to binary form.
- *text*, a ``str``, the IPv4 address in textual form.
+ *text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``.
"""
if not isinstance(text, bytes):
- text = text.encode()
- parts = text.split(b'.')
+ btext = text.encode()
+ else:
+ btext = text
+ parts = btext.split(b'.')
if len(parts) != 4:
raise dns.exception.SyntaxError
for part in parts:
"""IPv6 helper functions."""
+from typing import List, Union
+
import re
import binascii
_leading_zero = re.compile(r'0+([0-9a-f]+)')
-def inet_ntoa(address):
+def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form.
prefix = '::'
else:
prefix = '::ffff:'
- hex = prefix + dns.ipv4.inet_ntoa(address[12:])
+ thex = prefix + dns.ipv4.inet_ntoa(address[12:])
else:
- hex = ':'.join(chunks[:best_start]) + '::' + \
+ thex = ':'.join(chunks[:best_start]) + '::' + \
':'.join(chunks[best_start + best_len:])
else:
- hex = ':'.join(chunks)
- return hex
+ thex = ':'.join(chunks)
+ return thex
_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
_colon_colon_start = re.compile(br'::.*')
_colon_colon_end = re.compile(br'.*::$')
-def inet_aton(text, ignore_scope=False):
+def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str``, the IPv6 address in textual form.
# Our aim here is not something fast; we just want something that works.
#
if not isinstance(text, bytes):
- text = text.encode()
+ btext = text.encode()
+ else:
+ btext = text
if ignore_scope:
- parts = text.split(b'%')
+ parts = btext.split(b'%')
l = len(parts)
if l == 2:
- text = parts[0]
+ btext = parts[0]
elif l > 2:
raise dns.exception.SyntaxError
- if text == b'':
+ if btext == b'':
raise dns.exception.SyntaxError
- elif text.endswith(b':') and not text.endswith(b'::'):
+ elif btext.endswith(b':') and not btext.endswith(b'::'):
raise dns.exception.SyntaxError
- elif text.startswith(b':') and not text.startswith(b'::'):
+ elif btext.startswith(b':') and not btext.startswith(b'::'):
raise dns.exception.SyntaxError
- elif text == b'::':
- text = b'0::'
+ elif btext == b'::':
+ btext = b'0::'
#
# Get rid of the icky dot-quad syntax if we have it.
#
- m = _v4_ending.match(text)
+ m = _v4_ending.match(btext)
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
- text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
+ btext = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
b[0], b[1], b[2],
b[3])).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:'
#
- m = _colon_colon_start.match(text)
+ m = _colon_colon_start.match(btext)
if m is not None:
- text = text[1:]
+ btext = btext[1:]
else:
- m = _colon_colon_end.match(text)
+ m = _colon_colon_end.match(btext)
if m is not None:
- text = text[:-1]
+ btext = btext[:-1]
#
# Now canonicalize into 8 chunks of 4 hex digits each
#
- chunks = text.split(b':')
+ chunks = btext.split(b':')
l = len(chunks)
if l > 8:
raise dns.exception.SyntaxError
seen_empty = False
- canonical = []
+ canonical: List[bytes] = []
for c in chunks:
if c == b'':
if seen_empty:
canonical.append(c)
if l < 8 and not seen_empty:
raise dns.exception.SyntaxError
- text = b''.join(canonical)
+ btext = b''.join(canonical)
#
# Finally we can go to binary.
#
try:
- return binascii.unhexlify(text)
+ return binascii.unhexlify(btext)
except (binascii.Error, TypeError):
raise dns.exception.SyntaxError
"""DNS Messages"""
+from typing import Any, Dict, List, Optional, Tuple, Union
+
import contextlib
import io
import time
supp_kwargs = {'message'}
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
def message(self):
"""As much of the message as could be processed.
class MessageError:
- def __init__(self, exception, offset):
+ def __init__(self, exception: Exception, offset: int):
self.exception = exception
self.offset = offset
DEFAULT_EDNS_PAYLOAD = 1232
MAX_CHAIN = 16
+IndexKeyType = Tuple[int, dns.name.Name, dns.rdataclass.RdataClass,
+ dns.rdatatype.RdataType, Optional[dns.rdatatype.RdataType],
+ Optional[dns.rdataclass.RdataClass]]
+IndexType = Dict[IndexKeyType, dns.rrset.RRset]
+SectionType = Union[int, List[dns.rrset.RRset]]
+
class Message:
"""A DNS message."""
_section_enum = MessageSection
- def __init__(self, id=None):
+ def __init__(self, id: Optional[int]=None):
if id is None:
self.id = dns.entropy.random_16()
else:
self.id = id
self.flags = 0
- self.sections = [[], [], [], []]
- self.opt = None
+ self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []]
+ self.opt: Optional[dns.rrset.RRset] = None
self.request_payload = 0
- self.keyring = None
- self.tsig = None
+ self.keyring: Any = None
+ self.tsig: Optional[dns.rrset.RRset] = None
self.request_mac = b''
self.xfr = False
- self.origin = None
- self.tsig_ctx = None
- self.index = {}
- self.errors = []
+ self.origin: Optional[dns.name.Name] = None
+ self.tsig_ctx: Optional[Any] = None
+ self.index: IndexType = {}
+ self.errors: List[MessageError] = []
+ self.time = 0.0
@property
- def question(self):
+ def question(self) -> List[dns.rrset.RRset]:
""" The question section."""
return self.sections[0]
self.sections[0] = v
@property
- def answer(self):
+ def answer(self) -> List[dns.rrset.RRset]:
""" The answer section."""
return self.sections[1]
self.sections[1] = v
@property
- def authority(self):
+ def authority(self) -> List[dns.rrset.RRset]:
""" The authority section."""
return self.sections[2]
self.sections[2] = v
@property
- def additional(self):
+ def additional(self) -> List[dns.rrset.RRset]:
""" The additional data section."""
return self.sections[3]
def __str__(self):
return self.to_text()
- def to_text(self, origin=None, relativize=True, **kw):
+ def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True,
+ **kw):
"""Convert the message to text.
The *origin*, *relativize*, and any other keyword
def __ne__(self, other):
return not self.__eq__(other)
- def is_response(self, other):
+ def is_response(self, other: 'Message') -> bool:
"""Is *other*, also a ``dns.message.Message``, a response to this
message?
return False
return True
- def section_number(self, section):
+ def section_number(self, section: List[dns.rrset.RRset]) -> int:
"""Return the "section number" of the specified section for use
in indexing.
return self._section_enum(i)
raise ValueError('unknown section')
- def section_from_number(self, number):
+ def section_from_number(self, number: int) -> List[dns.rrset.RRset]:
"""Return the section list associated with the specified section
number.
section = self._section_enum.make(number)
return self.sections[section]
- def find_rrset(self, section, name, rdclass, rdtype,
- covers=dns.rdatatype.NONE, deleting=None, create=False,
- force_unique=False):
+ def find_rrset(self,
+ section: SectionType,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers = dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass]=None,
+ create=False,
+ force_unique=False) -> dns.rrset.RRset:
"""Find the RRset with the given attributes in the specified section.
*section*, an ``int`` section number, or one of the section
if isinstance(section, int):
section_number = section
- section = self.section_from_number(section_number)
+ the_section = self.section_from_number(section_number)
else:
section_number = self.section_number(section)
+ the_section = section
key = (section_number, name, rdclass, rdtype, covers, deleting)
if not force_unique:
if self.index is not None:
if rrset is not None:
return rrset
else:
- for rrset in section:
+ for rrset in the_section:
if rrset.full_match(name, rdclass, rdtype, covers,
deleting):
return rrset
if not create:
raise KeyError
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
- section.append(rrset)
+ the_section.append(rrset)
if self.index is not None:
self.index[key] = rrset
return rrset
- def get_rrset(self, section, name, rdclass, rdtype,
- covers=dns.rdatatype.NONE, deleting=None, create=False,
- force_unique=False):
+ def get_rrset(self,
+ section: SectionType,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers = dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass]=None,
+ create=False,
+ force_unique=False) -> Optional[dns.rrset.RRset]:
"""Get the RRset with the given attributes in the specified section.
If the RRset is not found, None is returned.
rrset = None
return rrset
- def to_wire(self, origin=None, max_size=0, multi=False, tsig_ctx=None,
- **kw):
+ def to_wire(self, origin: Optional[dns.name.Name]=None, max_size=0,
+ multi=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes:
"""Return a string containing the message in DNS compressed wire
format.
original_id, error, other)
return dns.rrset.from_rdata(keyname, 0, tsig)
- def use_tsig(self, keyring, keyname=None, fudge=300,
- original_id=None, tsig_error=0, other_data=b'',
- algorithm=dns.tsig.default_algorithm):
+ def use_tsig(self, keyring: Any, keyname: Optional[dns.name.Name]=None,
+ fudge=300, original_id: Optional[int]=None, tsig_error=0,
+ other_data=b'', algorithm=dns.tsig.default_algorithm):
"""When sending, a TSIG signature using the specified key
should be added.
b'', original_id, tsig_error, other_data)
@property
- def keyname(self):
+ def keyname(self) -> Optional[dns.name.Name]:
if self.tsig:
return self.tsig.name
else:
return None
@property
- def keyalgorithm(self):
+ def keyalgorithm(self) -> Optional[dns.name.Name]:
if self.tsig:
return self.tsig[0].algorithm
else:
return None
@property
- def mac(self):
+ def mac(self) -> Optional[bytes]:
if self.tsig:
return self.tsig[0].mac
else:
return None
@property
- def tsig_error(self):
+ def tsig_error(self) -> Optional[int]:
if self.tsig:
return self.tsig[0].error
else:
return None
@property
- def had_tsig(self):
+ def had_tsig(self) -> bool:
return bool(self.tsig)
@staticmethod
return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD,
- request_payload=None, options=None):
+ request_payload: Optional[int]=None,
+ options: Optional[List[dns.edns.Option]]=None):
"""Configure EDNS behavior.
*edns*, an ``int``, is the EDNS level to use. Specifying
self.request_payload = request_payload
@property
- def edns(self):
+ def edns(self) -> int:
if self.opt:
return (self.ednsflags & 0xff0000) >> 16
else:
return -1
@property
- def ednsflags(self):
+ def ednsflags(self) -> int:
if self.opt:
return self.opt.ttl
else:
self.opt = self._make_opt(v)
@property
- def payload(self):
+ def payload(self) -> int:
if self.opt:
return self.opt[0].payload
else:
return 0
@property
- def options(self):
+ def options(self) -> Tuple:
if self.opt:
return self.opt[0].options
else:
elif self.opt:
self.ednsflags &= ~dns.flags.DO
- def rcode(self):
+ def rcode(self) -> dns.rcode.Rcode:
"""Return the rcode.
- Returns an ``int``.
+ Returns a ``dns.rcode.Rcode``.
"""
return dns.rcode.from_flags(int(self.flags), int(self.ednsflags))
- def set_rcode(self, rcode):
+ def set_rcode(self, rcode: dns.rcode.Rcode):
"""Set the rcode.
- *rcode*, an ``int``, is the rcode to set.
+ *rcode*, a ``dns.rcode.Rcode``, is the rcode to set.
"""
(value, evalue) = dns.rcode.to_flags(rcode)
self.flags &= 0xFFF0
self.ednsflags &= 0x00FFFFFF
self.ednsflags |= evalue
- def opcode(self):
+ def opcode(self) -> dns.opcode.Opcode:
"""Return the opcode.
- Returns an ``int``.
+ Returns a ``dns.opcode.Opcode``.
"""
return dns.opcode.from_flags(int(self.flags))
- def set_opcode(self, opcode):
+ def set_opcode(self, opcode: dns.opcode.Opcode):
"""Set the opcode.
- *opcode*, an ``int``, is the opcode to set.
+ *opcode*, a ``dns.opcode.Opcode``, is the opcode to set.
"""
self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode)
exist.
The ``canonical_name`` attribute is the canonical name after all
- chaining has been applied (this is the name as ``rrset.name`` in cases
+ chaining has been applied (this is the same name as ``rrset.name`` in cases
where rrset is not ``None``).
The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to
The ``cnames`` attribute is a list of all the CNAME RRSets followed to
get to the canonical name.
"""
- def __init__(self, canonical_name, answer, minimum_ttl, cnames):
+ def __init__(self, canonical_name: dns.name.Name, answer: Optional[dns.rrset.RRset],
+ minimum_ttl: int, cnames: List[dns.rrset.RRset]):
self.canonical_name = canonical_name
self.answer = answer
self.minimum_ttl = minimum_ttl
class QueryMessage(Message):
- def resolve_chaining(self):
+ def resolve_chaining(self) -> ChainingResult:
"""Follow the CNAME chain in the response to determine the answer
RRset.
break
return ChainingResult(qname, answer, min_ttl, cnames)
- def canonical_name(self):
+ def canonical_name(self) -> dns.name.Name:
"""Return the canonical name of the first name in the question
section.
tsig_ctx=None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False, raise_on_truncation=False,
- continue_on_error=False):
+ continue_on_error=False) -> Message:
"""Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
def from_text(text, idna_codec=None, one_rr_per_rrset=False,
- origin=None, relativize=True, relativize_to=None):
+ origin=None, relativize=True, relativize_to=None) -> Message:
"""Convert the text format message into a message object.
The reader stops after reading the first blank line in the input to
return reader.read()
-def from_file(f, idna_codec=None, one_rr_per_rrset=False):
+def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message:
"""Read the next text format message from the specified file.
Message blocks are separated by a single blank line.
if isinstance(f, str):
f = stack.enter_context(open(f))
return from_text(f, idna_codec, one_rr_per_rrset)
+ assert False # for mypy
def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
- want_dnssec=False, ednsflags=None, payload=None,
- request_payload=None, options=None, idna_codec=None,
- id=None, flags=dns.flags.RD):
+ want_dnssec=False, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+ request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
+ idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None,
+ flags: int=dns.flags.RD) -> QueryMessage:
"""Make a query message.
The query name, type, and class may all be specified either
# only pass keywords on to use_edns if they have been set to a
# non-None value. Setting a field will turn EDNS on if it hasn't
# been configured.
- kwargs = {}
+ kwargs: Dict[str, Any] = {}
if ednsflags is not None:
kwargs['ednsflags'] = ednsflags
if payload is not None:
def make_response(query, recursion_available=False, our_payload=8192,
- fudge=300, tsig_error=0):
+ fudge=300, tsig_error=0) -> Message:
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all
of the infrastructure required of a response, but none of the
+++ /dev/null
-from typing import Optional, Dict, List, Tuple, Union
-from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
-import hmac
-
-class Message:
- def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
- ...
- def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
- covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
- force_unique=False) -> rrset.RRset:
- ...
- def __init__(self, id : Optional[int] =None) -> None:
- self.id : int
- self.flags = 0
- self.sections : List[List[rrset.RRset]] = [[], [], [], []]
- self.opt : rrset.RRset = None
- self.request_payload = 0
- self.keyring = None
- self.tsig : rrset.RRset = None
- self.request_mac = b''
- self.xfr = False
- self.origin = None
- self.tsig_ctx = None
- self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
-
- def is_response(self, other : Message) -> bool:
- ...
-
- def set_rcode(self, rcode : rcode.Rcode):
- ...
-
-def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
- ...
-
-def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
- tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
- question_only=False, one_rr_per_rrset=False,
- ignore_trailing=False) -> Message:
- ...
-def make_response(query : Message, recursion_available=False, our_payload=8192,
- fudge=300) -> Message:
- ...
-
-def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
- want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
- request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
- ...
"""DNS Names.
"""
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
import copy
import struct
except ImportError: # pragma: no cover
have_idna_2008 = False
+import dns.enum
import dns.wire
import dns.exception
import dns.immutable
-# fullcompare() result values
-#: The compared names have no relationship to each other.
-NAMERELN_NONE = 0
-#: the first name is a superdomain of the second.
-NAMERELN_SUPERDOMAIN = 1
-#: The first name is a subdomain of the second.
-NAMERELN_SUBDOMAIN = 2
-#: The compared names are equal.
-NAMERELN_EQUAL = 3
-#: The compared names have a common ancestor.
-NAMERELN_COMMONANCESTOR = 4
+CompressType = Dict['Name', int]
+
+
+class NameRelation(dns.enum.IntEnum):
+ """Name relation result from fullcompare()."""
+
+ # This is an IntEnum for backwards compatibility in case anyone
+ # has hardwired the constants.
+
+ #: The compared names have no relationship to each other.
+ NONE = 0
+ #: the first name is a superdomain of the second.
+ SUPERDOMAIN = 1
+ #: The first name is a subdomain of the second.
+ SUBDOMAIN = 2
+ #: The compared names are equal.
+ EQUAL = 3
+ #: The compared names have a common ancestor.
+ COMMONANCESTOR = 4
+
+ @classmethod
+ def _maximum(cls):
+ return cls.COMMONANCESTOR
+
+ @classmethod
+ def _short_name(cls):
+ return cls.__name__
+
+
+# Backwards compatibility
+NAMERELN_NONE = NameRelation.NONE
+NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
+NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
+NAMERELN_EQUAL = NameRelation.EQUAL
+NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR
class EmptyLabel(dns.exception.SyntaxError):
supp_kwargs = {'idna_exception'}
fmt = "IDNA processing exception: {idna_exception}"
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+_escaped = b'"().;\\@$'
+_escaped_text = '"().;\\@$'
+
+def _escapify(label: Union[bytes, str]) -> str:
+ """Escape the characters in label which need it.
+ @returns: the escaped string
+ @rtype: string"""
+ if isinstance(label, bytes):
+ # Ordinary DNS label mode. Escape special characters and values
+ # < 0x20 or > 0x7f.
+ text = ''
+ for c in label:
+ if c in _escaped:
+ text += '\\' + chr(c)
+ elif c > 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += '\\%03d' % c
+ return text
+
+ # Unicode label mode. Escape only special characters and values < 0x20
+ text = ''
+ for uc in label:
+ if uc in _escaped_text:
+ text += '\\' + uc
+ elif uc <= '\x20':
+ text += '\\%03d' % ord(uc)
+ else:
+ text += uc
+ return text
+
class IDNACodec:
"""Abstract base class for IDNA encoder/decoders."""
def __init__(self):
pass
- def is_idna(self, label):
+ def is_idna(self, label: bytes) -> bool:
return label.lower().startswith(b'xn--')
- def encode(self, label):
+ def encode(self, label: str) -> bytes:
raise NotImplementedError # pragma: no cover
- def decode(self, label):
+ def decode(self, label: bytes) -> str:
# We do not apply any IDNA policy on decode.
if self.is_idna(label):
try:
- label = label[4:].decode('punycode')
+ slabel = label[4:].decode('punycode')
+ return _escapify(slabel)
except Exception as e:
raise IDNAException(idna_exception=e)
- return _escapify(label)
+ else:
+ return _escapify(label)
class IDNA2003Codec(IDNACodec):
super().__init__()
self.strict_decode = strict_decode
- def encode(self, label):
+ def encode(self, label: str) -> bytes:
"""Encode *label*."""
if label == '':
except UnicodeError:
raise LabelTooLong
- def decode(self, label):
+ def decode(self, label: bytes) -> str:
"""Decode *label*."""
if not self.strict_decode:
return super().decode(label)
self.allow_pure_ascii = allow_pure_ascii
self.strict_decode = strict_decode
- def encode(self, label):
+ def encode(self, label: str) -> bytes:
if label == '':
return b''
if self.allow_pure_ascii and is_all_ascii(label):
else:
raise IDNAException(idna_exception=e)
- def decode(self, label):
+ def decode(self, label: bytes) -> str:
if not self.strict_decode:
return super().decode(label)
if label == b'':
except (idna.IDNAError, UnicodeError) as e:
raise IDNAException(idna_exception=e)
-_escaped = b'"().;\\@$'
-_escaped_text = '"().;\\@$'
-
IDNA_2003_Practical = IDNA2003Codec(False)
IDNA_2003_Strict = IDNA2003Codec(True)
IDNA_2003 = IDNA_2003_Practical
IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
IDNA_2008 = IDNA_2008_Practical
-def _escapify(label):
- """Escape the characters in label which need it.
- @returns: the escaped string
- @rtype: string"""
- if isinstance(label, bytes):
- # Ordinary DNS label mode. Escape special characters and values
- # < 0x20 or > 0x7f.
- text = ''
- for c in label:
- if c in _escaped:
- text += '\\' + chr(c)
- elif c > 0x20 and c < 0x7F:
- text += chr(c)
- else:
- text += '\\%03d' % c
- return text
-
- # Unicode label mode. Escape only special characters and values < 0x20
- text = ''
- for c in label:
- if c in _escaped_text:
- text += '\\' + c
- elif c <= '\x20':
- text += '\\%03d' % ord(c)
- else:
- text += c
- return text
-
-def _validate_labels(labels):
+def _validate_labels(labels: Tuple[bytes, ...]):
"""Check for empty labels in the middle of a label sequence,
labels that are too long, and for too many labels.
raise EmptyLabel
-def _maybe_convert_to_binary(label):
+def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
"""If label is ``str``, convert it to ``bytes``. If it is already
``bytes`` just return it.
__slots__ = ['labels']
- def __init__(self, labels):
+ def __init__(self, labels: Iterable[Union[bytes, str]]):
"""*labels* is any iterable whose values are ``str`` or ``bytes``.
"""
- labels = [_maybe_convert_to_binary(x) for x in labels]
- self.labels = tuple(labels)
+ blabels = [_maybe_convert_to_binary(x) for x in labels]
+ self.labels = tuple(blabels)
_validate_labels(self.labels)
def __copy__(self):
super().__setattr__('labels', state['labels'])
_validate_labels(self.labels)
- def is_absolute(self):
+ def is_absolute(self) -> bool:
"""Is the most significant label of this name the root label?
Returns a ``bool``.
return len(self.labels) > 0 and self.labels[-1] == b''
- def is_wild(self):
+ def is_wild(self) -> bool:
"""Is this name wild? (I.e. Is the least significant label '*'?)
Returns a ``bool``.
return len(self.labels) > 0 and self.labels[0] == b'*'
- def __hash__(self):
+ def __hash__(self) -> int:
"""Return a case-insensitive hash of the name.
Returns an ``int``.
h += (h << 3) + c
return h
- def fullcompare(self, other):
+ def fullcompare(self, other: 'Name') -> Tuple[NameRelation, int, int]:
"""Compare two names, returning a 3-tuple
``(relation, order, nlabels)``.
*relation* describes the relation ship between the names,
- and is one of: ``dns.name.NAMERELN_NONE``,
- ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``,
- ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``.
+ and is one of: ``dns.name.NameRelation.NONE``,
+ ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``,
+ ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``.
*order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and ==
0 if *self* == *other*. A relative name is always less than an
oabs = other.is_absolute()
if sabs != oabs:
if sabs:
- return (NAMERELN_NONE, 1, 0)
+ return (NameRelation.NONE, 1, 0)
else:
- return (NAMERELN_NONE, -1, 0)
+ return (NameRelation.NONE, -1, 0)
l1 = len(self.labels)
l2 = len(other.labels)
ldiff = l1 - l2
order = 0
nlabels = 0
- namereln = NAMERELN_NONE
+ namereln = NameRelation.NONE
while l > 0:
l -= 1
l1 -= 1
if label1 < label2:
order = -1
if nlabels > 0:
- namereln = NAMERELN_COMMONANCESTOR
+ namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels)
elif label1 > label2:
order = 1
if nlabels > 0:
- namereln = NAMERELN_COMMONANCESTOR
+ namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels)
nlabels += 1
order = ldiff
if ldiff < 0:
- namereln = NAMERELN_SUPERDOMAIN
+ namereln = NameRelation.SUPERDOMAIN
elif ldiff > 0:
- namereln = NAMERELN_SUBDOMAIN
+ namereln = NameRelation.SUBDOMAIN
else:
- namereln = NAMERELN_EQUAL
+ namereln = NameRelation.EQUAL
return (namereln, order, nlabels)
- def is_subdomain(self, other):
+ def is_subdomain(self, other: 'Name') -> bool:
"""Is self a subdomain of other?
Note that the notion of subdomain includes equality, e.g.
"""
(nr, _, _) = self.fullcompare(other)
- if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
+ if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
return True
return False
- def is_superdomain(self, other):
+ def is_superdomain(self, other: 'Name') -> bool:
"""Is self a superdomain of other?
Note that the notion of superdomain includes equality, e.g.
"""
(nr, _, _) = self.fullcompare(other)
- if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
+ if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
return True
return False
- def canonicalize(self):
+ def canonicalize(self) -> 'Name':
"""Return a name which is equal to the current name, but is in
DNSSEC canonical form.
"""
def __str__(self):
return self.to_text(False)
- def to_text(self, omit_final_dot=False):
+ def to_text(self, omit_final_dot=False) -> str:
"""Convert name to DNS text format.
*omit_final_dot* is a ``bool``. If True, don't emit the final
s = '.'.join(map(_escapify, l))
return s
- def to_unicode(self, omit_final_dot=False, idna_codec=None):
+ def to_unicode(self, omit_final_dot=False, idna_codec: Optional[IDNACodec]=None) -> str:
"""Convert name to Unicode text format.
IDN ACE labels are converted to Unicode.
idna_codec = IDNA_2003_Practical
return '.'.join([idna_codec.decode(x) for x in l])
- def to_digestable(self, origin=None):
+ def to_digestable(self, origin: Optional['Name']=None) -> bytes:
"""Convert name to a format suitable for digesting in hashes.
The name is canonicalized and converted to uncompressed wire
Returns a ``bytes``.
"""
- return self.to_wire(origin=origin, canonicalize=True)
+ digest = self.to_wire(origin=origin, canonicalize=True)
+ assert digest is not None
+ return digest
- def to_wire(self, file=None, compress=None, origin=None,
- canonicalize=False):
+ def to_wire(self, file=None, compress: Optional[CompressType]=None,
+ origin: Optional['Name']=None, canonicalize=False) -> Optional[bytes]:
"""Convert name to wire format, possibly compressing it.
*file* is the file where the name is emitted (typically an
out += label
return bytes(out)
+ labels: Iterable[bytes]
if not self.is_absolute():
if origin is None or not origin.is_absolute():
raise NeedAbsoluteNameOrOrigin
file.write(label.lower())
else:
file.write(label)
+ return None
- def __len__(self):
+ def __len__(self) -> int:
"""The length of the name (in labels).
Returns an ``int``.
def __sub__(self, other):
return self.relativize(other)
- def split(self, depth):
+ def split(self, depth: int) -> Tuple['Name', 'Name']:
"""Split a name into a prefix and suffix names at the specified depth.
*depth* is an ``int`` specifying the number of labels in the suffix
'depth must be >= 0 and <= the length of the name')
return (Name(self[: -depth]), Name(self[-depth:]))
- def concatenate(self, other):
+ def concatenate(self, other: 'Name') -> 'Name':
"""Return a new name which is the concatenation of self and other.
Raises ``dns.name.AbsoluteConcatenation`` if the name is
labels.extend(list(other.labels))
return Name(labels)
- def relativize(self, origin):
+ def relativize(self, origin: 'Name') -> 'Name':
"""If the name is a subdomain of *origin*, return a new name which is
the name relative to origin. Otherwise return the name.
else:
return self
- def derelativize(self, origin):
+ def derelativize(self, origin: 'Name') -> 'Name':
"""If the name is a relative name, return a new name which is the
concatenation of the name and origin. Otherwise return the name.
else:
return self
- def choose_relativity(self, origin=None, relativize=True):
+ def choose_relativity(self, origin: Optional['Name']=None, relativize=True) -> 'Name':
"""Return a name with the relativity desired by the caller.
If *origin* is ``None``, then the name is returned.
else:
return self
- def parent(self):
+ def parent(self) -> 'Name':
"""Return the parent of the name.
For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
#: The empty name.
empty = Name([])
-def from_unicode(text, origin=root, idna_codec=None):
+def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name:
"""Convert unicode text into a Name object.
Labels are encoded in IDN ACE form according to rules specified by
labels.extend(list(origin.labels))
return Name(labels)
-def is_all_ascii(text):
+def is_all_ascii(text: str) -> bool:
for c in text:
if ord(c) > 0x7f:
return False
return True
-def from_text(text, origin=root, idna_codec=None):
+def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name:
"""Convert text into a Name object.
- *text*, a ``str``, is the text to convert into a name.
+ *text*, a ``bytes`` or ``str``, is the text to convert into a name.
*origin*, a ``dns.name.Name``, specifies the origin to
append to non-absolute names. The default is the root name.
labels.extend(list(origin.labels))
return Name(labels)
+# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
-def from_wire_parser(parser):
+def from_wire_parser(parser: 'dns.wire.Parser') -> Name:
"""Convert possibly compressed wire format into a Name.
*parser* is a dns.wire.Parser.
return Name(labels)
-def from_wire(message, current):
+def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
"""Convert possibly compressed wire format into a Name.
*message* is a ``bytes`` containing an entire DNS message in DNS
+++ /dev/null
-from typing import Optional, Union, Tuple, Iterable, List
-
-have_idna_2008: bool
-
-class Name:
- def is_subdomain(self, o : Name) -> bool: ...
- def is_superdomain(self, o : Name) -> bool: ...
- def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
- self.labels : List[bytes]
- def is_absolute(self) -> bool: ...
- def is_wild(self) -> bool: ...
- def fullcompare(self, other) -> Tuple[int,int,int]: ...
- def canonicalize(self) -> Name: ...
- def __eq__(self, other) -> bool: ...
- def __ne__(self, other) -> bool: ...
- def __lt__(self, other : Name) -> bool: ...
- def __le__(self, other : Name) -> bool: ...
- def __ge__(self, other : Name) -> bool: ...
- def __gt__(self, other : Name) -> bool: ...
- def to_text(self, omit_final_dot=False) -> str: ...
- def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
- def to_digestable(self, origin=None) -> bytes: ...
- def to_wire(self, file=None, compress=None, origin=None,
- canonicalize=False) -> Optional[bytes]: ...
- def __add__(self, other : Name) -> Name: ...
- def __sub__(self, other : Name) -> Name: ...
- def split(self, depth) -> List[Tuple[str,str]]: ...
- def concatenate(self, other : Name) -> Name: ...
- def relativize(self, origin) -> Name: ...
- def derelativize(self, origin) -> Name: ...
- def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
- def parent(self) -> Name: ...
-
-class IDNACodec:
- pass
-
-def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
- ...
-
-empty : Name
"""DNS nodes. A node is a set of rdatasets."""
+from typing import List, Optional, Union
+
import enum
import io
import dns.immutable
+import dns.name
+import dns.rdataclass
import dns.rdataset
import dns.rdatatype
+import dns.rrset
import dns.renderer
CNAME = 2
@classmethod
- def classify(cls, rdtype, covers):
+ def classify(cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType) -> 'NodeKind':
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
return NodeKind.REGULAR
@classmethod
- def classify_rdataset(cls, rdataset):
+ def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> 'NodeKind':
return cls.classify(rdataset.rdtype, rdataset.covers)
def __init__(self):
# the set of rdatasets, represented as a list.
- self.rdatasets = []
+ self.rdatasets: List[dns.rdataset.Rdataset] = []
- def to_text(self, name, **kw):
+ def to_text(self, name: dns.name.Name, **kw) -> str:
"""Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments
to this method are passed on to the rdataset's to_text() method.
- *name*, a ``dns.name.Name`` or ``str``, the owner name of the
+ *name*, a ``dns.name.Name``, the owner name of the
rdatasets.
Returns a ``str``.
# edit self.rdatasets.
self.rdatasets.append(rdataset)
- def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
- create=False):
+ def find_rdataset(self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+ create=False) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the
current node.
- *rdclass*, an ``int``, the class of the rdataset.
+ *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
- *rdtype*, an ``int``, the type of the rdataset.
+ *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
- *covers*, an ``int`` or ``None``, the covered type.
+ *covers*, a ``dns.rdatatype.RdataType``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
self._append_rdataset(rds)
return rds
- def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
- create=False):
+ def get_rdataset(self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+ create=False) -> Optional[dns.rdataset.Rdataset]:
"""Get an rdataset matching the specified properties in the
current node.
rds = None
return rds
- def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+ def delete_rdataset(self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType=dns.rdatatype.NONE):
"""Delete the rdataset matching the specified properties in the
current node.
if rds is not None:
self.rdatasets.remove(rds)
- def replace_rdataset(self, replacement):
+ def replace_rdataset(self, replacement: dns.rdataset.Rdataset):
"""Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*.
replacement.covers)
self._append_rdataset(replacement)
- def classify(self):
+ def classify(self) -> NodeKind:
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
return kind
return NodeKind.NEUTRAL
- def is_immutable(self):
+ def is_immutable(self) -> bool:
return False
def replace_rdataset(self, replacement):
raise TypeError("immutable")
- def is_immutable(self):
+ def is_immutable(self) -> bool:
return True
+++ /dev/null
-from typing import List, Optional, Union
-from . import rdataset, rdatatype, name
-class Node:
- def __init__(self):
- self.rdatasets : List[rdataset.Rdataset]
- def to_text(self, name : Union[str,name.Name], **kw) -> str:
- ...
- def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
- create=False) -> rdataset.Rdataset:
- ...
- def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
- create=False) -> Optional[rdataset.Rdataset]:
- ...
- def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
- ...
- def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
- ...
"""An DNS opcode is unknown."""
-def from_text(text):
+def from_text(text: str) -> Opcode:
"""Convert text into an opcode.
*text*, a ``str``, the textual opcode
return Opcode.from_text(text)
-def from_flags(flags):
+def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags.
Returns an ``int``.
"""
- return (flags & 0x7800) >> 11
+ return Opcode((flags & 0x7800) >> 11)
-def to_flags(value):
+def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message
flags.
return (value << 11) & 0x7800
-def to_text(value):
+def to_text(value: Opcode) -> str:
"""Convert an opcode to text.
*value*, an ``int`` the opcode value,
return Opcode.to_text(value)
-def is_update(flags):
+def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags.
"""Talk to a DNS server."""
+from typing import Any, Dict, Optional, Tuple, Union
+
import base64
import contextlib
import enum
import dns.rdataclass
import dns.rdatatype
import dns.serial
+import dns.transaction
+import dns.tsig
import dns.xfr
try:
class WantWriteException(Exception):
pass
+ class SSLContext:
+ pass
+
class SSLSocket:
pass
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
- _selector_class = selectors.PollSelector
+ #
+ # We ignore typing here as we can't say _selector_class is Any
+ # on python < 3.8 due to a bug.
+ _selector_class = selectors.PollSelector # type: ignore
else:
- _selector_class = selectors.SelectSelector # pragma: no cover
+ _selector_class = selectors.SelectSelector # type: ignore
def _wait_for_readable(s, expiration):
s.close()
raise
-def https(q, where, timeout=None, port=443, source=None, source_port=0,
+def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+ port=443, source: Optional[str]=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False,
- session=None, path='/dns-query', post=True,
- bootstrap_address=None, verify=True):
+ session: Optional[Any]=None, path='/dns-query', post=True,
+ bootstrap_address: Optional[str]=None, verify=True) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
elif bootstrap_address is not None:
_httpx_ok = False
split_url = urllib.parse.urlsplit(where)
+ if split_url.hostname is None:
+ raise ValueError('DoH URL has no hostname')
headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
if _have_requests:
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
if _is_httpx:
- wire = wire.decode() # httpx does a repr() if we give it bytes
+ twire = wire.decode() # httpx does a repr() if we give it bytes
response = session.get(url, headers=headers,
timeout=timeout,
- params={"dns": wire})
+ params={"dns": twire})
else:
response = session.get(url, headers=headers,
timeout=timeout, verify=verify,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
- r.time = response.elapsed
+ r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
_wait_for_writable(sock, expiration)
-def send_udp(sock, what, destination, expiration=None):
+def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: Any,
+ expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``socket``.
return (n, sent_time)
-def receive_udp(sock, destination=None, expiration=None,
+def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None,
ignore_unexpected=False, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False,
- raise_on_truncation=False):
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+ ignore_trailing=False, raise_on_truncation=False) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``socket``.
else:
return (r, received_time, from_address)
-def udp(q, where, timeout=None, port=53, source=None, source_port=0,
+def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
- raise_on_truncation=False, sock=None):
+ raise_on_truncation=False, sock: Optional[Any]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
if not q.is_response(r):
raise BadResponse
return r
+ assert False # help mypy figure out we can't get here
-def udp_with_fallback(q, where, timeout=None, port=53, source=None,
- source_port=0, ignore_unexpected=False,
- one_rr_per_rrset=False, ignore_trailing=False,
- udp_sock=None, tcp_sock=None):
+def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
+ ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+ udp_sock: Optional[Any]=None,
+ tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
_wait_for_readable(sock, expiration)
-def send_tcp(sock, what, expiration=None):
+def send_tcp(sock: Any, what: Union[dns.message.Message, bytes],
+ expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``socket``.
"""
if isinstance(what, dns.message.Message):
- what = what.to_wire()
- l = len(what)
+ wire = what.to_wire()
+ else:
+ wire = what
+ l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
- tcpmsg = struct.pack("!H", l) + what
+ tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
-def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False):
+def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+ ignore_trailing=False) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``socket``.
raise OSError(err, os.strerror(err))
-def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False, sock=None):
+def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+ source: Optional[str]=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[Any]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*q*, a ``dns.message.Message``, the query to send
if not q.is_response(r):
raise BadResponse
return r
+ assert False # help mypy figure out we can't get here
def _tls_handshake(s, expiration):
_wait_for_writable(s, expiration)
-def tls(q, where, timeout=None, port=853, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False, sock=None,
- ssl_context=None, server_hostname=None):
+def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+ port=853, source: Optional[str]=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[ssl.SSLSocket]=None,
+ ssl_context: Optional[ssl.SSLContext]=None,
+ server_hostname: Optional[str]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*q*, a ``dns.message.Message``, the query to send
if not q.is_response(r):
raise BadResponse
return r
-
+ assert False # help mypy figure out we can't get here
def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
timeout=None, port=53, keyring=None, keyname=None, relativize=True,
ONLY = 2
-def inbound_xfr(where, txn_manager, query=None,
- port=53, timeout=None, lifetime=None, source=None,
- source_port=0, udp_mode=UDPMode.NEVER):
+def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
+ query: Optional[dns.message.Message]=None,
+ port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+ source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER):
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
+++ /dev/null
-from typing import Optional, Union, Dict, Generator, Any
-from . import tsig, rdatatype, rdataclass, name, message
-from requests.sessions import Session
-
-import socket
-
-# If the ssl import works, then
-#
-# error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
- import ssl
-except ImportError:
- class ssl: # type: ignore
- SSLContext : Dict = {}
-
-have_doh: bool
-
-def https(q : message.Message, where: str, timeout : Optional[float] = None,
- port : Optional[int] = 443, source : Optional[str] = None,
- source_port : Optional[int] = 0,
- session: Optional[Session] = None,
- path : Optional[str] = '/dns-query', post : Optional[bool] = True,
- bootstrap_address : Optional[str] = None,
- verify : Optional[bool] = True) -> message.Message:
- pass
-
-def tcp(q : message.Message, where : str, timeout : float = None, port=53,
- af : Optional[int] = None, source : Optional[str] = None,
- source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[socket.socket] = None) -> message.Message:
- pass
-
-def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
- rdclass=rdataclass.IN,
- timeout : Optional[float] = None, port=53,
- keyring : Optional[Dict[name.Name, bytes]] = None,
- keyname : Union[str,name.Name]= None, relativize=True,
- lifetime : Optional[float] = None,
- source : Optional[str] = None, source_port=0, serial=0,
- use_udp : Optional[bool] = False,
- keyalgorithm=tsig.default_algorithm) \
- -> Generator[Any,Any,message.Message]:
- pass
-
-def udp(q : message.Message, where : str, timeout : Optional[float] = None,
- port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
- ignore_unexpected : Optional[bool] = False,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[socket.socket] = None) -> message.Message:
- pass
-
-def tls(q : message.Message, where : str, timeout : Optional[float] = None,
- port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[socket.socket] = None,
- ssl_context: Optional[ssl.SSLContext] = None,
- server_hostname: Optional[str] = None) -> message.Message:
- pass
"""DNS rdata."""
+from typing import Any, Dict, Optional, Tuple, Union
+
from importlib import import_module
import base64
import binascii
self.rdclass = self._as_rdataclass(rdclass)
self.rdtype = self._as_rdatatype(rdtype)
- self.rdcomment = None
+ self.rdcomment: Optional[str] = None
def _get_all_slots(self):
return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
# it if needed.
object.__setattr__(self, 'rdcomment', None)
- def covers(self):
+ def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is
creating rdatasets, allowing the rdataset to contain only RRSIGs
of a particular type, e.g. RRSIG(NS).
- Returns an ``int``.
+ Returns a ``dns.rdatatype.RdataType``.
"""
return dns.rdatatype.NONE
- def extended_rdatatype(self):
+ def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any.
return self.covers() << 16 | self.rdtype
- def to_text(self, origin=None, relativize=True, **kw):
+ def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw):
"""Convert an rdata to text format.
Returns a ``str``.
raise NotImplementedError # pragma: no cover
- def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ def _to_wire(self, file, compress: Optional[dns.name.CompressType]=None,
+ origin: Optional[dns.name.Name]=None, canonicalize=False):
raise NotImplementedError # pragma: no cover
def to_wire(self, file=None, compress=None, origin=None,
- canonicalize=False):
+ canonicalize=False) -> bytes:
"""Convert an rdata to wire format.
Returns a ``bytes`` or ``None``.
self._to_wire(f, compress, origin, canonicalize)
return f.getvalue()
- def to_generic(self, origin=None):
+ def to_generic(self, origin: Optional[dns.name.Name]=None) -> 'dns.rdata.GenericRdata':
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
return dns.rdata.GenericRdata(self.rdclass, self.rdtype,
self.to_wire(origin=origin))
- def to_digestable(self, origin=None):
+ def to_digestable(self, origin: Optional[dns.name.Name]=None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form.
return hash(self.to_digestable(dns.name.root))
@classmethod
- def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
- relativize_to=None):
+ def from_text(cls, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize=True,
+ relativize_to: Optional[dns.name.Name]=None):
raise NotImplementedError # pragma: no cover
@classmethod
- def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None):
raise NotImplementedError # pragma: no cover
def replace(self, **kwargs):
return dns.rdatatype.RdataType.make(value)
@classmethod
- def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
+ def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True) -> bytes:
if encode and isinstance(value, str):
- value = value.encode()
+ bvalue = value.encode()
elif isinstance(value, bytearray):
- value = bytes(value)
- elif not isinstance(value, bytes):
+ bvalue = bytes(value)
+ elif isinstance(value, bytes):
+ bvalue = value
+ else:
raise ValueError('not bytes')
- if max_length is not None and len(value) > max_length:
+ if max_length is not None and len(bvalue) > max_length:
raise ValueError('too long')
- if not empty_ok and len(value) == 0:
+ if not empty_ok and len(bvalue) == 0:
raise ValueError('empty bytes not allowed')
- return value
+ return bvalue
@classmethod
def _as_name(cls, value):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
return cls(rdclass, rdtype, parser.get_remaining())
-_rdata_classes = {}
+_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = {}
_module_prefix = 'dns.rdtypes'
def get_rdata_class(rdclass, rdtype):
return cls
-def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
- relativize_to=None, idna_codec=None):
+def from_text(rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ tok: Union[dns.tokenizer.Tokenizer, str],
+ origin: Optional[dns.name.Name]=None,
+ relativize=True, relativize_to: Optional[dns.name.Name]=None,
+ idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata:
"""Build an rdata object from text format.
This function attempts to dynamically load a class which
If *tok* is a ``str``, then a tokenizer is created and the string
is used as its input.
- *rdclass*, an ``int``, the rdataclass.
+ *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
- *rdtype*, an ``int``, the rdatatype.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
return rdata
-def from_wire_parser(rdclass, rdtype, parser, origin=None):
+def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
- *rdclass*, an ``int``, the rdataclass.
+ *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
- *rdtype*, an ``int``, the rdatatype.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the rdata length.
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
-def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
+def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ wire: bytes, current: int, rdlen: int,
+ origin: Optional[dns.name.Name]=None) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
+++ /dev/null
-from typing import Dict, Tuple, Any, Optional, BinaryIO
-from .name import Name, IDNACodec
-class Rdata:
- def __init__(self):
- self.address : str
- def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
- ...
- @classmethod
- def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
- ...
-_rdata_modules : Dict[Tuple[Any,Rdata],Any]
-
-def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
- relativize : bool = True, relativize_to : Optional[Name] = None,
- idna_codec : Optional[IDNACodec] = None):
- ...
-
-def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
- ...
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
+from typing import Any, cast, Collection, Dict, List, Optional, Union
+
import io
import random
import struct
import dns.exception
import dns.immutable
+import dns.name
import dns.rdatatype
import dns.rdataclass
import dns.rdata
import dns.set
+import dns.ttl
# define SimpleSet here for backwards compatibility
SimpleSet = dns.set.Set
__slots__ = ['rdclass', 'rdtype', 'covers', 'ttl']
- def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0):
+ def __init__(self, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers=dns.rdatatype.NONE, ttl=0):
"""Create a new rdataset of the specified class and type.
- *rdclass*, an ``int``, the rdataclass.
+ *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
- *rdtype*, an ``int``, the rdatatype.
+ *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
- *covers*, an ``int``, the covered rdatatype.
+ *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
*ttl*, an ``int``, the TTL.
"""
super().__init__()
self.rdclass = rdclass
- self.rdtype = rdtype
- self.covers = covers
+ self.rdtype: dns.rdatatype.RdataType = rdtype
+ self.covers: dns.rdatatype.RdataType = covers
self.ttl = ttl
def _clone(self):
obj.ttl = self.ttl
return obj
- def update_ttl(self, ttl):
+ def update_ttl(self, ttl: int):
"""Perform TTL minimization.
Set the TTL of the rdataset to be the lesser of the set's current
elif ttl < self.ttl:
self.ttl = ttl
- def add(self, rd, ttl=None): # pylint: disable=arguments-differ
+ def add(self, rd, ttl: Optional[int]=None): # pylint: disable=arguments-differ
"""Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then
def __ne__(self, other):
return not self.__eq__(other)
- def to_text(self, name=None, origin=None, relativize=True,
- override_rdclass=None, want_comments=False, **kw):
+ def to_text(self, name: Optional[dns.name.Name]=None,
+ origin: Optional[dns.name.Name]=None,
+ relativize=True,
+ override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+ want_comments=False, **kw) -> str:
"""Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
#
return s.getvalue()[:-1]
- def to_wire(self, name, file, compress=None, origin=None,
- override_rdclass=None, want_shuffle=True):
+ def to_wire(self, name: dns.name.Name, file: Any,
+ compress: Optional[dns.name.CompressType]=None,
+ origin: Optional[dns.name.Name]=None,
+ override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+ want_shuffle=True) -> int:
"""Convert the rdataset to wire format.
*name*, a ``dns.name.Name`` is the owner name to use.
file.write(stuff)
return 1
else:
+ l: Union[Rdataset, List[dns.rdata.Rdata]]
if want_shuffle:
l = list(self)
random.shuffle(l)
file.seek(0, io.SEEK_END)
return len(self)
- def match(self, rdclass, rdtype, covers):
+ def match(self, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType) -> bool:
"""Returns ``True`` if this rdataset matches the specified class,
type, and covers.
"""
return True
return False
- def processing_order(self):
+ def processing_order(self) -> List[dns.rdata.Rdata]:
"""Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same preference
_clone_class = Rdataset
- def __init__(self, rdataset):
+ def __init__(self, rdataset: Rdataset):
"""Create an immutable rdataset from the specified rdataset."""
super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
return ImmutableRdataset(super().symmetric_difference(other))
-def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
- origin=None, relativize=True, relativize_to=None):
+def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ ttl: int, text_rdatas: Collection[str],
+ idna_codec: Optional[dns.name.IDNACodec]=None,
+ origin: Optional[dns.name.Name]=None,
+ relativize=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- rdclass = dns.rdataclass.RdataClass.make(rdclass)
- rdtype = dns.rdatatype.RdataType.make(rdtype)
- r = Rdataset(rdclass, rdtype)
+ the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = Rdataset(the_rdclass, the_rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
return r
-def from_text(rdclass, rdtype, ttl, *text_rdatas):
+def from_text(rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ ttl: int, *text_rdatas) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- return from_text_list(rdclass, rdtype, ttl, text_rdatas)
+ return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
-def from_rdata_list(ttl, rdatas):
+def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified list of rdata objects.
r = Rdataset(rd.rdclass, rd.rdtype)
r.update_ttl(ttl)
r.add(rd)
+ assert r is not None
return r
-def from_rdata(ttl, *rdatas):
+def from_rdata(ttl: int, *rdatas) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- return from_rdata_list(ttl, rdatas)
+ return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))
+++ /dev/null
-from typing import Optional, Dict, List, Union
-from io import BytesIO
-from . import exception, name, set, rdatatype, rdata, rdataset
-
-class DifferingCovers(exception.DNSException):
- """An attempt was made to add a DNS SIG/RRSIG whose covered type
- is not the same as that of the other rdatas in the rdataset."""
-
-
-class IncompatibleTypes(exception.DNSException):
- """An attempt was made to add DNS RR data of an incompatible type."""
-
-
-class Rdataset(set.Set):
- def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
- self.rdclass : int = rdclass
- self.rdtype : int = rdtype
- self.covers : int = covers
- self.ttl : int = ttl
-
- def update_ttl(self, ttl : int) -> None:
- ...
-
- def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
- ...
-
- def union_update(self, other : Rdataset):
- ...
-
- def intersection_update(self, other : Rdataset):
- ...
-
- def update(self, other : Rdataset):
- ...
-
- def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
- override_rdclass : Optional[int] =None, **kw) -> bytes:
- ...
-
- def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
- override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
- ...
-
- def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
- ...
-
-
-def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
- ...
-
-def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
- ...
-
-def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
- ...
-
-def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
- ...
"""DNS Rdata Types."""
+from typing import Dict
+
import dns.enum
import dns.exception
def _unknown_exception_class(cls):
return UnknownRdatatype
-_registered_by_text = {}
-_registered_by_value = {}
+_registered_by_text: Dict[str, RdataType] = {}
+_registered_by_value: Dict[RdataType, str] = {}
_metatypes = {RdataType.OPT}
import dns.exception
import dns.immutable
-import dns.dnssec
+import dns.dnssectypes
import dns.rdata
import dns.tokenizer
def to_text(self, origin=None, relativize=True, **kw):
certificate_type = _ctype_to_text(self.certificate_type)
return "%s %d %s %s" % (certificate_type, self.key_tag,
- dns.dnssec.algorithm_to_text(self.algorithm),
+ dns.dnssectypes.Algorithm.to_text(self.algorithm),
dns.rdata._base64ify(self.certificate, **kw))
@classmethod
relativize_to=None):
certificate_type = _ctype_from_text(tok.get_string())
key_tag = tok.get_uint16()
- algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
+ algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
b64 = tok.concatenate_remaining_identifiers().encode()
certificate = base64.b64decode(b64)
return cls(rdclass, rdtype, certificate_type, key_tag,
import struct
import time
-import dns.dnssec
+import dns.dnssectypes
import dns.immutable
import dns.exception
import dns.rdata
signature):
super().__init__(rdclass, rdtype)
self.type_covered = self._as_rdatatype(type_covered)
- self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+ self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.labels = self._as_uint8(labels)
self.original_ttl = self._as_ttl(original_ttl)
self.expiration = self._as_uint32(expiration)
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
type_covered = dns.rdatatype.from_text(tok.get_string())
- algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
+ algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
labels = tok.get_int()
original_ttl = tok.get_ttl()
expiration = sigtime_to_posixtime(tok.get_string())
import base64
import struct
-import dns.dnssec
import dns.immutable
import dns.exception
import dns.rdata
import dns.immutable
import dns.rdata
import dns.rdatatype
-import dns.zone
+import dns.zonetypes
@dns.immutable.immutable
def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial)
- self.scheme = dns.zone.DigestScheme.make(scheme)
- self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm)
+ self.scheme = dns.zonetypes.DigestScheme.make(scheme)
+ self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm)
self.digest = self._as_bytes(digest)
if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2
if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3
raise ValueError('hash_algorithm 0 is reserved')
- hasher = dns.zone._digest_hashers.get(self.hash_algorithm)
+ hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm)
if hasher and hasher().digest_size != len(self.digest):
raise ValueError('digest length inconsistent with hash algorithm')
import dns.exception
import dns.immutable
-import dns.dnssec
+import dns.dnssectypes
import dns.rdata
# wildcard import
super().__init__(rdclass, rdtype)
self.flags = self._as_uint16(flags)
self.protocol = self._as_uint8(protocol)
- self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+ self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw):
+++ /dev/null
-from typing import Set, Any
-
-SEP : int
-REVOKE : int
-ZONE : int
-
-def flags_to_text_set(flags : int) -> Set[str]:
- ...
-
-def flags_from_text_set(texts_set) -> int:
- ...
-
-from .. import rdata
-
-class DNSKEYBase(rdata.Rdata):
- def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
- self.flags : int
- self.protocol : int
- self.key : str
- self.algorithm : int
-
- def to_text(self, origin : Any = None, relativize=True, **kw : Any):
- ...
-
- @classmethod
- def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
- relativize_to=None):
- ...
-
- def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
- ...
-
- @classmethod
- def from_parser(cls, rdclass, rdtype, parser, origin=None):
- ...
-
- def flags_to_text_set(self) -> Set[str]:
- ...
import struct
import binascii
-import dns.dnssec
+import dns.dnssectypes
import dns.immutable
import dns.rdata
import dns.rdatatype
digest):
super().__init__(rdclass, rdtype)
self.key_tag = self._as_uint16(key_tag)
- self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+ self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.digest_type = self._as_uint8(digest_type)
self.digest = self._as_bytes(digest)
try:
"""TXT-like base class."""
+from typing import Iterable, Optional, Tuple, Union
+
import struct
import dns.exception
__slots__ = ['strings']
- def __init__(self, rdclass, rdtype, strings):
+ def __init__(self, rdclass, rdtype, strings: Iterable[Union[bytes, str]]):
"""Initialize a TXT-like rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
*strings*, a tuple of ``bytes``
"""
super().__init__(rdclass, rdtype)
- self.strings = self._as_tuple(strings,
- lambda x: self._as_bytes(x, True, 255))
+ self.strings: Tuple[bytes] = self._as_tuple(strings, lambda x: self._as_bytes(x, True, 255))
- def to_text(self, origin=None, relativize=True, **kw):
+ def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw):
txt = ''
prefix = ''
for s in self.strings:
return txt
@classmethod
- def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
- relativize_to=None):
+ def from_text(cls, rdclass, rdtype, tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None,
+ relativize=True, relativize_to: Optional[dns.name.Name]=None):
strings = []
for token in tok.get_remaining():
token = token.unescape_to_bytes()
+++ /dev/null
-import typing
-from .. import rdata
-
-class TXTBase(rdata.Rdata):
- strings: typing.Tuple[bytes, ...]
-
- def __init__(self, rdclass: int, rdtype: int, strings: typing.Iterable[bytes]) -> None:
- ...
- def to_text(self, origin: typing.Any, relativize: bool, **kw: typing.Any) -> str:
- ...
-class TXT(TXTBase):
- ...
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS stub resolver."""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
from urllib.parse import urlparse
import contextlib
import socket
# pylint: disable=arguments-differ
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
def _check_kwargs(self, qnames,
responses=None):
if not isinstance(qnames, (list, tuple, set)):
"""The DNS query name is too long after DNAME substitution."""
-def _errors_to_text(errors):
+ErrorTuple = Tuple[str, bool, int, Exception, dns.message.Message]
+
+
+def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text."""
texts = []
for err in errors:
fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1]
supp_kwargs = {'timeout', 'errors'}
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
def _fmt_kwargs(self, **kwargs):
srv_msgs = _errors_to_text(kwargs['errors'])
return super()._fmt_kwargs(timeout=kwargs['timeout'],
'to the question: {query}'
supp_kwargs = {'response'}
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
def _fmt_kwargs(self, **kwargs):
return super()._fmt_kwargs(query=kwargs['response'].question)
fmt = "%s {query}: {errors}" % msg[:-1]
supp_kwargs = {'request', 'errors'}
+ # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
def _fmt_kwargs(self, **kwargs):
srv_msgs = _errors_to_text(kwargs['errors'])
return super()._fmt_kwargs(query=kwargs['request'].question,
RRset's name might not be the query name.
"""
- def __init__(self, qname, rdtype, rdclass, response, nameserver=None,
- port=None):
+ def __init__(self, qname: dns.name.Name, rdtype: dns.rdatatype.RdataType,
+ rdclass: dns.rdataclass.RdataClass, response: dns.message.QueryMessage,
+ nameserver: Optional[str]=None, port: Optional[int]=None):
self.qname = qname
self.rdtype = rdtype
self.rdclass = rdclass
self.hits = 0
self.misses = 0
- def clone(self):
+ def clone(self) -> 'CacheStatistics':
return CacheStatistics(self.hits, self.misses)
with self.lock:
return self.statistics.misses
- def get_statistics_snapshot(self):
+ def get_statistics_snapshot(self) -> CacheStatistics:
"""Return a consistent snapshot of all the statistics.
If running with multiple threads, it's better to take a
return self.statistics.clone()
+CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass]
+
+
class Cache(CacheBase):
"""Simple thread-safe DNS answer cache."""
now = time.time()
self.next_cleaning = now + self.cleaning_interval
- def get(self, key):
+ def get(self, key: CacheKey) -> Optional[Answer]:
"""Get the answer associated with *key*.
Returns None if no answer is cached for the key.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
Returns a ``dns.resolver.Answer`` or ``None``.
self.statistics.hits += 1
return v
- def put(self, key, value):
+ def put(self, key: CacheKey, value: Answer):
"""Associate key and value in the cache.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
*value*, a ``dns.resolver.Answer``, the answer.
self._maybe_clean()
self.data[key] = value
- def flush(self, key=None):
+ def flush(self, key: Optional[CacheKey]=None):
"""Flush the cache.
If *key* is not ``None``, only that item is flushed. Otherwise
the entire cache is flushed.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
"""
max_size = 1
self.max_size = max_size
- def get(self, key):
+ def get(self, key: CacheKey) -> Optional[Answer]:
"""Get the answer associated with *key*.
Returns None if no answer is cached for the key.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
Returns a ``dns.resolver.Answer`` or ``None``.
node.hits += 1
return node.value
- def get_hits_for_key(self, key):
+ def get_hits_for_key(self, key: CacheKey) -> int:
"""Return the number of cache hits associated with the specified key."""
with self.lock:
node = self.data.get(key)
else:
return node.hits
- def put(self, key, value):
+ def put(self, key: CacheKey, value: Answer):
"""Associate key and value in the cache.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
*value*, a ``dns.resolver.Answer``, the answer.
node.link_after(self.sentinel)
self.data[key] = node
- def flush(self, key=None):
+ def flush(self, key: Optional[CacheKey]=None):
"""Flush the cache.
If *key* is not ``None``, only that item is flushed. Otherwise
the entire cache is flushed.
- *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+ *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
query name, rdtype, and rdclass respectively.
"""
resolver data structures directly.
"""
- def __init__(self, resolver, qname, rdtype, rdclass, tcp,
- raise_on_no_answer, search):
+ def __init__(self, resolver: 'BaseResolver', qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ tcp: bool, raise_on_no_answer: bool, search: Optional[bool]):
if isinstance(qname, str):
qname = dns.name.from_text(qname, None)
rdtype = dns.rdatatype.RdataType.make(rdtype)
self.rdclass = rdclass
self.tcp = tcp
self.raise_on_no_answer = raise_on_no_answer
- self.nxdomain_responses = {}
- #
+ self.nxdomain_responses: Dict[dns.name.Name, Answer] = {}
# Initialize other things to help analysis tools
self.qname = dns.name.empty
- self.nameservers = []
- self.current_nameservers = []
- self.errors = []
- self.nameserver = None
+ self.nameservers: List[str] = []
+ self.current_nameservers: List[str] = []
+ self.errors: List[ErrorTuple] = []
+ self.nameserver: Optional[str] = None
self.port = 0
self.tcp_attempt = False
self.retry_with_tcp = False
- self.request = None
- self.backoff = 0
+ self.request: Optional[dns.message.QueryMessage] = None
+ self.backoff = 0.0
- def next_request(self):
+ def next_request(self) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]:
"""Get the next request to send, and check the cache.
Returns a (request, answer) tuple. At most one of request or
dns.rcode.to_text(rcode), response))
return (None, False)
+
class BaseResolver:
"""DNS stub resolver."""
dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
if len(self.domain) == 0:
self.domain = dns.name.root
- self.nameservers = []
- self.nameserver_ports = {}
+ self.nameservers: List[str] = []
+ self.nameserver_ports: Dict[str, int] = {}
self.port = 53
- self.search = []
+ self.search: List[dns.name.Name] = []
self.use_search_by_default = False
self.timeout = 2.0
self.lifetime = 5.0
self.keyalgorithm = dns.tsig.default_algorithm
self.edns = -1
self.ednsflags = 0
- self.ednsoptions = None
+ self.ednsoptions: Optional[List[dns.edns.Option]] = None
self.payload = 0
self.cache = None
self.flags = None
self.retry_servfail = False
self.rotate = False
- self.ndots = None
+ self.ndots: Optional[int] = None
def read_resolv_conf(self, f):
"""Process *f* as a file in the /etc/resolv.conf format. If f is
except AttributeError:
raise NotImplementedError
- def _compute_timeout(self, start, lifetime=None, errors=None):
+ def _compute_timeout(self, start: float, lifetime: Optional[float]=None,
+ errors: Optional[List[ErrorTuple]]=None) -> float:
lifetime = self.lifetime if lifetime is None else lifetime
now = time.time()
duration = now - start
raise LifetimeTimeout(timeout=duration, errors=errors)
return min(lifetime - duration, self.timeout)
- def _get_qnames_to_try(self, qname, search):
+ def _get_qnames_to_try(self, qname: dns.name.Name, search: Optional[bool]) -> List[dns.name.Name]:
# This is a separate method so we can unit test the search
# rules without requiring the Internet.
if search is None:
self.payload = payload
self.ednsoptions = options
- def set_flags(self, flags):
+ def set_flags(self, flags: int):
"""Overrides the default flags with your own.
*flags*, an ``int``, the message flags to use.
self.flags = flags
@property
- def nameservers(self):
+ def nameservers(self) -> List[str]:
return self._nameservers
@nameservers.setter
- def nameservers(self, nameservers):
+ def nameservers(self, nameservers: List[str]):
"""
*nameservers*, a ``list`` of nameservers.
class Resolver(BaseResolver):
"""DNS stub resolver."""
- def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True, source_port=0,
- lifetime=None, search=None): # pylint: disable=arguments-differ
+ def resolve(self, qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+ tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pylint: disable=arguments-differ
"""Query nameservers to find the answer to the question.
The *qname*, *rdtype*, and *rdclass* parameters may be objects
if answer is not None:
# cache hit!
return answer
+ assert request is not None # needed for type checking
done = False
while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
if answer is not None:
return answer
- def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True, source_port=0,
- lifetime=None): # pragma: no cover
+ def query(self, qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+ tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question.
This method calls resolve() with ``search=True``, and is
raise_on_no_answer, source_port, lifetime,
True)
- def resolve_address(self, ipaddr, *args, **kwargs):
+ def resolve_address(self, ipaddr: str, *args, **kwargs) -> Answer:
"""Use a resolver to run a reverse query for PTR records.
This utilizes the resolve() method to perform a PTR lookup on the
except for rdtype and rdclass are also supported by this
function.
"""
-
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs['rdtype'] = dns.rdatatype.PTR
+ modified_kwargs['rdclass'] = dns.rdataclass.IN
return self.resolve(dns.reversename.from_address(ipaddr),
- rdtype=dns.rdatatype.PTR,
- rdclass=dns.rdataclass.IN,
- *args, **kwargs)
+ *args, **modified_kwargs)
# pylint: disable=redefined-outer-name
- def canonical_name(self, name):
+ def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
#: The default resolver.
-default_resolver = None
+default_resolver: Optional[Resolver] = None
-def get_default_resolver():
+def get_default_resolver() -> Resolver:
"""Get the default resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
+ assert default_resolver is not None
return default_resolver
default_resolver = Resolver()
-def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime=None, search=None):
+def resolve(qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+ tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pragma: no cover
+
"""Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver
raise_on_no_answer, source_port,
lifetime, search)
-def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime=None): # pragma: no cover
+def query(qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+ tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+ lifetime: Optional[float]=None) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question.
This method calls resolve() with ``search=True``, and is
True)
-def resolve_address(ipaddr, *args, **kwargs):
+def resolve_address(ipaddr: str, *args, **kwargs) -> Answer:
"""Use a resolver to run a reverse query for PTR records.
See ``dns.resolver.Resolver.resolve_address`` for more information on the
return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
-def canonical_name(name):
+def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See ``dns.resolver.Resolver.canonical_name`` for more information on the
return get_default_resolver().canonical_name(name)
-def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None,
- lifetime=None):
+def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
+ tcp=False, resolver: Optional[Resolver]=None,
+ lifetime: Optional[float]=None) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
*name*, an absolute ``dns.name.Name`` or ``str``, the query name.
rlifetime = None
answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp,
lifetime=rlifetime)
+ assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher
return (canonical, aliases, addresses)
-def override_system_resolver(resolver=None):
+def override_system_resolver(resolver: Optional[Resolver]=None):
"""Override the system resolver routines in the socket module with
versions which use dnspython's resolver.
+++ /dev/null
-from typing import Union, Optional, List, Any, Dict
-from . import exception, rdataclass, name, rdatatype
-
-import socket
-_gethostbyname = socket.gethostbyname
-
-class NXDOMAIN(exception.DNSException): ...
-class YXDOMAIN(exception.DNSException): ...
-class NoAnswer(exception.DNSException): ...
-class NoNameservers(exception.DNSException): ...
-class NotAbsolute(exception.DNSException): ...
-class NoRootSOA(exception.DNSException): ...
-class NoMetaqueries(exception.DNSException): ...
-class NoResolverConfiguration(exception.DNSException): ...
-Timeout = exception.Timeout
-
-def resolve(qname : str, rdtype : Union[int,str] = 0,
- rdclass : Union[int,str] = 0,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime : Optional[float]=None,
- search : Optional[bool]=None):
- ...
-def query(qname : str, rdtype : Union[int,str] = 0,
- rdclass : Union[int,str] = 0,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, lifetime : Optional[float]=None):
- ...
-def resolve_address(ipaddr: str, *args: Any, **kwargs: Optional[Dict]):
- ...
-class LRUCache:
- def __init__(self, max_size=1000):
- ...
- def get(self, key):
- ...
- def put(self, key, val):
- ...
-class Answer:
- def __init__(self, qname, rdtype, rdclass, response,
- raise_on_no_answer=True):
- ...
-def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False,
- resolver : Optional[Resolver] = None):
- ...
-
-class Resolver:
- def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
- configure : Optional[bool] = True):
- self.nameservers : List[str]
- def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
- rdclass : Union[int,str] = rdataclass.IN,
- tcp : bool = False, source : Optional[str] = None,
- raise_on_no_answer=True, source_port : int = 0,
- lifetime : Optional[float]=None,
- search : Optional[bool]=None):
- ...
- def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
- rdclass : Union[int,str] = rdataclass.IN,
- tcp : bool = False, source : Optional[str] = None,
- raise_on_no_answer=True, source_port : int = 0,
- lifetime : Optional[float]=None):
- ...
-default_resolver: typing.Optional[Resolver]
-def reset_default_resolver() -> None:
- ...
-def get_default_resolver() -> Resolver:
- ...
ipv6_reverse_domain = dns.name.from_text('ip6.arpa.')
-def from_address(text, v4_origin=ipv4_reverse_domain,
- v6_origin=ipv6_reverse_domain):
+def from_address(text: str, v4_origin=ipv4_reverse_domain,
+ v6_origin=ipv6_reverse_domain) -> dns.name.Name:
"""Convert an IPv4 or IPv6 address in textual form into a Name object whose
value is the reverse-map domain name of the address.
return dns.name.from_text('.'.join(reversed(parts)), origin=origin)
-def to_address(name, v4_origin=ipv4_reverse_domain,
- v6_origin=ipv6_reverse_domain):
+def to_address(name: dns.name.Name, v4_origin=ipv4_reverse_domain,
+ v6_origin=ipv6_reverse_domain) -> str:
"""Convert a reverse map domain name into textual address form.
*name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name
+++ /dev/null
-from . import name
-def from_address(text : str) -> name.Name:
- ...
-
-def to_address(name : name.Name) -> str:
- ...
"""DNS RRsets (an RRset is a named rdataset)"""
+from typing import cast, Collection, Optional, Union
import dns.name
import dns.rdataset
__slots__ = ['name', 'deleting']
- def __init__(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE,
- deleting=None):
+ def __init__(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass]=None):
"""Create a new RRset."""
super().__init__(rdclass, rdtype, covers)
return False
return super().__eq__(other)
- def match(self, *args, **kwargs):
+ def match(self, *args, **kwargs) -> bool:
"""Does this rrset match the specified attributes?
Behaves as :py:func:`full_match()` if the first argument is a
else:
return super().match(*args, **kwargs)
- def full_match(self, name, rdclass, rdtype, covers,
- deleting=None):
+ def full_match(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType,
+ deleting: Optional[dns.rdataclass.RdataClass]=None) -> bool:
"""Returns ``True`` if this rrset matches the specified name, class,
type, covers, and deletion state.
"""
# pylint: disable=arguments-differ
- def to_text(self, origin=None, relativize=True, **kw):
+ def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw) -> str: # type: ignore
"""Convert the RRset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
return super().to_text(self.name, origin, relativize,
self.deleting, **kw)
- def to_wire(self, file, compress=None, origin=None,
- **kw):
+ def to_wire(self, file, compress: Optional[dns.name.CompressType]=None, # type: ignore
+ origin: Optional[dns.name.Name]=None, **kw) -> int:
"""Convert the RRset to wire format.
All keyword arguments are passed to ``dns.rdataset.to_wire()``; see
# pylint: enable=arguments-differ
- def to_rdataset(self):
+ def to_rdataset(self) -> dns.rdataset.Rdataset:
"""Convert an RRset into an Rdataset.
Returns a ``dns.rdataset.Rdataset``.
return dns.rdataset.from_rdata_list(self.ttl, list(self))
-def from_text_list(name, ttl, rdclass, rdtype, text_rdatas,
- idna_codec=None, origin=None, relativize=True,
- relativize_to=None):
+def from_text_list(name: Union[dns.name.Name, str], ttl: int,
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ text_rdatas: Collection[str],
+ idna_codec: Optional[dns.name.IDNACodec]=None,
+ origin: Optional[dns.name.Name]=None, relativize=True,
+ relativize_to: Optional[dns.name.Name]=None) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type, and with
the specified list of rdatas in text format.
if isinstance(name, str):
name = dns.name.from_text(name, None, idna_codec=idna_codec)
- rdclass = dns.rdataclass.RdataClass.make(rdclass)
- rdtype = dns.rdatatype.RdataType.make(rdtype)
- r = RRset(name, rdclass, rdtype)
+ the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = RRset(name, the_rdclass, the_rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
return r
-def from_text(name, ttl, rdclass, rdtype, *text_rdatas):
+def from_text(name: Union[dns.name.Name, str], ttl: int,
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ *text_rdatas) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type and with
the specified rdatas in text format.
Returns a ``dns.rrset.RRset`` object.
"""
- return from_text_list(name, ttl, rdclass, rdtype, text_rdatas)
+ return from_text_list(name, ttl, rdclass, rdtype,
+ cast(Collection[str], text_rdatas))
-def from_rdata_list(name, ttl, rdatas, idna_codec=None):
+def from_rdata_list(name: Union[dns.name.Name, str], ttl:int,
+ rdatas: Collection[dns.rdata.Rdata],
+ idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset:
"""Create an RRset with the specified name and TTL, and with
the specified list of rdata objects.
r = RRset(name, rd.rdclass, rd.rdtype)
r.update_ttl(ttl)
r.add(rd)
+ assert r is not None
return r
-def from_rdata(name, ttl, *rdatas):
+def from_rdata(name: Union[dns.name.Name, str], ttl:int, *rdatas) -> RRset:
"""Create an RRset with the specified name and TTL, and with
the specified rdata objects.
Returns a ``dns.rrset.RRset`` object.
"""
- return from_rdata_list(name, ttl, rdatas)
+ return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas))
+++ /dev/null
-from typing import List, Optional
-from . import rdataset, rdatatype
-
-class RRset(rdataset.Rdataset):
- def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE,
- deleting : Optional[int] =None) -> None:
- self.name = name
- self.deleting = deleting
-def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str):
- ...
"""Serial Number Arthimetic from RFC 1982"""
class Serial:
- def __init__(self, value, bits=32):
+ def __init__(self, value:int , bits=32):
self.value = value % 2 ** bits
self.bits = bits
"""Tokenize DNS zone file format"""
+from typing import Optional, List, Tuple
+
import io
import sys
has_escape: Does the token value contain escapes?
"""
- def __init__(self, ttype, value='', has_escape=False, comment=None):
+ def __init__(self, ttype: int, value='', has_escape=False, comment: Optional[str]=None):
"""Initialize a token instance."""
self.ttype = ttype
self.has_escape = has_escape
self.comment = comment
- def is_eof(self):
+ def is_eof(self) -> bool:
return self.ttype == EOF
- def is_eol(self):
+ def is_eol(self) -> bool:
return self.ttype == EOL
- def is_whitespace(self):
+ def is_whitespace(self) -> bool:
return self.ttype == WHITESPACE
- def is_identifier(self):
+ def is_identifier(self) -> bool:
return self.ttype == IDENTIFIER
- def is_quoted_string(self):
+ def is_quoted_string(self) -> bool:
return self.ttype == QUOTED_STRING
- def is_comment(self):
+ def is_comment(self) -> bool:
return self.ttype == COMMENT
- def is_delimiter(self): # pragma: no cover (we don't return delimiters yet)
+ def is_delimiter(self) -> bool: # pragma: no cover (we don't return delimiters yet)
return self.ttype == DELIMITER
- def is_eol_or_eof(self):
+ def is_eol_or_eof(self) -> bool:
return self.ttype == EOL or self.ttype == EOF
def __eq__(self, other):
def __str__(self):
return '%d "%s"' % (self.ttype, self.value)
- def unescape(self):
+ def unescape(self) -> 'Token':
if not self.has_escape:
return self
unescaped = ''
unescaped += c
return Token(self.ttype, unescaped)
- def unescape_to_bytes(self):
+ def unescape_to_bytes(self) -> 'Token':
# We used to use unescape() for TXT-like records, but this
# caused problems as we'd process DNS escapes into Unicode code
# points instead of byte values, and then a to_text() of the
encoder/decoder is used.
"""
- def __init__(self, f=sys.stdin, filename=None, idna_codec=None):
+ def __init__(self, f=sys.stdin, filename: Optional[str]=None,
+ idna_codec: Optional[dns.name.IDNACodec]=None):
"""Initialize a tokenizer instance.
f: The file to tokenize. The default is sys.stdin.
else:
filename = '<file>'
self.file = f
- self.ungotten_char = None
- self.ungotten_token = None
+ self.ungotten_char: Optional[str] = None
+ self.ungotten_token: Optional[Token] = None
self.multiline = 0
self.quoting = False
self.eof = False
self.delimiters = _DELIMITERS
self.line_number = 1
+ assert filename is not None
self.filename = filename
if idna_codec is None:
- idna_codec = dns.name.IDNA_2003
- self.idna_codec = idna_codec
+ self.idna_codec: dns.name.IDNACodec = dns.name.IDNA_2003
+ else:
+ self.idna_codec = idna_codec
- def _get_char(self):
+ def _get_char(self) -> str:
"""Read a character from input.
"""
self.ungotten_char = None
return c
- def where(self):
+ def where(self) -> Tuple[str, int]:
"""Return the current location in the input.
Returns a (string, int) tuple. The first item is the filename of
return skipped
skipped += 1
- def get(self, want_leading=False, want_comment=False):
+ def get(self, want_leading=False, want_comment=False) -> Token:
"""Get the next token.
want_leading: If True, return a WHITESPACE token if the
"""
if self.ungotten_token is not None:
- token = self.ungotten_token
+ utoken = self.ungotten_token
self.ungotten_token = None
- if token.is_whitespace():
+ if utoken.is_whitespace():
if want_leading:
- return token
- elif token.is_comment():
+ return utoken
+ elif utoken.is_comment():
if want_comment:
- return token
+ return utoken
else:
- return token
+ return utoken
skipped = self.skip_whitespace()
if want_leading and skipped > 0:
return Token(WHITESPACE, ' ')
ttype = EOF
return Token(ttype, token, has_escape)
- def unget(self, token):
+ def unget(self, token: Token):
"""Unget a token.
The unget buffer for tokens is only one token large; it is
raise dns.exception.SyntaxError('expecting an integer')
return int(token.value, base)
- def get_uint8(self):
+ def get_uint8(self) -> int:
"""Read the next token and interpret it as an 8-bit unsigned
integer.
'%d is not an unsigned 8-bit integer' % value)
return value
- def get_uint16(self, base=10):
+ def get_uint16(self, base=10) -> int:
"""Read the next token and interpret it as a 16-bit unsigned
integer.
'%d is not an unsigned 16-bit integer' % value)
return value
- def get_uint32(self, base=10):
+ def get_uint32(self, base=10) -> int:
"""Read the next token and interpret it as a 32-bit unsigned
integer.
'%d is not an unsigned 32-bit integer' % value)
return value
- def get_uint48(self, base=10):
+ def get_uint48(self, base=10) -> int:
"""Read the next token and interpret it as a 48-bit unsigned
integer.
'%d is not an unsigned 48-bit integer' % value)
return value
- def get_string(self, max_length=None):
+ def get_string(self, max_length=None) -> str:
"""Read the next token and interpret it as a string.
Raises dns.exception.SyntaxError if not a string.
raise dns.exception.SyntaxError("string too long")
return token.value
- def get_identifier(self):
+ def get_identifier(self) -> str:
"""Read the next token, which should be an identifier.
Raises dns.exception.SyntaxError if not an identifier.
raise dns.exception.SyntaxError('expecting an identifier')
return token.value
- def get_remaining(self, max_tokens=None):
+ def get_remaining(self, max_tokens=None) -> List[Token]:
"""Return the remaining tokens on the line, until an EOL or EOF is seen.
max_tokens: If not None, stop after this number of tokens.
break
return tokens
- def concatenate_remaining_identifiers(self, allow_empty=False):
+ def concatenate_remaining_identifiers(self, allow_empty=False) -> str:
"""Read the remaining tokens on the line, which should be identifiers.
Raises dns.exception.SyntaxError if there are no remaining tokens,
raise dns.exception.SyntaxError('expecting another identifier')
return s
- def as_name(self, token, origin=None, relativize=False, relativize_to=None):
+ def as_name(self, token: Token, origin: Optional[dns.name.Name]=None,
+ relativize=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
"""Try to interpret the token as a DNS name.
Raises dns.exception.SyntaxError if not a name.
name = dns.name.from_text(token.value, origin, self.idna_codec)
return name.choose_relativity(relativize_to or origin, relativize)
- def get_name(self, origin=None, relativize=False, relativize_to=None):
+ def get_name(self, origin: Optional[dns.name.Name]=None, relativize=False,
+ relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
"""Read the next token and interpret it as a DNS name.
Raises dns.exception.SyntaxError if not a name.
token = self.get()
return self.as_name(token, origin, relativize, relativize_to)
- def get_eol_as_token(self):
+ def get_eol_as_token(self) -> Token:
"""Read the next token and raise an exception if it isn't EOL or
EOF.
token.value))
return token
- def get_eol(self):
+ def get_eol(self) -> str:
return self.get_eol_as_token().value
- def get_ttl(self):
+ def get_ttl(self) -> int:
"""Read the next token and interpret it as a DNS TTL.
Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+from typing import Callable, List, Optional, Tuple, Union
+
import collections
import dns.exception
import dns.name
+import dns.node
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
class TransactionManager:
- def reader(self):
+ def reader(self) -> 'Transaction':
"""Begin a read-only transaction."""
raise NotImplementedError # pragma: no cover
- def writer(self, replacement=False):
+ def writer(self, replacement=False) -> 'Transaction':
"""Begin a writable transaction.
*replacement*, a ``bool``. If `True`, the content of the
"""
raise NotImplementedError # pragma: no cover
- def origin_information(self):
+ def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
"""Returns a tuple
(absolute_origin, relativize, effective_origin)
"""
raise NotImplementedError # pragma: no cover
- def get_class(self):
+ def get_class(self) -> dns.rdataclass.RdataClass:
"""The class of the transaction manager.
"""
raise NotImplementedError # pragma: no cover
- def from_wire_origin(self):
+ def from_wire_origin(self) -> Optional[dns.name.Name]:
"""Origin to use in from_wire() calls.
"""
(absolute_origin, relativize, _) = self.origin_information()
return dns.node.ImmutableNode(node)
+CheckPutRdatasetType = Callable[['Transaction', dns.name.Name, dns.rdataset.Rdataset], None]
+CheckDeleteRdatasetType = Callable[['Transaction', dns.name.Name,
+ dns.rdatatype.RdataType, dns.rdatatype.RdataType], None]
+CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None]
+
+
class Transaction:
- def __init__(self, manager, replacement=False, read_only=False):
+ def __init__(self, manager: TransactionManager, replacement=False, read_only=False):
self.manager = manager
self.replacement = replacement
self.read_only = read_only
self._ended = False
- self._check_put_rdataset = []
- self._check_delete_rdataset = []
- self._check_delete_name = []
+ self._check_put_rdataset: List[CheckPutRdatasetType]= []
+ self._check_delete_rdataset: List[CheckDeleteRdatasetType] = []
+ self._check_delete_name: List[CheckDeleteNameType] = []
#
# This is the high level API
#
+ # Note that we currently use non-immutable types in the return type signature to avoid
+ # covariance problems, e.g. if the caller has a List[Rdataset], mypy will be unhappy if we
+ # return an ImmutableRdataset.
- def get(self, name, rdtype, covers=dns.rdatatype.NONE):
+ def get(self, name: Optional[Union[dns.name.Name,str]],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rdataset.Rdataset:
"""Return the rdataset associated with *name*, *rdtype*, and *covers*,
or `None` if not found.
if isinstance(name, str):
name = dns.name.from_text(name, None)
rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
rdataset = self._get_rdataset(name, rdtype, covers)
return _ensure_immutable_rdataset(rdataset)
- def get_node(self, name):
+ def get_node(self, name) -> dns.node.Node:
"""Return the node at *name*, if any.
Returns an immutable node or ``None``.
self._check_read_only()
return self._delete(True, args)
- def name_exists(self, name):
+ def name_exists(self, name: Union[dns.name.Name, str]) -> bool:
"""Does the specified name exist?"""
self._check_ended()
if isinstance(name, str):
self._check_ended()
return self._iterate_rdatasets()
- def changed(self):
+ def changed(self) -> bool:
"""Has this transaction changed anything?
For read-only transactions, the result is always `False`.
"""
self._end(False)
- def check_put_rdataset(self, check):
+ def check_put_rdataset(self, check: CheckPutRdatasetType):
"""Call *check* before putting (storing) an rdataset.
The function is called with the transaction, the name, and the rdataset.
"""
self._check_put_rdataset.append(check)
- def check_delete_rdataset(self, check):
+ def check_delete_rdataset(self, check: CheckDeleteRdatasetType):
"""Call *check* before deleting an rdataset.
The function is called with the transaction, the name, the rdatatype,
"""
self._check_delete_rdataset.append(check)
- def check_delete_name(self, check):
+ def check_delete_name(self, check: CheckDeleteNameType):
"""Call *check* before putting (storing) an rdataset.
The function is called with the transaction and the name.
"""A place to store TSIG keys."""
+from typing import Any, Dict, Union
+
import base64
import dns.name
import dns.tsig
-def from_text(textring):
+def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
"""Convert a dictionary containing (textual DNS name, base64 secret)
pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
a dictionary containing (textual DNS name, (algorithm, base64 secret))
keyring = {}
for (name, value) in textring.items():
- name = dns.name.from_text(name)
+ kname = dns.name.from_text(name)
if isinstance(value, str):
- keyring[name] = dns.tsig.Key(name, value).secret
+ keyring[kname] = dns.tsig.Key(kname, value).secret
else:
(algorithm, secret) = value
- keyring[name] = dns.tsig.Key(name, secret, algorithm)
+ keyring[kname] = dns.tsig.Key(kname, secret, algorithm)
return keyring
-def to_text(keyring):
+def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
"""Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
into a text keyring which has (textual DNS name, (textual algorithm,
base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes)
def b64encode(secret):
return base64.encodebytes(secret).decode().rstrip()
for (name, key) in keyring.items():
- name = name.to_text()
+ tname = name.to_text()
if isinstance(key, bytes):
- textring[name] = b64encode(key)
+ textring[tname] = b64encode(key)
else:
if isinstance(key.secret, bytes):
text_secret = b64encode(key.secret)
else:
text_secret = str(key.secret)
- textring[name] = (key.algorithm.to_text(), text_secret)
+ textring[tname] = (key.algorithm.to_text(), text_secret)
return textring
+++ /dev/null
-from typing import Dict
-from . import name
-
-def from_text(textring : Dict[str,str]) -> Dict[name.Name,bytes]:
- ...
-def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]:
- ...
"""DNS TTL conversion."""
+from typing import Union
+
import dns.exception
# Technically TTLs are supposed to be between 0 and 2**31 - 1, with values
"""DNS TTL value is not well-formed."""
-def from_text(text):
+def from_text(text: str) -> int:
"""Convert the text form of a TTL to an integer.
The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported.
return total
-def make(value):
+def make(value: Union[int, str]) -> int:
if isinstance(value, int):
return value
elif isinstance(value, str):
"""DNS Dynamic Update Support"""
+from typing import Any, Optional, Union
import dns.message
import dns.name
class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
- _section_enum = UpdateSection
+ # ignore the mypy error here as we mean to use a different enum
+ _section_enum = UpdateSection # type: ignore
- def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None,
- keyname=None, keyalgorithm=dns.tsig.default_algorithm,
- id=None):
+ def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None,
+ rdclass=dns.rdataclass.IN,
+ keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None,
+ keyalgorithm=dns.tsig.default_algorithm,
+ id: Optional[int]=None):
"""Initialize a new DNS Update object.
See the documentation of the Message class for a complete
self.origin)
self._add_rr(name, ttl, rd, section=section)
- def add(self, name, *args):
+ def add(self, name: Union[dns.name.Name, str], *args):
"""Add records.
The first argument is always a name. The other
self._add(False, self.update, name, *args)
- def delete(self, name, *args):
+ def delete(self, name: Union[dns.name.Name, str], *args):
"""Delete records.
The first argument is always a name. The other
if len(args) == 0:
self.find_rrset(self.update, name, dns.rdataclass.ANY,
dns.rdatatype.ANY, dns.rdatatype.NONE,
- dns.rdatatype.ANY, True, True)
+ dns.rdataclass.ANY, True, True)
elif isinstance(args[0], dns.rdataset.Rdataset):
for rds in args:
for rd in rds:
self._add_rr(name, 0, rd, dns.rdataclass.NONE)
else:
- args = list(args)
- if isinstance(args[0], dns.rdata.Rdata):
- for rd in args:
+ largs = list(args)
+ if isinstance(largs[0], dns.rdata.Rdata):
+ for rd in largs:
self._add_rr(name, 0, rd, dns.rdataclass.NONE)
else:
- rdtype = dns.rdatatype.RdataType.make(args.pop(0))
- if len(args) == 0:
+ rdtype = dns.rdatatype.RdataType.make(largs.pop(0))
+ if len(largs) == 0:
self.find_rrset(self.update, name,
self.zone_rdclass, rdtype,
dns.rdatatype.NONE,
dns.rdataclass.ANY,
True, True)
else:
- for s in args:
+ for s in largs:
rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s,
self.origin)
self._add_rr(name, 0, rd, dns.rdataclass.NONE)
- def replace(self, name, *args):
+ def replace(self, name: Union[dns.name.Name, str], *args):
"""Replace records.
The first argument is always a name. The other
self._add(True, self.update, name, *args)
- def present(self, name, *args):
+ def present(self, name: Union[dns.name.Name, str], *args):
"""Require that an owner name (and optionally an rdata type,
or specific rdataset) exists as a prerequisite to the
execution of the update.
len(args) > 1:
if not isinstance(args[0], dns.rdataset.Rdataset):
# Add a 0 TTL
- args = list(args)
- args.insert(0, 0)
- self._add(False, self.prerequisite, name, *args)
+ largs = list(args)
+ largs.insert(0, 0)
+ self._add(False, self.prerequisite, name, *largs)
+ else:
+ self._add(False, self.prerequisite, name, *args)
else:
rdtype = dns.rdatatype.RdataType.make(args[0])
self.find_rrset(self.prerequisite, name,
dns.rdatatype.NONE, None,
True, True)
- def absent(self, name, rdtype=None):
+ def absent(self, name: Union[dns.name.Name, str], rdtype=None):
"""Require that an owner name (and optionally an rdata type) does
not exist as a prerequisite to the execution of the update."""
+++ /dev/null
-from typing import Optional,Dict,Union,Any
-
-from . import message, tsig, rdataclass, name
-
-class Update(message.Message):
- def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None,
- keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None:
- self.id : int
- def add(self, name : Union[str,name.Name], *args : Any):
- ...
- def delete(self, name, *args : Any):
- ...
- def replace(self, name : Union[str,name.Name], *args : Any):
- ...
- def present(self, name : Union[str,name.Name], *args : Any):
- ...
- def absent(self, name : Union[str,name.Name], rdtype=None):
- """Require that an owner name (and optionally an rdata type) does
- not exist as a prerequisite to the execution of the update."""
- def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes:
- ...
"""DNS Versioned Zones."""
+from typing import Callable, Deque, Optional, Set, Union
+
import collections
try:
import threading as _threading
node_factory = Node
- def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
- pruning_policy=None):
+ def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass=dns.rdataclass.IN, relativize=True,
+ pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None):
"""Initialize a versioned zone object.
*origin* is the origin of the zone. It may be a ``dns.name.Name``,
*relativize*, a ``bool``, determine's whether domain names are
relativized to the zone's origin. The default is ``True``.
- *pruning policy*, a function taking a `Version` and returning
- a `bool`, or `None`. Should the version be pruned? If `None`,
+ *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning
+ a ``bool``, or ``None``. Should the version be pruned? If ``None``,
the default policy, which retains one version is used.
"""
super().__init__(origin, rdclass, relativize)
- self._versions = collections.deque()
+ self._versions: Deque[Version] = collections.deque()
self._version_lock = _threading.Lock()
if pruning_policy is None:
self._pruning_policy = self._default_pruning_policy
else:
self._pruning_policy = pruning_policy
- self._write_txn = None
- self._write_event = None
- self._write_waiters = collections.deque()
- self._readers = set()
+ self._write_txn: Optional[Transaction] = None
+ self._write_event: Optional[_threading.Event] = None
+ self._write_waiters: Deque[_threading.Event] = collections.deque()
+ self._readers: Set[Transaction] = set()
self._commit_version_unlocked(None,
WritableVersion(self, replacement=True),
origin)
- def reader(self, id=None, serial=None): # pylint: disable=arguments-differ
+ def reader(self, id: Optional[int]=None, serial: Optional[int]=None) -> Transaction: # pylint: disable=arguments-differ
if id is not None and serial is not None:
raise ValueError('cannot specify both id and serial')
with self._version_lock:
if self.relativize:
oname = dns.name.empty
else:
+ assert self.origin is not None
oname = self.origin
version = None
for v in reversed(self._versions):
self._readers.add(txn)
return txn
- def writer(self, replacement=False):
+ def writer(self, replacement=False) -> Transaction:
event = None
while True:
with self._version_lock:
self._pruning_policy(self, self._versions[0]):
self._versions.popleft()
- def set_max_versions(self, max_versions):
+ def set_max_versions(self, max_versions: Optional[int]):
"""Set a pruning policy that retains up to the specified number
of versions
"""
if max_versions is not None and max_versions < 1:
raise ValueError('max versions must be at least 1')
if max_versions is None:
- def policy(*_):
+ def policy(zone, _): # pylint: disable=unused-argument
return False
else:
def policy(zone, _):
return len(zone._versions) > max_versions
self.set_pruning_policy(policy)
- def set_pruning_policy(self, policy):
+ def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]):
"""Set the pruning policy for the zone.
The *policy* function takes a `Version` and returns `True` if
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+from typing import Optional, Tuple
+
import contextlib
import struct
import dns.name
class Parser:
- def __init__(self, wire, current=0):
+ def __init__(self, wire: bytes, current=0):
self.wire = wire
self.current = 0
self.end = len(self.wire)
def remaining(self):
return self.end - self.current
- def get_bytes(self, size):
+ def get_bytes(self, size=int) -> bytes:
+ assert size >= 0
if size > self.remaining():
raise dns.exception.FormError
output = self.wire[self.current:self.current + size]
self.furthest = max(self.furthest, self.current)
return output
- def get_counted_bytes(self, length_size=1):
+ def get_counted_bytes(self, length_size=1) -> bytes:
length = int.from_bytes(self.get_bytes(length_size), 'big')
return self.get_bytes(length)
- def get_remaining(self):
+ def get_remaining(self) -> bytes:
return self.get_bytes(self.remaining())
- def get_uint8(self):
+ def get_uint8(self) -> int:
return struct.unpack('!B', self.get_bytes(1))[0]
- def get_uint16(self):
+ def get_uint16(self) -> int:
return struct.unpack('!H', self.get_bytes(2))[0]
- def get_uint32(self):
+ def get_uint32(self) -> int:
return struct.unpack('!I', self.get_bytes(4))[0]
- def get_uint48(self):
+ def get_uint48(self) -> int:
return int.from_bytes(self.get_bytes(6), 'big')
- def get_struct(self, format):
+ def get_struct(self, format: str) -> Tuple:
return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
- def get_name(self, origin=None):
+ def get_name(self, origin: Optional['dns.name.Name']=None) -> 'dns.name.Name':
name = dns.name.from_wire_parser(self)
if origin:
name = name.relativize(origin)
return name
- def seek(self, where):
+ def seek(self, where: int):
# Note that seeking to the end is OK! (If you try to read
# after such a seek, you'll get an exception as expected.)
if where < 0 or where > self.end:
self.current = where
@contextlib.contextmanager
- def restrict_to(self, size):
+ def restrict_to(self, size: int):
+ assert size >= 0
if size > self.remaining():
raise dns.exception.FormError
saved_end = self.end
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+from typing import Any, List, Optional, Tuple
+
import dns.exception
import dns.message
import dns.name
import dns.rcode
import dns.serial
+import dns.rdataset
import dns.rdatatype
+import dns.transaction
+import dns.tsig
import dns.zone
State machine for zone transfers.
"""
- def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR,
- serial=None, is_udp=False):
+ def __init__(self, txn_manager: dns.transaction.TransactionManager, rdtype=dns.rdatatype.AXFR,
+ serial: Optional[int]=None, is_udp=False):
"""Initialize an inbound zone transfer.
*txn_manager* is a :py:class:`dns.transaction.TransactionManager`.
XFR.
"""
self.txn_manager = txn_manager
- self.txn = None
+ self.txn: Optional[dns.transaction.Transaction] = None
self.rdtype = rdtype
if rdtype == dns.rdatatype.IXFR:
if serial is None:
self.serial = serial
self.is_udp = is_udp
(_, _, self.origin) = txn_manager.origin_information()
- self.soa_rdataset = None
+ self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None
self.done = False
self.expecting_SOA = False
self.delete_mode = False
- def process_message(self, message):
+ def process_message(self, message: dns.message.Message) -> bool:
"""Process one message in the transfer.
The message should have the same relativization as was specified when
rdataset = rrset
if self.done:
raise dns.exception.FormError("answers after final SOA")
+ assert self.txn is not None # for mypy
if rdataset.rdtype == dns.rdatatype.SOA and \
name == self.origin:
#
return False
-def make_query(txn_manager, serial=0,
- use_edns=None, ednsflags=None, payload=None,
- request_payload=None, options=None,
- keyring=None, keyname=None,
- keyalgorithm=dns.tsig.default_algorithm):
+def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0,
+ use_edns=None, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+ request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
+ keyring: Any=None, keyname: Optional[dns.name.Name]=None,
+ keyalgorithm=dns.tsig.default_algorithm) -> Tuple[dns.message.QueryMessage, Optional[int]]:
"""Make an AXFR or IXFR query.
*txn_manager* is a ``dns.transaction.TransactionManager``, typically a
Returns a `(query, serial)` tuple.
"""
(zone_origin, _, origin) = txn_manager.origin_information()
+ if zone_origin is None:
+ raise ValueError('no zone origin')
if serial is None:
rdtype = dns.rdatatype.AXFR
elif not isinstance(serial, int):
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
return (q, serial)
-def extract_serial_from_query(query):
+def extract_serial_from_query(query: dns.message.Message) -> Optional[int]:
"""Extract the SOA serial number from query if it is an IXFR and return
it, otherwise return None.
*query* is a dns.message.QueryMessage that is an IXFR or AXFR request.
Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have
- an appropriate SOA RRset in the authority section."""
-
+ an appropriate SOA RRset in the authority section.
+ """
+ if not isinstance(query, dns.message.QueryMessage):
+ raise ValueError('query not a QueryMessage')
question = query.question[0]
if question.rdtype == dns.rdatatype.AXFR:
return None
"""DNS Zones."""
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
+
import contextlib
import hashlib
import io
import dns.rdataclass
import dns.rdatatype
import dns.rdata
+import dns.rdataset
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.ZONEMD
import dns.rrset
import dns.ttl
import dns.grange
import dns.zonefile
+from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers
class BadZone(dns.exception.DNSException):
"""The ZONEMD digest failed to verify."""
-class DigestScheme(dns.enum.IntEnum):
- """ZONEMD Scheme"""
-
- SIMPLE = 1
-
- @classmethod
- def _maximum(cls):
- return 255
-
-
-class DigestHashAlgorithm(dns.enum.IntEnum):
- """ZONEMD Hash Algorithm"""
-
- SHA384 = 1
- SHA512 = 2
-
- @classmethod
- def _maximum(cls):
- return 255
-
-
-_digest_hashers = {
- DigestHashAlgorithm.SHA384: hashlib.sha384,
- DigestHashAlgorithm.SHA512: hashlib.sha512,
-}
-
-
class Zone(dns.transaction.TransactionManager):
"""A DNS zone.
__slots__ = ['rdclass', 'origin', 'nodes', 'relativize']
- def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True):
+ def __init__(self, origin: Optional[Union[dns.name.Name, str]],
+ rdclass=dns.rdataclass.IN, relativize=True):
"""Initialize a zone object.
*origin* is the origin of the zone. It may be a ``dns.name.Name``,
raise ValueError("origin parameter must be an absolute name")
self.origin = origin
self.rdclass = rdclass
- self.nodes = {}
+ self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
self.relativize = relativize
def __eq__(self, other):
return not self.__eq__(other)
- def _validate_name(self, name):
+ def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
if isinstance(name, str):
name = dns.name.from_text(name, None)
elif not isinstance(name, dns.name.Name):
raise KeyError("name parameter must be convertible to a DNS name")
if name.is_absolute():
+ if self.origin is None:
+ # This should probably never happen as other code (e.g.
+ # _rr_line) will notice the lack of an origin before us, but
+ # we check just in case!
+ raise KeyError('no zone origin is defined')
if not name.is_subdomain(self.origin):
raise KeyError(
"name parameter must be a subdomain of the zone origin")
if self.relativize:
name = name.relativize(self.origin)
+ elif not self.relativize:
+ # We have a relative name in a non-relative zone, so derelativize.
+ if self.origin is None:
+ raise KeyError('no zone origin is defined')
+ name = name.derelativize(self.origin)
return name
def __getitem__(self, key):
key = self._validate_name(key)
return key in self.nodes
- def find_node(self, name, create=False):
+ def find_node(self, name: Union[dns.name.Name, str], create=False):
"""Find a node in the zone, possibly creating it.
*name*: the name of the node to find.
self.nodes[name] = node
return node
- def get_node(self, name, create=False):
+ def get_node(self, name: Union[dns.name.Name, str], create=False):
"""Get a node in the zone, possibly creating it.
This method is like ``find_node()``, except it returns None instead
node = None
return node
- def delete_node(self, name):
+ def delete_node(self, name: Union[dns.name.Name, str]):
"""Delete the specified node if it exists.
*name*: the name of the node to find.
if name in self.nodes:
del self.nodes[name]
- def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
- create=False):
+ def find_rdataset(self, name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
+ create=False) -> dns.rdataset.Rdataset:
"""Look for an rdataset with the specified name and type in the zone,
and return an rdataset encapsulating it.
name must be a subdomain of the zone's origin. If ``zone.relativize``
is ``True``, then the name will be relativized.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdatatype.RdataType`` or ``str`` the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
name = self._validate_name(name)
rdtype = dns.rdatatype.RdataType.make(rdtype)
- if covers is not None:
- covers = dns.rdatatype.RdataType.make(covers)
+ covers = dns.rdatatype.RdataType.make(covers)
node = self.find_node(name, create)
return node.find_rdataset(self.rdclass, rdtype, covers, create)
def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
- create=False):
+ create=False) -> Optional[dns.rdataset.Rdataset]:
"""Look for an rdataset with the specified name and type in the zone.
This method is like ``find_rdataset()``, except it returns None instead
name must be a subdomain of the zone's origin. If ``zone.relativize``
is ``True``, then the name will be relativized.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
rdataset = None
return rdataset
- def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+ def delete_rdataset(self, name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE):
"""Delete the rdataset matching *rdtype* and *covers*, if it
exists at the node specified by *name*.
name must be a subdomain of the zone's origin. If ``zone.relativize``
is ``True``, then the name will be relativized.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
name = self._validate_name(name)
rdtype = dns.rdatatype.RdataType.make(rdtype)
- if covers is not None:
- covers = dns.rdatatype.RdataType.make(covers)
+ covers = dns.rdatatype.RdataType.make(covers)
node = self.get_node(name)
if node is not None:
node.delete_rdataset(self.rdclass, rdtype, covers)
if len(node) == 0:
self.delete_node(name)
- def replace_rdataset(self, name, replacement):
+ def replace_rdataset(self, name: Union[dns.name.Name, str],
+ replacement: dns.rdataset.Rdataset):
"""Replace an rdataset at name.
It is not an error if there is no rdataset matching I{replacement}.
node = self.find_node(name, True)
node.replace_rdataset(replacement)
- def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE):
+ def find_rrset(self, name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rrset.RRset:
"""Look for an rdataset with the specified name and type in the zone,
and return an RRset encapsulating it.
name must be a subdomain of the zone's origin. If ``zone.relativize``
is ``True``, then the name will be relativized.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
Returns a ``dns.rrset.RRset`` or ``None``.
"""
- name = self._validate_name(name)
- rdtype = dns.rdatatype.RdataType.make(rdtype)
- if covers is not None:
- covers = dns.rdatatype.RdataType.make(covers)
- rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers)
- rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers)
+ vname = self._validate_name(name)
+ the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ the_covers = dns.rdatatype.RdataType.make(covers)
+ rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers)
+ rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers)
rrset.update(rdataset)
return rrset
- def get_rrset(self, name, rdtype, covers=dns.rdatatype.NONE):
+ def get_rrset(self, name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> Optional[dns.rrset.RRset]:
"""Look for an rdataset with the specified name and type in the zone,
and return an RRset encapsulating it.
name must be a subdomain of the zone's origin. If ``zone.relativize``
is ``True``, then the name will be relativized.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
return rrset
def iterate_rdatasets(self, rdtype=dns.rdatatype.ANY,
- covers=dns.rdatatype.NONE):
+ covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
"""Return a generator which yields (name, rdataset) tuples for
all rdatasets in the zone which have the specified *rdtype*
and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default,
then all rdatasets will be matched.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
"""
rdtype = dns.rdatatype.RdataType.make(rdtype)
- if covers is not None:
- covers = dns.rdatatype.RdataType.make(covers)
+ covers = dns.rdatatype.RdataType.make(covers)
for (name, node) in self.items():
for rds in node:
if rdtype == dns.rdatatype.ANY or \
yield (name, rds)
def iterate_rdatas(self, rdtype=dns.rdatatype.ANY,
- covers=dns.rdatatype.NONE):
+ covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]:
"""Return a generator which yields (name, ttl, rdata) tuples for
all rdatas in the zone which have the specified *rdtype*
and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default,
then all rdatas will be matched.
- *rdtype*, an ``int`` or ``str``, the rdata type desired.
+ *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
- *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+ *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
"""
rdtype = dns.rdatatype.RdataType.make(rdtype)
- if covers is not None:
- covers = dns.rdatatype.RdataType.make(covers)
+ covers = dns.rdatatype.RdataType.make(covers)
for (name, node) in self.items():
for rds in node:
if rdtype == dns.rdatatype.ANY or \
for rdata in rds:
yield (name, rds.ttl, rdata)
- def to_file(self, f, sorted=True, relativize=True, nl=None,
+ def to_file(self, f: Any, sorted=True, relativize=True, nl: Optional[str]=None,
want_comments=False, want_origin=False):
"""Write a zone to a file.
nl = nl.decode()
if want_origin:
+ assert self.origin is not None
l = '$ORIGIN ' + self.origin.to_text()
l_b = l.encode(file_enc)
try:
f.write(l)
f.write(nl)
- def to_text(self, sorted=True, relativize=True, nl=None,
+ def to_text(self, sorted=True, relativize=True, nl: Optional[str]=None,
want_comments=False, want_origin=False):
"""Return a zone's text as though it were written to a file.
if self.get_rdataset(name, dns.rdatatype.NS) is None:
raise NoNS
- def get_soa(self, txn=None):
+ def get_soa(self, txn: Optional[dns.transaction.Transaction]=None):
"""Get the zone SOA RR.
Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
if self.relativize:
origin_name = dns.name.empty
else:
+ if self.origin is None:
+ # get_soa() has been called very early, and there must not be
+ # an SOA if there is no origin.
+ raise NoSOA
origin_name = self.origin
+ soa: Optional[dns.rdataset.Rdataset]
if txn:
soa = txn.get(origin_name, dns.rdatatype.SOA)
else:
raise NoSOA
return soa[0]
- def _compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE):
+ def _compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> bytes:
hashinfo = _digest_hashers.get(hash_algorithm)
if not hashinfo:
raise UnsupportedDigestHashAlgorithm
if self.relativize:
origin_name = dns.name.empty
else:
+ assert self.origin is not None
origin_name = self.origin
hasher = hashinfo()
for (name, node) in sorted(self.items()):
hasher.update(rrnamebuf + rrfixed + rrlen + rdata)
return hasher.digest()
- def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE):
- if self.relativize:
- origin_name = dns.name.empty
- else:
- origin_name = self.origin
+ def compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD:
serial = self.get_soa().serial
digest = self._compute_digest(hash_algorithm, scheme)
return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass,
serial, scheme, hash_algorithm,
digest)
- def verify_digest(self, zonemd=None):
+ def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None):
+ digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]]
if zonemd:
digests = [zonemd]
else:
- digests = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
- if digests is None:
+ rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
+ if rds is None:
raise NoDigest
+ digests = rds
for digest in digests:
try:
computed = self._compute_digest(digest.hash_algorithm,
# TransactionManager methods
- def reader(self):
+ def reader(self) -> 'Transaction':
return Transaction(self, False,
Version(self, 1, self.nodes, self.origin))
- def writer(self, replacement=False):
+ def writer(self, replacement=False) -> 'Transaction':
txn = Transaction(self, replacement)
txn._setup_version()
return txn
- def origin_information(self):
+ def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
+ effective: Optional[dns.name.Name]
if self.relativize:
effective = dns.name.empty
else:
class Version:
- def __init__(self, zone, id, nodes=None, origin=None):
+ def __init__(self, zone: Zone, id: int,
+ nodes: Optional[Dict[dns.name.Name, dns.node.Node]]=None,
+ origin: Optional[dns.name.Name]=None):
self.zone = zone
self.id = id
if nodes is not None:
self.nodes = {}
self.origin = origin
- def _validate_name(self, name):
+ def _validate_name(self, name: dns.name.Name):
if name.is_absolute():
if self.origin is None:
# This should probably never happen as other code (e.g.
raise KeyError("name is not a subdomain of the zone origin")
if self.zone.relativize:
name = name.relativize(self.origin)
+ elif not self.zone.relativize:
+ # We have a relative name in a non-relative zone, so derelativize.
+ if self.origin is None:
+ raise KeyError('no zone origin is defined')
+ name = name.derelativize(self.origin)
return name
- def get_node(self, name):
+ def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
name = self._validate_name(name)
return self.nodes.get(name)
- def get_rdataset(self, name, rdtype, covers):
+ def get_rdataset(self, name: dns.name.Name, rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType) -> Optional[dns.rdataset.Rdataset]:
node = self.get_node(name)
if node is None:
return None
class WritableVersion(Version):
- def __init__(self, zone, replacement=False):
+ def __init__(self, zone: Zone, replacement=False):
# The zone._versions_lock must be held by our caller in a versioned
# zone.
id = zone._get_next_version_id()
# We have to copy the zone origin as it may be None in the first
# version, and we don't want to mutate the zone until we commit.
self.origin = zone.origin
- self.changed = set()
+ self.changed: Set[dns.name.Name] = set()
- def _maybe_cow(self, name):
+ def _maybe_cow(self, name: dns.name.Name):
name = self._validate_name(name)
node = self.nodes.get(name)
if node is None or name not in self.changed:
# code used new_node.id != self.id for the "do we need to CoW?"
# test. Now we use the changed set as this works with both
# regular zones and versioned zones.
- new_node.id = self.id
+ #
+ # We ignore the mypy error as this is safe but it doesn't see it.
+ new_node.id = self.id # type: ignore
if node is not None:
# moo! copy on write!
new_node.rdatasets.extend(node.rdatasets)
else:
return node
- def delete_node(self, name):
+ def delete_node(self, name: dns.name.Name):
name = self._validate_name(name)
if name in self.nodes:
del self.nodes[name]
self.changed.add(name)
- def put_rdataset(self, name, rdataset):
+ def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset):
node = self._maybe_cow(name)
node.replace_rdataset(rdataset)
- def delete_rdataset(self, name, rdtype, covers):
+ def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType):
node = self._maybe_cow(name)
node.delete_rdataset(self.zone.rdclass, rdtype, covers)
if len(node) == 0:
@dns.immutable.immutable
class ImmutableVersion(Version):
- def __init__(self, version):
+ def __init__(self, version: WritableVersion):
# We tell super() that it's a replacement as we don't want it
# to copy the nodes, as we're about to do that with an
# immutable Dict.
# it might not exist if we deleted it in the version
if node:
version.nodes[name] = ImmutableVersionedNode(node)
- self.nodes = dns.immutable.Dict(version.nodes, True)
+ # We're changing the type of the nodes dictionary here on purpose, so
+ # we ignore the mypy error.
+ self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore
class Transaction(dns.transaction.Transaction):
return (absolute, relativize, effective)
-def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
- relativize=True, zone_factory=Zone, filename=None,
- allow_include=False, check_origin=True, idna_codec=None):
+def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
+ rdclass=dns.rdataclass.IN,
+ relativize=True, zone_factory=Zone, filename: Optional[str]=None,
+ allow_include=False, check_origin=True,
+ idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone:
"""Build a zone object from a zone file format string.
*text*, a ``str``, the zone file format input.
of the zone; if not specified, the first ``$ORIGIN`` statement in the
zone file will determine the origin of the zone.
- *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
+ *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is class IN.
*relativize*, a ``bool``, determine's whether domain names are
relativized to the zone's origin. The default is ``True``.
return zone
-def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
- relativize=True, zone_factory=Zone, filename=None,
- allow_include=True, check_origin=True):
+def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
+ rdclass=dns.rdataclass.IN,
+ relativize=True, zone_factory=Zone, filename: Optional[str]=None,
+ allow_include=True, check_origin=True) -> Zone:
"""Read a zone file and build a zone object.
*f*, a file or ``str``. If *f* is a string, it is treated
f = stack.enter_context(open(f))
return from_text(f, origin, rdclass, relativize, zone_factory,
filename, allow_include, check_origin)
+ assert False # make mypy happy
def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
+++ /dev/null
-from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict
-from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype
-
-class BadZone(exception.DNSException): ...
-class NoSOA(BadZone): ...
-class NoNS(BadZone): ...
-class UnknownOrigin(BadZone): ...
-
-class Zone:
- def __getitem__(self, key : str) -> node.Node:
- ...
- def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None:
- self.nodes : Dict[str,node.Node]
- self.origin = origin
- def values(self):
- return self.nodes.values()
- def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]:
- ...
- def __iter__(self) -> Iterator[str]:
- ...
- def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]:
- ...
- def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset:
- ...
- def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE,
- create=False) -> rdataset.Rdataset:
- ...
- def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]:
- ...
- def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]:
- ...
- def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None:
- ...
- def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None:
- ...
- def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY,
- covers : Union[str,int] =rdatatype.NONE):
- ...
- def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None):
- ...
- def to_text(self, sorted=True, relativize=True, nl : Optional[str] = None) -> str:
- ...
-
-def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True):
- ...
-
-def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN,
- relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None,
- allow_include=False, check_origin=True) -> zone.Zone:
- ...
-
-def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN,
- relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None,
- allow_include=True, check_origin=True) -> zone.Zone:
- ...
"""DNS Zones."""
+from typing import Any, List, Optional, Tuple, Union
+
import re
import sys
# adding the rdataset is ok
+SavedStateType = Tuple[dns.tokenizer.Tokenizer,
+ Optional[dns.name.Name], # current_origin
+ Optional[dns.name.Name], # last_name
+ Optional[str], # current_file
+ int, # last_ttl
+ bool, # last_ttl_known
+ int, # default_ttl
+ bool] # default_ttl_known
+
+
class Reader:
"""Read a DNS zone file into a transaction."""
- def __init__(self, tok, rdclass, txn, allow_include=False,
- allow_directives=True, force_name=None,
- force_ttl=None, force_rdclass=None, force_rdtype=None,
- default_ttl=None):
+ def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass,
+ txn: dns.transaction.Transaction, allow_include=False,
+ allow_directives=True, force_name: Optional[dns.name.Name]=None,
+ force_ttl: Optional[int]=None,
+ force_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+ force_rdtype: Optional[dns.rdatatype.RdataType]=None,
+ default_ttl: Optional[int]=None):
self.tok = tok
(self.zone_origin, self.relativize, _) = \
txn.manager.origin_information()
self.last_name = self.current_origin
self.zone_rdclass = rdclass
self.txn = txn
- self.saved_state = []
+ self.saved_state: List[SavedStateType] = []
self.current_file = None
self.allow_include = allow_include
self.allow_directives = allow_directives
self.rrsets = rrsets
-def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN,
- default_rdclass=dns.rdataclass.IN,
- rdtype=None, default_ttl=None, idna_codec=None,
- origin=dns.name.root, relativize=False):
+def read_rrsets(text: Any,
+ name: Optional[Union[dns.name.Name, str]]=None,
+ ttl: Optional[int]=None,
+ rdclass: Optional[Union[dns.rdataclass.RdataClass, str]]=dns.rdataclass.IN,
+ default_rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+ rdtype: Optional[Union[dns.rdatatype.RdataType, str]]=None,
+ default_ttl: Optional[Union[int, str]]=None,
+ idna_codec: Optional[dns.name.IDNACodec]=None,
+ origin: Optional[Union[dns.name.Name, str]]=dns.name.root,
+ relativize=False) -> List[dns.rrset.RRset]:
"""Read one or more rrsets from the specified text, possibly subject
to restrictions.
if isinstance(default_ttl, str):
default_ttl = dns.ttl.from_text(default_ttl)
if rdclass is not None:
- rdclass = dns.rdataclass.RdataClass.make(rdclass)
- default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+ the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ else:
+ the_rdclass = None
+ the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
if rdtype is not None:
- rdtype = dns.rdatatype.RdataType.make(rdtype)
+ the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+ else:
+ the_rdtype = None
manager = RRSetsReaderManager(origin, relativize, default_rdclass)
with manager.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, '<input>', idna_codec=idna_codec)
- reader = Reader(tok, default_rdclass, txn, allow_directives=False,
- force_name=name, force_ttl=ttl, force_rdclass=rdclass,
- force_rdtype=rdtype, default_ttl=default_ttl)
+ reader = Reader(tok, the_default_rdclass, txn, allow_directives=False,
+ force_name=name, force_ttl=ttl, force_rdclass=the_rdclass,
+ force_rdtype=the_rdtype, default_ttl=default_ttl)
reader.read()
return manager.rrsets
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""Common zone-related types."""
+
+# This is a separate file to avoid import circularity between dns.zone and
+# the implementation of the ZONEMD type.
+
+import hashlib
+
+import dns.enum
+
+
+class DigestScheme(dns.enum.IntEnum):
+ """ZONEMD Scheme"""
+
+ SIMPLE = 1
+
+ @classmethod
+ def _maximum(cls):
+ return 255
+
+
+class DigestHashAlgorithm(dns.enum.IntEnum):
+ """ZONEMD Hash Algorithm"""
+
+ SHA384 = 1
+ SHA512 = 2
+
+ @classmethod
+ def _maximum(cls):
+ return 255
+
+
+_digest_hashers = {
+ DigestHashAlgorithm.SHA384: hashlib.sha384,
+ DigestHashAlgorithm.SHA512: hashlib.sha512,
+}
async
exceptions
utilities
- typing
threads
examples
+++ /dev/null
-.. _typing:
-
-A Note on Typing
-----------------
-
-Dnspython has partial support for type annotations in separate .pyi
-files. Type information will not be integrated into the main files
-until major LTS versions of various Linux distributions containing 3.6
-are beyond their support times. Improvements to the .pyi files are
-welcome during this time.
[mypy-requests_toolbelt.*]
ignore_missing_imports = True
+
+[mypy-curio]
+ignore_missing_imports = True
+
+[mypy-trio]
+ignore_missing_imports = True
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+from typing import Any
+
import unittest
import dns.dnssec
import dns.rdata
import dns.rdataclass
import dns.rdatatype
+import dns.rdtypes.ANY.CDS
+import dns.rdtypes.ANY.DS
import dns.rrset
# pylint: disable=line-too-long
self.assertEqual(good_ds, good_ds_mnemonic)
def testMakeExampleSHA1DS(self): # type: () -> None
+ algorithm: Any
for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha1)
self.assertEqual(ds, example_ds_sha1)
def testMakeExampleSHA256DS(self): # type: () -> None
+ algorithm: Any
for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha256)
def testMakeExampleSHA384DS(self): # type: () -> None
+ algorithm: Any
for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
self.assertEqual(ds, example_ds_sha384)
self.assertEqual(ds, good_ds)
def testInvalidAlgorithm(self): # type: () -> None
+ algorithm: Any
for algorithm in (10, 'shax'):
with self.assertRaises(dns.dnssec.UnsupportedAlgorithm):
ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
for rdtype in digest_types:
rd = dns.rdata.from_text(dns.rdataclass.IN, rdtype,
f'18673 3 5 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7')
+ assert isinstance(rd, dns.rdtypes.ANY.DS.DS) or isinstance(rd, dns.rdtypes.ANY.CDS.CDS)
self.assertEqual(rd.digest_type, 5)
self.assertEqual(rd.digest, bytes.fromhex('71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7'))
try:
dns.name.from_text(t)
except Exception:
- self.fail("good test '%s' raised an exception" % t)
+ self.fail("good test '%r' raised an exception" % t)
for t in bad:
caught = False
try:
except Exception:
caught = True
if not caught:
- self.fail("bad test '%s' did not raise an exception" % t)
+ self.fail("bad test '%r' did not raise an exception" % t)
def testImmutable1(self):
def bad():
def testImmutable2(self):
def bad():
- self.origin.labels[0] = 'foo'
+ self.origin.labels[0] = 'foo' # type: ignore
self.assertRaises(TypeError, bad)
def testAbs1(self):
def testReverseIPv6(self):
e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')
- n = dns.reversename.from_address(b'::1')
+ n = dns.reversename.from_address('::1')
self.assertEqual(e, n)
def testReverseIPv6MappedIpv4(self):
def testReverseIPv6AlternateOrigin(self):
e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar.')
origin = dns.name.from_text('foo.bar')
- n = dns.reversename.from_address(b'::1', v6_origin=origin)
+ n = dns.reversename.from_address('::1', v6_origin=origin)
self.assertEqual(e, n)
def testForwardIPv4(self):
def testFromUnicodeNotString(self):
def bad():
- dns.name.from_unicode(b'123')
+ dns.name.from_unicode(b'123') # type: ignore
self.assertRaises(ValueError, bad)
def testFromUnicodeBadOrigin(self):
def bad():
- dns.name.from_unicode('example', 123)
+ dns.name.from_unicode('example', 123) # type: ignore
self.assertRaises(ValueError, bad)
def testFromUnicodeEmptyLabel(self):
def testFromTextNotString(self):
def bad():
- dns.name.from_text(123)
+ dns.name.from_text(123) # type: ignore
self.assertRaises(ValueError, bad)
def testFromTextBadOrigin(self):
def bad():
- dns.name.from_text('example', 123)
+ dns.name.from_text('example', 123) # type: ignore
self.assertRaises(ValueError, bad)
def testFromWireNotBytes(self):
def bad():
- dns.name.from_wire(123, 0)
+ dns.name.from_wire(123, 0) # type: ignore
self.assertRaises(ValueError, bad)
def testBadPunycode(self):
c.encode('Königsgäßchen')
with self.assertRaises(dns.name.NoIDNA2008):
c = dns.name.IDNA2008Codec(strict_decode=True)
- c.decode('xn--eckwd4c7c.xn--zckzah.')
+ c.decode(b'xn--eckwd4c7c.xn--zckzah.')
dns.name.have_idna_2008 = True
@unittest.skipUnless(dns.name.have_idna_2008,
import dns.rdata
import dns.rdataset
+import dns.rdtypes.IN.SRV
def test_processing_order_shuffle():
for j in range(3):
assert rds[j] in po
assert rds[0] == po[0]
+ assert isinstance(po[1], dns.rdtypes.IN.SRV.SRV)
if po[1].weight == 90:
weight_90_count += 1
else:
rds = z.get_rdataset('@', 'loc')
self.assertTrue(rds is None)
+ def testGetRdatasetWithRelativeNameFromAbsoluteZone(self):
+ z = dns.zone.from_text(example_text, 'example.', relativize=False)
+ rds = z.get_rdataset(dns.name.empty, 'soa')
+ self.assertIsNotNone(rds)
+ exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo.example. bar.example. 1 2 3 4 5')
+ self.assertEqual(rds, exrds)
+
def testGetRRset1(self):
z = dns.zone.from_text(example_text, 'example.', relativize=True)
rrs = z.get_rrset('@', 'soa')
self.assertTrue(soa.rdtype, dns.rdatatype.SOA)
self.assertEqual(soa.serial, 1)
-
def testGetSoaEmptyZone(self):
z = dns.zone.Zone('example.')
with self.assertRaises(dns.zone.NoSOA):