]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add integrated typing to much of dnspython. 785/head
authorBob Halley <halley@dnspython.org>
Fri, 25 Feb 2022 21:29:09 +0000 (13:29 -0800)
committerBob Halley <halley@dnspython.org>
Sat, 5 Mar 2022 20:39:50 +0000 (12:39 -0800)
79 files changed:
Makefile
dns/__init__.py
dns/_asyncbackend.py
dns/_asyncio_backend.py
dns/_curio_backend.py
dns/_trio_backend.py
dns/asyncbackend.py
dns/asyncbackend.pyi [deleted file]
dns/asyncquery.py
dns/asyncquery.pyi [deleted file]
dns/asyncresolver.py
dns/asyncresolver.pyi [deleted file]
dns/dnssec.py
dns/dnssec.pyi [deleted file]
dns/dnssectypes.py [new file with mode: 0644]
dns/e164.py
dns/e164.pyi [deleted file]
dns/edns.py
dns/entropy.py
dns/entropy.pyi [deleted file]
dns/exception.py
dns/exception.pyi [deleted file]
dns/flags.py
dns/grange.py
dns/inet.py
dns/inet.pyi [deleted file]
dns/ipv4.py
dns/ipv6.py
dns/message.py
dns/message.pyi [deleted file]
dns/name.py
dns/name.pyi [deleted file]
dns/node.py
dns/node.pyi [deleted file]
dns/opcode.py
dns/query.py
dns/query.pyi [deleted file]
dns/rdata.py
dns/rdata.pyi [deleted file]
dns/rdataset.py
dns/rdataset.pyi [deleted file]
dns/rdatatype.py
dns/rdtypes/ANY/CERT.py
dns/rdtypes/ANY/RRSIG.py
dns/rdtypes/ANY/TKEY.py
dns/rdtypes/ANY/ZONEMD.py
dns/rdtypes/dnskeybase.py
dns/rdtypes/dnskeybase.pyi [deleted file]
dns/rdtypes/dsbase.py
dns/rdtypes/txtbase.py
dns/rdtypes/txtbase.pyi [deleted file]
dns/resolver.py
dns/resolver.pyi [deleted file]
dns/reversename.py
dns/reversename.pyi [deleted file]
dns/rrset.py
dns/rrset.pyi [deleted file]
dns/serial.py
dns/tokenizer.py
dns/transaction.py
dns/tsigkeyring.py
dns/tsigkeyring.pyi [deleted file]
dns/ttl.py
dns/update.py
dns/update.pyi [deleted file]
dns/versioned.py
dns/wire.py
dns/xfr.py
dns/zone.py
dns/zone.pyi [deleted file]
dns/zonefile.py
dns/zonetypes.py [new file with mode: 0644]
doc/manual.rst
doc/typing.rst [deleted file]
mypy.ini
tests/test_dnssec.py
tests/test_name.py
tests/test_processing_order.py
tests/test_zone.py

index 76e70286284ad1de87350713d974724f79d9f578..fe4e8bd9f012216bffb23a2cded296d93fd2c270 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -56,7 +56,10 @@ potestlf:
        poetry run pytest --lf
 
 potype:
-       poetry run python -m mypy examples tests dns/*.py
+       poetry run python -m mypy dns/*.py
+
+potypetests:
+       poetry run python -m mypy --check-untyped-defs examples tests
 
 polint:
        poetry run pylint dns
index 0473ca175cb6b8ef20a3fac55b62ff5c9fcbb898..a620f97575f2e98513ad1f9de208af307bae6d5d 100644 (file)
@@ -22,6 +22,7 @@ __all__ = [
     'asyncquery',
     'asyncresolver',
     'dnssec',
+    'dnssectypes',
     'e164',
     'edns',
     'entropy',
@@ -60,6 +61,7 @@ __all__ = [
     'wire',
     'xfr',
     'zone',
+    'zonetypes',
     'zonefile',
 ]
 
index 1f3a8287174f3381b17e696e9a187924c5d6258e..674bf6eaa7d91da7428d47cd1e7282ce8320bd95 100644 (file)
@@ -41,6 +41,9 @@ class Socket:  # pragma: no cover
 
 
 class DatagramSocket(Socket):  # pragma: no cover
+    def __init__(self, family: int):
+        self.family = family
+
     async def sendto(self, what, destination, timeout):
         raise NotImplementedError
 
@@ -67,3 +70,6 @@ class Backend:    # pragma: no cover
 
     def datagram_connection_required(self):
         return False
+
+    async def sleep(self, interval):
+        raise NotImplementedError
index d737d13c7e4421d05d4510164819a654786963ab..9d458da0cd1819e23072171243bcfdbb40079b68 100644 (file)
@@ -55,8 +55,8 @@ async def _maybe_wait_for(awaitable, timeout):
 
 
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
-    def __init__(self, family, transport, protocol):
-        self.family = family
+    def __init__(self, family: int, transport, protocol):
+        super().__init__(family)
         self.transport = transport
         self.protocol = protocol
 
index 535eb84d22257c0b0ba4b7b94fd458335c28d46c..3f22b5d3e777f0de08190242420c8eda731f6405 100644 (file)
@@ -26,8 +26,8 @@ _lltuple = dns.inet.low_level_address_tuple
 
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, socket):
+        super().__init__(socket.family)
         self.socket = socket
-        self.family = socket.family
 
     async def sendto(self, what, destination, timeout):
         async with _maybe_timeout(timeout):
index 863d413e84ff7dd0177418a74495456f7860733f..8a337e9d872ee8406cafb75e8c87ef6028779ee9 100644 (file)
@@ -26,8 +26,8 @@ _lltuple = dns.inet.low_level_address_tuple
 
 class DatagramSocket(dns._asyncbackend.DatagramSocket):
     def __init__(self, socket):
+        super().__init__(socket.family)
         self.socket = socket
-        self.family = socket.family
 
     async def sendto(self, what, destination, timeout):
         with _maybe_timeout(timeout):
index ad79a572b2827a9ef0c69f75dbad1af503302261..a8f794ac6fe80c76ef687fde0000ec5a7edd89fa 100644 (file)
@@ -1,5 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+from typing import Dict
+
 import dns.exception
 
 # pylint: disable=unused-import
@@ -10,7 +12,7 @@ from dns._asyncbackend import Socket, DatagramSocket, StreamSocket, Backend  # n
 
 _default_backend = None
 
-_backends = {}
+_backends: Dict[str, Backend] = {}
 
 # Allow sniffio import to be disabled for testing purposes
 _no_sniffio = False
@@ -19,7 +21,7 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
     pass
 
 
-def get_backend(name):
+def get_backend(name: str) -> Backend:
     """Get the specified asynchronous backend.
 
     *name*, a ``str``, the name of the backend.  Currently the "trio",
@@ -46,7 +48,7 @@ def get_backend(name):
     return backend
 
 
-def sniff():
+def sniff() -> str:
     """Attempt to determine the in-use asynchronous I/O library by using
     the ``sniffio`` module if it is available.
 
@@ -71,13 +73,14 @@ def sniff():
         except RuntimeError:
             raise AsyncLibraryNotFoundError('no async library detected')
         except AttributeError:  # pragma: no cover
-            # we have to check current_task on 3.6
-            if not asyncio.Task.current_task():
+            # we have to check current_task on 3.6; we ignore for mypy
+            # purposes at it is otherwise unhappy on >= 3.7
+            if not asyncio.Task.current_task():  # type: ignore
                 raise AsyncLibraryNotFoundError('no async library detected')
             return 'asyncio'
 
 
-def get_default_backend():
+def get_default_backend() -> Backend:
     """Get the default backend, initializing it if necessary.
     """
     if _default_backend:
@@ -86,7 +89,7 @@ def get_default_backend():
     return set_default_backend(sniff())
 
 
-def set_default_backend(name):
+def set_default_backend(name: str):
     """Set the default backend.
 
     It's not normally necessary to call this method, as
diff --git a/dns/asyncbackend.pyi b/dns/asyncbackend.pyi
deleted file mode 100644 (file)
index 1ec9d32..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-class Backend:
-    ...
-
-def get_backend(name: str) -> Backend:
-    ...
-def sniff() -> str:
-    ...
-def get_default_backend() -> Backend:
-    ...
-def set_default_backend(name: str) -> Backend:
-    ...
index 13f687fb37501d9c42191cf2b3d4973b7bd65b39..8c35d1aabc8cde308edab1e21281964503047874 100644 (file)
@@ -17,6 +17,8 @@
 
 """Talk to a DNS server."""
 
+from typing import Any, Dict, Optional, Tuple, Union
+
 import base64
 import socket
 import struct
@@ -31,6 +33,7 @@ import dns.message
 import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
+import dns.transaction
 
 from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
     UDPMode, _have_httpx, _have_http2, NoDOH
@@ -67,7 +70,9 @@ def _timeout(expiration, now=None):
         return None
 
 
-async def send_udp(sock, what, destination, expiration=None):
+async def send_udp(sock: dns.asyncbackend.DatagramSocket,
+                   what: Union[dns.message.Message, bytes], destination: Any,
+                   expiration: Optional[float]=None) -> Tuple[int, float]:
     """Send a DNS message to the specified UDP socket.
 
     *sock*, a ``dns.asyncbackend.DatagramSocket``.
@@ -91,10 +96,11 @@ async def send_udp(sock, what, destination, expiration=None):
     return (n, sent_time)
 
 
-async def receive_udp(sock, destination=None, expiration=None,
+async def receive_udp(sock: dns.asyncbackend.DatagramSocket,
+                      destination: Optional[Any]=None, expiration: Optional[float]=None,
                       ignore_unexpected=False, one_rr_per_rrset=False,
-                      keyring=None, request_mac=b'', ignore_trailing=False,
-                      raise_on_truncation=False):
+                      keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+                      ignore_trailing=False, raise_on_truncation=False) -> Any:
     """Read a DNS message from a UDP socket.
 
     *sock*, a ``dns.asyncbackend.DatagramSocket``.
@@ -116,10 +122,11 @@ async def receive_udp(sock, destination=None, expiration=None,
                               raise_on_truncation=raise_on_truncation)
     return (r, received_time, from_address)
 
-async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
-              ignore_unexpected=False, one_rr_per_rrset=False,
-              ignore_trailing=False, raise_on_truncation=False, sock=None,
-              backend=None):
+async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+              source: Optional[str]=None, source_port=0,
+              ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+              raise_on_truncation=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None,
+              backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via UDP.
 
     *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@@ -152,6 +159,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
                 dtuple = None
             s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
                                           dtuple)
+        assert s is not None
         await send_udp(s, wire, destination, expiration)
         (r, received_time, _) = await receive_udp(s, destination, expiration,
                                                   ignore_unexpected,
@@ -167,10 +175,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
         if not sock and s:
             await s.close()
 
-async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
-                            source_port=0, ignore_unexpected=False,
-                            one_rr_per_rrset=False, ignore_trailing=False,
-                            udp_sock=None, tcp_sock=None, backend=None):
+async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+                            source: Optional[str]=None, source_port=0,
+                            ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+                            udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None,
+                            tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None,
+                            backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]:
     """Return the response to the query, trying UDP first and falling back
     to TCP if UDP results in a truncated response.
 
@@ -203,7 +213,9 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
         return (response, True)
 
 
-async def send_tcp(sock, what, expiration=None):
+async def send_tcp(sock: dns.asyncbackend.StreamSocket,
+                   what: Union[dns.message.Message, bytes],
+                   expiration: Optional[float]=None) -> Tuple[int, float]:
     """Send a DNS message to the specified TCP socket.
 
     *sock*, a ``dns.asyncbackend.StreamSocket``.
@@ -213,12 +225,14 @@ async def send_tcp(sock, what, expiration=None):
     """
 
     if isinstance(what, dns.message.Message):
-        what = what.to_wire()
-    l = len(what)
+        wire = what.to_wire()
+    else:
+        wire = what
+    l = len(wire)
     # copying the wire into tcpmsg is inefficient, but lets us
     # avoid writev() or doing a short write that would get pushed
     # onto the net
-    tcpmsg = struct.pack("!H", l) + what
+    tcpmsg = struct.pack("!H", l) + wire
     sent_time = time.time()
     await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
     return (len(tcpmsg), sent_time)
@@ -238,8 +252,10 @@ async def _read_exactly(sock, count, expiration):
     return s
 
 
-async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
-                      keyring=None, request_mac=b'', ignore_trailing=False):
+async def receive_tcp(sock: dns.asyncbackend.StreamSocket,
+                      expiration: Optional[float]=None, one_rr_per_rrset=False,
+                      keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
+                      request_mac=b'', ignore_trailing=False) -> Tuple[dns.message.Message, float]:
     """Read a DNS message from a TCP socket.
 
     *sock*, a ``dns.asyncbackend.StreamSocket``.
@@ -258,9 +274,11 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
     return (r, received_time)
 
 
-async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
-              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
-              backend=None):
+async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+              source: Optional[str]=None, source_port=0,
+              one_rr_per_rrset=False, ignore_trailing=False,
+              sock: Optional[dns.asyncbackend.StreamSocket]=None,
+              backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TCP.
 
     *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@@ -297,6 +315,7 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
                 backend = dns.asyncbackend.get_default_backend()
             s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
                                           dtuple, timeout)
+        assert s is not None
         await send_tcp(s, wire, expiration)
         (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
                                                q.keyring, q.mac,
@@ -309,9 +328,13 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
         if not sock and s:
             await s.close()
 
-async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
-              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
-              backend=None, ssl_context=None, server_hostname=None):
+async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+              port=853, source: Optional[str]=None, source_port=0,
+              one_rr_per_rrset=False, ignore_trailing=False,
+              sock: Optional[dns.asyncbackend.StreamSocket]=None,
+              backend: Optional[dns.asyncbackend.Backend]=None,
+              ssl_context: Optional[ssl.SSLContext]=None,
+              server_hostname: Optional[str]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TLS.
 
     *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@@ -363,8 +386,10 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
         if not sock and s:
             await s.close()
 
-async def https(q, where, timeout=None, port=443, source=None, source_port=0,
-                one_rr_per_rrset=False, ignore_trailing=False, client=None,
+async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+                port=443, source: Optional[str]=None, source_port=0,
+                one_rr_per_rrset=False, ignore_trailing=False,
+                client: Optional[httpx.AsyncClient]=None,
                 path='/dns-query', post=True, verify=True):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -419,18 +444,18 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
                                          timeout=timeout)
         else:
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
-            wire = wire.decode()  # httpx does a repr() if we give it bytes
+            twire = wire.decode()  # httpx does a repr() if we give it bytes
             response = await client.get(url, headers=headers, timeout=timeout,
-                                        params={"dns": wire})
+                                        params={"dns": twire})
     finally:
         if client_to_close:
-            await client.aclose()
+            await client_to_close.aclose()
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
     # status codes
     if response.status_code < 200 or response.status_code > 299:
         raise ValueError('{} responded with status code {}'
-                         '\nResponse body: {}'.format(where,
+                         '\nResponse body: {!r}'.format(where,
                                                       response.status_code,
                                                       response.content))
     r = dns.message.from_wire(response.content,
@@ -438,14 +463,16 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
                               request_mac=q.request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
                               ignore_trailing=ignore_trailing)
-    r.time = response.elapsed
+    r.time = response.elapsed.total_seconds()
     if not q.is_response(r):
         raise BadResponse
     return r
 
-async def inbound_xfr(where, txn_manager, query=None,
-                      port=53, timeout=None, lifetime=None, source=None,
-                      source_port=0, udp_mode=UDPMode.NEVER, backend=None):
+async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
+                      query: Optional[dns.message.Message]=None,
+                      port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+                      source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER,
+                      backend: Optional[dns.asyncbackend.Backend]=None):
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
diff --git a/dns/asyncquery.pyi b/dns/asyncquery.pyi
deleted file mode 100644 (file)
index a03434c..0000000
+++ /dev/null
@@ -1,43 +0,0 @@
-from typing import Optional, Union, Dict, Generator, Any
-from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
-
-# If the ssl import works, then
-#
-#    error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
-    import ssl
-except ImportError:
-    class ssl:    # type: ignore
-        SSLContext : Dict = {}
-
-async def udp(q : message.Message, where : str,
-              timeout : Optional[float] = None, port=53,
-              source : Optional[str] = None, source_port : Optional[int] = 0,
-              ignore_unexpected : Optional[bool] = False,
-              one_rr_per_rrset : Optional[bool] = False,
-              ignore_trailing : Optional[bool] = False,
-              sock : Optional[asyncbackend.DatagramSocket] = None,
-              backend : Optional[asyncbackend.Backend] = None) -> message.Message:
-    pass
-
-async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
-        af : Optional[int] = None, source : Optional[str] = None,
-        source_port : Optional[int] = 0,
-        one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False,
-        sock : Optional[asyncbackend.StreamSocket] = None,
-        backend : Optional[asyncbackend.Backend] = None) -> message.Message:
-    pass
-
-async def tls(q : message.Message, where : str,
-              timeout : Optional[float] = None, port=53,
-              source : Optional[str] = None, source_port : Optional[int] = 0,
-              one_rr_per_rrset : Optional[bool] = False,
-              ignore_trailing : Optional[bool] = False,
-              sock : Optional[asyncbackend.StreamSocket] = None,
-              backend : Optional[asyncbackend.Backend] = None,
-              ssl_context: Optional[ssl.SSLContext] = None,
-              server_hostname: Optional[str] = None) -> message.Message:
-    pass
index b483756744d451a5746f973a614729b7b7708f9d..72ef0412c55bc84a7eaab1a6acaa51682d2d4b73 100644 (file)
 
 """Asynchronous DNS stub resolver."""
 
+from typing import Optional, Union
+
 import time
 
 import dns.asyncbackend
 import dns.asyncquery
 import dns.exception
+import dns.name
 import dns.query
 import dns.resolver  # lgtm[py/import-and-import-from]
 
@@ -37,11 +40,13 @@ _tcp = dns.asyncquery.tcp
 class Resolver(dns.resolver.BaseResolver):
     """Asynchronous DNS stub resolver."""
 
-    async def resolve(self, qname, rdtype=dns.rdatatype.A,
+    async def resolve(self, qname: Union[dns.name.Name, str],
+                      rdtype=dns.rdatatype.A,
                       rdclass=dns.rdataclass.IN,
-                      tcp=False, source=None, raise_on_no_answer=True,
-                      source_port=0, lifetime=None, search=None,
-                      backend=None):
+                      tcp=False, source: Optional[str]=None,
+                      raise_on_no_answer=True, source_port=0,
+                      lifetime: Optional[float]=None, search: Optional[bool]=None,
+                      backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
         """Query nameservers asynchronously to find the answer to the question.
 
         *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
@@ -66,6 +71,7 @@ class Resolver(dns.resolver.BaseResolver):
             if answer is not None:
                 # cache hit!
                 return answer
+            assert request is not None  # needed for type checking
             done = False
             while not done:
                 (nameserver, port, tcp, backoff) = resolution.next_nameserver()
@@ -101,7 +107,7 @@ class Resolver(dns.resolver.BaseResolver):
                 if answer is not None:
                     return answer
 
-    async def resolve_address(self, ipaddr, *args, **kwargs):
+    async def resolve_address(self, ipaddr: str, *args, **kwargs) -> dns.resolver.Answer:
         """Use an asynchronous resolver to run a reverse query for PTR
         records.
 
@@ -116,15 +122,19 @@ class Resolver(dns.resolver.BaseResolver):
         function.
 
         """
-
+        # We make a modified kwargs for type checking happiness, as otherwise
+        # we get a legit warning about possibly having rdtype and rdclass
+        # in the kwargs more than once.
+        modified_kwargs = {}
+        modified_kwargs.update(kwargs)
+        modified_kwargs['rdtype'] = dns.rdatatype.PTR
+        modified_kwargs['rdclass'] = dns.rdataclass.IN
         return await self.resolve(dns.reversename.from_address(ipaddr),
-                                  rdtype=dns.rdatatype.PTR,
-                                  rdclass=dns.rdataclass.IN,
-                                  *args, **kwargs)
+                                  *args, **modified_kwargs)
 
     # pylint: disable=redefined-outer-name
 
-    async def canonical_name(self, name):
+    async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
         """Determine the canonical name of *name*.
 
         The canonical name is the name the resolver uses for queries
@@ -149,10 +159,11 @@ class Resolver(dns.resolver.BaseResolver):
 default_resolver = None
 
 
-def get_default_resolver():
+def get_default_resolver() -> Resolver:
     """Get the default asynchronous resolver, initializing it if necessary."""
     if default_resolver is None:
         reset_default_resolver()
+    assert default_resolver is not None
     return default_resolver
 
 
@@ -167,9 +178,13 @@ def reset_default_resolver():
     default_resolver = Resolver()
 
 
-async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-                  tcp=False, source=None, raise_on_no_answer=True,
-                  source_port=0, lifetime=None, search=None, backend=None):
+async def resolve(qname: Union[dns.name.Name, str],
+                  rdtype=dns.rdatatype.A,
+                  rdclass=dns.rdataclass.IN,
+                  tcp=False, source: Optional[str]=None,
+                  raise_on_no_answer=True, source_port=0,
+                  lifetime: Optional[float]=None, search: Optional[bool]=None,
+                  backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
     """Query nameservers asynchronously to find the answer to the question.
 
     This is a convenience function that uses the default resolver
@@ -185,7 +200,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
                                                 backend)
 
 
-async def resolve_address(ipaddr, *args, **kwargs):
+async def resolve_address(ipaddr: str, *args, **kwargs) -> dns.resolver.Answer:
     """Use a resolver to run a reverse query for PTR records.
 
     See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
@@ -194,7 +209,7 @@ async def resolve_address(ipaddr, *args, **kwargs):
 
     return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
 
-async def canonical_name(name):
+async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
     """Determine the canonical name of *name*.
 
     See :py:func:`dns.resolver.Resolver.canonical_name` for more
@@ -203,8 +218,9 @@ async def canonical_name(name):
 
     return await get_default_resolver().canonical_name(name)
 
-async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
-                        resolver=None, backend=None):
+async def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
+                        tcp=False, resolver: Optional[Resolver]=None,
+                        backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name:
     """Find the name of the zone which contains the specified name.
 
     See :py:func:`dns.resolver.Resolver.zone_for_name` for more
@@ -221,6 +237,7 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
         try:
             answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
                                             tcp, backend=backend)
+            assert answer.rrset is not None
             if answer.rrset.name == name:
                 return name
             # otherwise we were CNAMEd or DNAMEd and need to look higher
diff --git a/dns/asyncresolver.pyi b/dns/asyncresolver.pyi
deleted file mode 100644 (file)
index 92759d2..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union, Optional, List, Any, Dict
-from . import exception, rdataclass, name, rdatatype, asyncbackend
-
-async def resolve(qname : str, rdtype : Union[int,str] = 0,
-                  rdclass : Union[int,str] = 0,
-                  tcp=False, source=None, raise_on_no_answer=True,
-                  source_port=0, lifetime : Optional[float]=None,
-                  search : Optional[bool]=None,
-                  backend : Optional[asyncbackend.Backend]=None):
-    ...
-async def resolve_address(self, ipaddr: str,
-                          *args: Any, **kwargs: Optional[Dict]):
-    ...
-
-class Resolver:
-    def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
-                 configure : Optional[bool] = True):
-        self.nameservers : List[str]
-    async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
-                      rdclass : Union[int,str] = rdataclass.IN,
-                      tcp : bool = False, source : Optional[str] = None,
-                      raise_on_no_answer=True, source_port : int = 0,
-                      lifetime : Optional[float]=None,
-                      search : Optional[bool]=None,
-                      backend : Optional[asyncbackend.Backend]=None):
-        ...
index dee4e61813216f377703fc61f1132865a89254b4..bb20005e6ea804a4a5251c186ff0ea016260c558 100644 (file)
 
 """Common DNSSEC-related functions and constants."""
 
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
+
 import hashlib
 import struct
 import time
 import base64
 
-import dns.enum
+from dns.dnssectypes import *
+
 import dns.exception
 import dns.name
 import dns.node
@@ -30,6 +33,10 @@ import dns.rdataset
 import dns.rdata
 import dns.rdatatype
 import dns.rdataclass
+import dns.rrset
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+from dns.rdtypes.ANY.DS import DS
+from dns.rdtypes.ANY.RRSIG import RRSIG
 
 
 class UnsupportedAlgorithm(dns.exception.DNSException):
@@ -40,31 +47,7 @@ class ValidationFailure(dns.exception.DNSException):
     """The DNSSEC signature is invalid."""
 
 
-class Algorithm(dns.enum.IntEnum):
-    RSAMD5 = 1
-    DH = 2
-    DSA = 3
-    ECC = 4
-    RSASHA1 = 5
-    DSANSEC3SHA1 = 6
-    RSASHA1NSEC3SHA1 = 7
-    RSASHA256 = 8
-    RSASHA512 = 10
-    ECCGOST = 12
-    ECDSAP256SHA256 = 13
-    ECDSAP384SHA384 = 14
-    ED25519 = 15
-    ED448 = 16
-    INDIRECT = 252
-    PRIVATEDNS = 253
-    PRIVATEOID = 254
-
-    @classmethod
-    def _maximum(cls):
-        return 255
-
-
-def algorithm_from_text(text):
+def algorithm_from_text(text: str) -> Algorithm:
     """Convert text into a DNSSEC algorithm value.
 
     *text*, a ``str``, the text to convert to into an algorithm value.
@@ -75,10 +58,10 @@ def algorithm_from_text(text):
     return Algorithm.from_text(text)
 
 
-def algorithm_to_text(value):
+def algorithm_to_text(value: Union[Algorithm, int]) -> str:
     """Convert a DNSSEC algorithm value to text
 
-    *value*, an ``int`` a DNSSEC algorithm.
+    *value*, a ``dns.dnssec.Algorithm``.
 
     Returns a ``str``, the name of a DNSSEC algorithm.
     """
@@ -86,7 +69,7 @@ def algorithm_to_text(value):
     return Algorithm.to_text(value)
 
 
-def key_id(key):
+def key_id(key: DNSKEY) -> int:
     """Return the key id (a 16-bit number) for the specified key.
 
     *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``
@@ -107,19 +90,10 @@ def key_id(key):
         total += ((total >> 16) & 0xffff)
         return total & 0xffff
 
-class DSDigest(dns.enum.IntEnum):
-    """DNSSEC Delegation Signer Digest Algorithm"""
-
-    SHA1 = 1
-    SHA256 = 2
-    SHA384 = 4
-
-    @classmethod
-    def _maximum(cls):
-        return 255
-
 
-def make_ds(name, key, algorithm, origin=None):
+def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
+            algorithm: Union[DSDigest, str],
+            origin: Optional[dns.name.Name]=None) -> DS:
     """Create a DS record for a DNSSEC key.
 
     *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record.
@@ -143,7 +117,8 @@ def make_ds(name, key, algorithm, origin=None):
             algorithm = DSDigest[algorithm.upper()]
     except Exception:
         raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
-
+    if not isinstance(key, DNSKEY):
+        raise ValueError('key is not a DNSKEY')
     if algorithm == DSDigest.SHA1:
         dshash = hashlib.sha1()
     elif algorithm == DSDigest.SHA256:
@@ -155,17 +130,20 @@ def make_ds(name, key, algorithm, origin=None):
 
     if isinstance(name, str):
         name = dns.name.from_text(name, origin)
-    dshash.update(name.canonicalize().to_wire())
+    wire = name.canonicalize().to_wire()
+    assert wire is not None
+    dshash.update(wire)
     dshash.update(key.to_wire(origin=origin))
     digest = dshash.digest()
 
     dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \
         digest
-    return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
-                               len(dsrdata))
+    ds = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
+                             len(dsrdata))
+    return cast(DS, ds)
 
 
-def _find_candidate_keys(keys, rrsig):
+def _find_candidate_keys(keys, rrsig: RRSIG) -> Optional[List[DNSKEY]]:
     value = keys.get(rrsig.signer)
     if isinstance(value, dns.node.Node):
         rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY)
@@ -173,54 +151,54 @@ def _find_candidate_keys(keys, rrsig):
         rdataset = value
     if rdataset is None:
         return None
-    return [rd for rd in rdataset if
+    return [cast(DNSKEY, rd) for rd in rdataset if
             rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag]
 
 
-def _is_rsa(algorithm):
+def _is_rsa(algorithm: int) -> bool:
     return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1,
                          Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256,
                          Algorithm.RSASHA512)
 
 
-def _is_dsa(algorithm):
+def _is_dsa(algorithm: int) -> bool:
     return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
 
 
-def _is_ecdsa(algorithm):
+def _is_ecdsa(algorithm: int) -> bool:
     return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
 
 
-def _is_eddsa(algorithm):
+def _is_eddsa(algorithm: int) -> bool:
     return algorithm in (Algorithm.ED25519, Algorithm.ED448)
 
 
-def _is_gost(algorithm):
+def _is_gost(algorithm: int) -> bool:
     return algorithm == Algorithm.ECCGOST
 
 
-def _is_md5(algorithm):
+def _is_md5(algorithm: int) -> bool:
     return algorithm == Algorithm.RSAMD5
 
 
-def _is_sha1(algorithm):
+def _is_sha1(algorithm: int) -> bool:
     return algorithm in (Algorithm.DSA, Algorithm.RSASHA1,
                          Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1)
 
 
-def _is_sha256(algorithm):
+def _is_sha256(algorithm: int) -> bool:
     return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
 
 
-def _is_sha384(algorithm):
+def _is_sha384(algorithm: int) -> bool:
     return algorithm == Algorithm.ECDSAP384SHA384
 
 
-def _is_sha512(algorithm):
+def _is_sha512(algorithm: int) -> bool:
     return algorithm == Algorithm.RSASHA512
 
 
-def _make_hash(algorithm):
+def _make_hash(algorithm: int) -> Any:
     if _is_md5(algorithm):
         return hashes.MD5()
     if _is_sha1(algorithm):
@@ -239,12 +217,14 @@ def _make_hash(algorithm):
     raise ValidationFailure('unknown hash for algorithm %u' % algorithm)
 
 
-def _bytes_to_long(b):
+def _bytes_to_long(b: bytes) -> int:
     return int.from_bytes(b, 'big')
 
 
-def _validate_signature(sig, data, key, chosen_hash):
+def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any):
+    keyptr: bytes
     if _is_rsa(key.algorithm):
+        # we ignore because mypy is confused and thinks key.key is a str for unknown reasons.
         keyptr = key.key
         (bytes_,) = struct.unpack('!B', keyptr[0:1])
         keyptr = keyptr[1:]
@@ -254,12 +234,12 @@ def _validate_signature(sig, data, key, chosen_hash):
         rsa_e = keyptr[0:bytes_]
         rsa_n = keyptr[bytes_:]
         try:
-            public_key = rsa.RSAPublicNumbers(
+            rsa_public_key = rsa.RSAPublicNumbers(
                 _bytes_to_long(rsa_e),
                 _bytes_to_long(rsa_n)).public_key(default_backend())
         except ValueError:
             raise ValidationFailure('invalid public key')
-        public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
+        rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
     elif _is_dsa(key.algorithm):
         keyptr = key.key
         (t,) = struct.unpack('!B', keyptr[0:1])
@@ -273,7 +253,7 @@ def _validate_signature(sig, data, key, chosen_hash):
         keyptr = keyptr[octets:]
         dsa_y = keyptr[0:octets]
         try:
-            public_key = dsa.DSAPublicNumbers(
+            dsa_public_key = dsa.DSAPublicNumbers(
                 _bytes_to_long(dsa_y),
                 dsa.DSAParameterNumbers(
                     _bytes_to_long(dsa_p),
@@ -281,9 +261,10 @@ def _validate_signature(sig, data, key, chosen_hash):
                     _bytes_to_long(dsa_g))).public_key(default_backend())
         except ValueError:
             raise ValidationFailure('invalid public key')
-        public_key.verify(sig, data, chosen_hash)
+        dsa_public_key.verify(sig, data, chosen_hash)
     elif _is_ecdsa(key.algorithm):
         keyptr = key.key
+        curve: Any
         if key.algorithm == Algorithm.ECDSAP256SHA256:
             curve = ec.SECP256R1()
             octets = 32
@@ -293,24 +274,25 @@ def _validate_signature(sig, data, key, chosen_hash):
         ecdsa_x = keyptr[0:octets]
         ecdsa_y = keyptr[octets:octets * 2]
         try:
-            public_key = ec.EllipticCurvePublicNumbers(
+            ecdsa_public_key = ec.EllipticCurvePublicNumbers(
                 curve=curve,
                 x=_bytes_to_long(ecdsa_x),
                 y=_bytes_to_long(ecdsa_y)).public_key(default_backend())
         except ValueError:
             raise ValidationFailure('invalid public key')
-        public_key.verify(sig, data, ec.ECDSA(chosen_hash))
+        ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
     elif _is_eddsa(key.algorithm):
         keyptr = key.key
+        loader: Any
         if key.algorithm == Algorithm.ED25519:
             loader = ed25519.Ed25519PublicKey
         else:
             loader = ed448.Ed448PublicKey
         try:
-            public_key = loader.from_public_bytes(keyptr)
+            eddsa_public_key = loader.from_public_bytes(keyptr)
         except ValueError:
             raise ValidationFailure('invalid public key')
-        public_key.verify(sig, data)
+        eddsa_public_key.verify(sig, data)
     elif _is_gost(key.algorithm):
         raise UnsupportedAlgorithm(
             'algorithm "%s" not supported by dnspython' %
@@ -319,7 +301,10 @@ def _validate_signature(sig, data, key, chosen_hash):
         raise ValidationFailure('unknown algorithm %u' % key.algorithm)
 
 
-def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
+def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+                    rrsig: RRSIG,
+                    keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+                    origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
     """Validate an RRset against a single signature rdata, throwing an
     exception if validation is not successful.
 
@@ -337,7 +322,7 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
     *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative
     names.
 
-    *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to
+    *now*, a ``float`` or ``None``, the time, in seconds since the epoch, to
     use as the current time when validating.  If ``None``, the actual current
     time is used.
 
@@ -394,7 +379,10 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
     data += rrsig.signer.to_digestable(origin)
 
     # Derelativize the name before considering labels.
-    rrname = rrname.derelativize(origin)
+    if not rrname.is_absolute():
+        if origin is None:
+            raise ValidationFailure('relative RR name without an origin specified')
+        rrname = rrname.derelativize(origin)
 
     if len(rrname) - 1 < rrsig.labels:
         raise ValidationFailure('owner name longer than RRSIG labels')
@@ -425,7 +413,10 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
     raise ValidationFailure('verify failure')
 
 
-def _validate(rrset, rrsigset, keys, origin=None, now=None):
+def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+              rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+              keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+              origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
     """Validate an RRset against a signature RRset, throwing an exception
     if none of the signatures validate.
 
@@ -475,6 +466,8 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None):
         raise ValidationFailure("owner names do not match")
 
     for rrsig in rrsigrdataset:
+        if not isinstance(rrsig, RRSIG):
+            raise ValidationFailure('expected an RRSIG')
         try:
             _validate_rrsig(rrset, rrsig, keys, origin, now)
             return
@@ -483,15 +476,6 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None):
     raise ValidationFailure("no RRSIGs validated")
 
 
-class NSEC3Hash(dns.enum.IntEnum):
-    """NSEC3 hash algorithm"""
-
-    SHA1 = 1
-
-    @classmethod
-    def _maximum(cls):
-        return 255
-
 def nsec3_hash(domain, salt, iterations, algorithm):
     """
     Calculate the NSEC3 hash, according to
diff --git a/dns/dnssec.pyi b/dns/dnssec.pyi
deleted file mode 100644 (file)
index e126f9b..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Union, Dict, Tuple, Optional
-from . import rdataset, rrset, exception, name, rdtypes, rdata, node
-import dns.rdtypes.ANY.DS as DS
-import dns.rdtypes.ANY.DNSKEY as DNSKEY
-
-_have_pyca : bool
-
-def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
-    ...
-
-def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
-    ...
-
-class ValidationFailure(exception.DNSException):
-    ...
-
-def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
-    ...
-
-def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
-    ...
diff --git a/dns/dnssectypes.py b/dns/dnssectypes.py
new file mode 100644 (file)
index 0000000..2a74716
--- /dev/null
@@ -0,0 +1,69 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""Common DNSSEC-related types."""
+
+# This is a separate file to avoid import circularity between dns.dnssec and
+# the implementations of the DS and DNSKEY types.
+
+import dns.enum
+
+
+class Algorithm(dns.enum.IntEnum):
+    RSAMD5 = 1
+    DH = 2
+    DSA = 3
+    ECC = 4
+    RSASHA1 = 5
+    DSANSEC3SHA1 = 6
+    RSASHA1NSEC3SHA1 = 7
+    RSASHA256 = 8
+    RSASHA512 = 10
+    ECCGOST = 12
+    ECDSAP256SHA256 = 13
+    ECDSAP384SHA384 = 14
+    ED25519 = 15
+    ED448 = 16
+    INDIRECT = 252
+    PRIVATEDNS = 253
+    PRIVATEOID = 254
+
+    @classmethod
+    def _maximum(cls):
+        return 255
+
+
+class DSDigest(dns.enum.IntEnum):
+    """DNSSEC Delegation Signer Digest Algorithm"""
+
+    SHA1 = 1
+    SHA256 = 2
+    SHA384 = 4
+
+    @classmethod
+    def _maximum(cls):
+        return 255
+
+
+class NSEC3Hash(dns.enum.IntEnum):
+    """NSEC3 hash algorithm"""
+
+    SHA1 = 1
+
+    @classmethod
+    def _maximum(cls):
+        return 255
index 83731b2c56260b79b35fe46634f04c7e339fe51d..8c9a3ac58bd3031fbbb438447916a749e5150360 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS E.164 helpers."""
 
+from typing import Iterable, Optional, Union
+
 import dns.exception
 import dns.name
 import dns.resolver
@@ -25,7 +27,7 @@ import dns.resolver
 public_enum_domain = dns.name.from_text('e164.arpa.')
 
 
-def from_e164(text, origin=public_enum_domain):
+def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> dns.name.Name:
     """Convert an E.164 number in textual form into a Name object whose
     value is the ENUM domain name for that number.
 
@@ -45,7 +47,8 @@ def from_e164(text, origin=public_enum_domain):
     return dns.name.from_text('.'.join(parts), origin=origin)
 
 
-def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
+def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain,
+            want_plus_prefix=True) -> str:
     """Convert an ENUM domain name into an E.164 number.
 
     Note that dnspython does not have any information about preferred
@@ -77,7 +80,8 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
     return text.decode()
 
 
-def query(number, domains, resolver=None):
+def query(number: str, domains: Iterable[Union[dns.name.Name, str]],
+          resolver: Optional[dns.resolver.Resolver]=None) -> dns.resolver.Answer:
     """Look for NAPTR RRs for the specified number in the specified domains.
 
     e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
diff --git a/dns/e164.pyi b/dns/e164.pyi
deleted file mode 100644 (file)
index 37a99fe..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-from typing import Optional, Iterable
-from . import name, resolver
-def from_e164(text : str, origin=name.Name(".")) -> name.Name:
-    ...
-
-def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
-    ...
-
-def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
-    ...
index fa4e98b1d83847bd5ec698e96f55e94b88d862a6..15c646de1deffec41793c90bc1d35bc107fc7a16 100644 (file)
@@ -17,6 +17,8 @@
 
 """EDNS Options"""
 
+from typing import Any, Dict, Optional, Union
+
 import math
 import socket
 import struct
@@ -24,6 +26,7 @@ import struct
 import dns.enum
 import dns.inet
 import dns.rdata
+import dns.wire
 
 
 class OptionType(dns.enum.IntEnum):
@@ -59,14 +62,14 @@ class Option:
 
     """Base class for all EDNS option types."""
 
-    def __init__(self, otype):
+    def __init__(self, otype: Union[OptionType, str]):
         """Initialize an option.
 
-        *otype*, an ``int``, is the option type.
+        *otype*, a ``dns.edns.OptionType``, is the option type.
         """
         self.otype = OptionType.make(otype)
 
-    def to_wire(self, file=None):
+    def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
         """Convert an option to wire format.
 
         Returns a ``bytes`` or ``None``.
@@ -75,10 +78,10 @@ class Option:
         raise NotImplementedError  # pragma: no cover
 
     @classmethod
-    def from_wire_parser(cls, otype, parser):
+    def from_wire_parser(cls, otype: OptionType, parser: 'dns.wire.Parser') -> 'Option':
         """Build an EDNS option object from wire format.
 
-        *otype*, an ``int``, is the option type.
+        *otype*, a ``dns.edns.OptionType``, is the option type.
 
         *parser*, a ``dns.wire.Parser``, the parser, which should be
         restructed to the option length.
@@ -150,28 +153,29 @@ class GenericOption(Option):  # lgtm[py/missing-equals]
     implementation.
     """
 
-    def __init__(self, otype, data):
+    def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
         super().__init__(otype)
         self.data = dns.rdata.Rdata._as_bytes(data, True)
 
-    def to_wire(self, file=None):
+    def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
         if file:
             file.write(self.data)
+            return None
         else:
             return self.data
 
-    def to_text(self):
+    def to_text(self) -> str:
         return "Generic %d" % self.otype
 
     @classmethod
-    def from_wire_parser(cls, otype, parser):
+    def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option:
         return cls(otype, parser.get_remaining())
 
 
 class ECSOption(Option):  # lgtm[py/missing-equals]
     """EDNS Client Subnet (ECS, RFC7871)"""
 
-    def __init__(self, address, srclen=None, scopelen=0):
+    def __init__(self, address: str, srclen: Optional[int]=None, scopelen=0):
         """*address*, a ``str``, is the client address information.
 
         *srclen*, an ``int``, the source prefix length, which is the
@@ -202,6 +206,7 @@ class ECSOption(Option):  # lgtm[py/missing-equals]
         else:  # pragma: no cover   (this will never happen)
             raise ValueError('Bad address family')
 
+        assert srclen is not None
         self.address = address
         self.srclen = srclen
         self.scopelen = scopelen
@@ -218,12 +223,12 @@ class ECSOption(Option):  # lgtm[py/missing-equals]
                                ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
             self.addrdata = self.addrdata[:-1] + last
 
-    def to_text(self):
+    def to_text(self) -> str:
         return "ECS {}/{} scope/{}".format(self.address, self.srclen,
                                            self.scopelen)
 
     @staticmethod
-    def from_text(text):
+    def from_text(text) -> Option:
         """Convert a string into a `dns.edns.ECSOption`
 
         *text*, a `str`, the text form of the option.
@@ -277,16 +282,17 @@ class ECSOption(Option):  # lgtm[py/missing-equals]
                              '"{}": srclen must be an integer'.format(srclen))
         return ECSOption(address, srclen, scope)
 
-    def to_wire(self, file=None):
+    def to_wire(self, file=None) -> Optional[bytes]:
         value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
                  self.addrdata)
         if file:
             file.write(value)
+            return None
         else:
             return value
 
     @classmethod
-    def from_wire_parser(cls, otype, parser):
+    def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser'):
         family, src, scope = parser.get_struct('!HBB')
         addrlen = int(math.ceil(src / 8.0))
         prefix = parser.get_bytes(addrlen)
@@ -337,7 +343,7 @@ class EDECode(dns.enum.IntEnum):
 class EDEOption(Option):  # lgtm[py/missing-equals]
     """Extended DNS Error (EDE, RFC8914)"""
 
-    def __init__(self, code, text=None):
+    def __init__(self, code: Union[EDECode, str], text: Optional[str]=None):
         """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
         extended error.
 
@@ -350,28 +356,27 @@ class EDEOption(Option):  # lgtm[py/missing-equals]
         self.code = EDECode.make(code)
         if text is not None and not isinstance(text, str):
             raise ValueError('text must be string or None')
-
-        self.code = code
         self.text = text
 
-    def to_text(self):
+    def to_text(self) -> str:
         output = f'EDE {self.code}'
         if self.text is not None:
             output += f': {self.text}'
         return output
 
-    def to_wire(self, file=None):
+    def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
         value = struct.pack('!H', self.code)
         if self.text is not None:
             value += self.text.encode('utf8')
 
         if file:
             file.write(value)
+            return None
         else:
             return value
 
     @classmethod
-    def from_wire_parser(cls, otype, parser):
+    def from_wire_parser(cls, otype: Union[OptionType, str], parser) -> Option:
         code = parser.get_uint16()
         text = parser.get_remaining()
 
@@ -385,13 +390,13 @@ class EDEOption(Option):  # lgtm[py/missing-equals]
         return cls(code, text)
 
 
-_type_to_class = {
+_type_to_class: Dict[OptionType, Any] = {
     OptionType.ECS: ECSOption,
     OptionType.EDE: EDEOption,
 }
 
 
-def get_option_class(otype):
+def get_option_class(otype: OptionType) -> Any:
     """Return the class for the specified option type.
 
     The GenericOption class is used if a more specific class is not
@@ -404,7 +409,7 @@ def get_option_class(otype):
     return cls
 
 
-def option_from_wire_parser(otype, parser):
+def option_from_wire_parser(otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option:
     """Build an EDNS option object from wire format.
 
     *otype*, an ``int``, is the option type.
@@ -414,12 +419,12 @@ def option_from_wire_parser(otype, parser):
 
     Returns an instance of a subclass of ``dns.edns.Option``.
     """
-    cls = get_option_class(otype)
-    otype = OptionType.make(otype)
+    the_otype = OptionType.make(otype)
+    cls = get_option_class(the_otype)
     return cls.from_wire_parser(otype, parser)
 
 
-def option_from_wire(otype, wire, current, olen):
+def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, olen: int) -> Option:
     """Build an EDNS option object from wire format.
 
     *otype*, an ``int``, is the option type.
@@ -437,7 +442,7 @@ def option_from_wire(otype, wire, current, olen):
     with parser.restrict_to(olen):
         return option_from_wire_parser(otype, parser)
 
-def register_type(implementation, otype):
+def register_type(implementation: Any, otype: OptionType):
     """Register the implementation of an option type.
 
     *implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
index 086bba787d4d2466f714b7400591b1b8f52c192a..528a628bf7d6f2750991579d3b7e060eb863774e 100644 (file)
@@ -15,6 +15,8 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+from typing import Any, Optional
+
 import os
 import hashlib
 import random
@@ -34,7 +36,7 @@ class EntropyPool:
 
     def __init__(self, seed=None):
         self.pool_index = 0
-        self.digest = None
+        self.digest: Optional[bytearray] = None
         self.next_byte = 0
         self.lock = _threading.Lock()
         self.hash = hashlib.sha1()
@@ -76,7 +78,7 @@ class EntropyPool:
             seed = bytearray(seed)
             self._stir(seed)
 
-    def random_8(self):
+    def random_8(self) -> int:
         with self.lock:
             self._maybe_seed()
             if self.digest is None or self.next_byte == self.hash_len:
@@ -88,13 +90,13 @@ class EntropyPool:
             self.next_byte += 1
         return value
 
-    def random_16(self):
+    def random_16(self) -> int:
         return self.random_8() * 256 + self.random_8()
 
-    def random_32(self):
+    def random_32(self) -> int:
         return self.random_16() * 65536 + self.random_16()
 
-    def random_between(self, first, last):
+    def random_between(self, first: int, last: int) -> int:
         size = last - first + 1
         if size > 4294967296:
             raise ValueError('too big')
@@ -111,18 +113,19 @@ class EntropyPool:
 
 pool = EntropyPool()
 
+system_random: Optional[Any]
 try:
     system_random = random.SystemRandom()
 except Exception:  # pragma: no cover
     system_random = None
 
-def random_16():
+def random_16() -> int:
     if system_random is not None:
         return system_random.randrange(0, 65536)
     else:
         return pool.random_16()
 
-def between(first, last):
+def between(first: int, last: int) -> int:
     if system_random is not None:
         return system_random.randrange(first, last + 1)
     else:
diff --git a/dns/entropy.pyi b/dns/entropy.pyi
deleted file mode 100644 (file)
index 818f805..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-from typing import Optional
-from random import SystemRandom
-
-system_random : Optional[SystemRandom]
-
-def random_16() -> int:
-   pass
-
-def between(first: int, last: int) -> int:
-    pass
index 5376458805bbd0d3c5b70ca750c6a5df7a399957..550a1bcfcdd1e1b4102f7aa3bf6f79643f9d1127 100644 (file)
@@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
 always be subclasses of ``DNSException``.
 """
 
+
+from typing import Dict, Optional, Set
+
+
 class DNSException(Exception):
     """Abstract base class shared by all dnspython exceptions.
 
@@ -44,9 +48,9 @@ class DNSException(Exception):
     and ``fmt`` class variables to get nice parametrized messages.
     """
 
-    msg = None  # non-parametrized message
-    supp_kwargs = set()  # accepted parameters for _fmt_kwargs (sanity check)
-    fmt = None  # message parametrized with results from _fmt_kwargs
+    msg: Optional[str] = None  # non-parametrized message
+    supp_kwargs: Set[str] = set()  # accepted parameters for _fmt_kwargs (sanity check)
+    fmt: Optional[str] = None  # message parametrized with results from _fmt_kwargs
 
     def __init__(self, *args, **kwargs):
         self._check_params(*args, **kwargs)
@@ -128,6 +132,10 @@ class Timeout(DNSException):
     supp_kwargs = {'timeout'}
     fmt = "The DNS operation timed out after {timeout:.3f} seconds"
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
 
 class ExceptionWrapper:
     def __init__(self, exception_class):
diff --git a/dns/exception.pyi b/dns/exception.pyi
deleted file mode 100644 (file)
index dc57126..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-from typing import Set, Optional, Dict
-
-class DNSException(Exception):
-    supp_kwargs : Set[str]
-    kwargs : Optional[Dict]
-    fmt : Optional[str]
-
-class SyntaxError(DNSException): ...
-class FormError(DNSException): ...
-class Timeout(DNSException): ...
-class TooBig(DNSException): ...
-class UnexpectedEnd(SyntaxError): ...
index 965228798c82e23e8cd8cdbcabc8b160f6f3f9b2..6fe1afd3cc1ef1704890312eef819649e7f60275 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Message Flags."""
 
+from typing import Any
+
 import enum
 
 # Standard DNS flags
@@ -45,7 +47,7 @@ class EDNSFlag(enum.IntFlag):
     DO = 0x8000
 
 
-def _from_text(text, enum_class):
+def _from_text(text: str, enum_class: Any) -> int:
     flags = 0
     tokens = text.split()
     for t in tokens:
@@ -53,7 +55,7 @@ def _from_text(text, enum_class):
     return flags
 
 
-def _to_text(flags, enum_class):
+def _to_text(flags: int, enum_class: Any) -> str:
     text_flags = []
     for k, v in enum_class.__members__.items():
         if flags & v != 0:
@@ -61,7 +63,7 @@ def _to_text(flags, enum_class):
     return ' '.join(text_flags)
 
 
-def from_text(text):
+def from_text(text: str) -> int:
     """Convert a space-separated list of flag text values into a flags
     value.
 
@@ -71,7 +73,7 @@ def from_text(text):
     return _from_text(text, Flag)
 
 
-def to_text(flags):
+def to_text(flags: int) -> str:
     """Convert a flags value into a space-separated list of flag text
     values.
 
@@ -81,7 +83,7 @@ def to_text(flags):
     return _to_text(flags, Flag)
 
 
-def edns_from_text(text):
+def edns_from_text(text: str) -> int:
     """Convert a space-separated list of EDNS flag text values into a EDNS
     flags value.
 
@@ -91,7 +93,7 @@ def edns_from_text(text):
     return _from_text(text, EDNSFlag)
 
 
-def edns_to_text(flags):
+def edns_to_text(flags: int) -> str:
     """Convert an EDNS flags value into a space-separated list of EDNS flag
     text values.
 
index 112ede47c432b40d927a880a69c99309576aa54d..ebb64d2d86981281e8b82dd1958c7e6adac52372 100644 (file)
 
 """DNS GENERATE range conversion."""
 
+from typing import Tuple
+
 import dns
 
-def from_text(text):
+def from_text(text: str) -> Tuple[int, int, int]:
     """Convert the text form of a range in a ``$GENERATE`` statement to an
     integer.
 
index d3bdc64c8800c9cb6741406f23be753fb6329326..b3ed9995e60e07f8df2cb49e9e0022be67db001e 100644 (file)
@@ -17,6 +17,8 @@
 
 """Generic Internet address helper functions."""
 
+from typing import Any, Optional, Tuple
+
 import socket
 
 import dns.ipv4
@@ -30,7 +32,7 @@ AF_INET = socket.AF_INET
 AF_INET6 = socket.AF_INET6
 
 
-def inet_pton(family, text):
+def inet_pton(family: int, text: str) -> bytes:
     """Convert the textual form of a network address into its binary form.
 
     *family* is an ``int``, the address family.
@@ -51,7 +53,7 @@ def inet_pton(family, text):
         raise NotImplementedError
 
 
-def inet_ntop(family, address):
+def inet_ntop(family: int, address: bytes) -> str:
     """Convert the binary form of a network address into its textual form.
 
     *family* is an ``int``, the address family.
@@ -72,7 +74,7 @@ def inet_ntop(family, address):
         raise NotImplementedError
 
 
-def af_for_address(text):
+def af_for_address(text: str) -> int:
     """Determine the address family of a textual-form network address.
 
     *text*, a ``str``, the textual address.
@@ -94,7 +96,7 @@ def af_for_address(text):
             raise ValueError
 
 
-def is_multicast(text):
+def is_multicast(text: str) -> bool:
     """Is the textual-form network address a multicast address?
 
     *text*, a ``str``, the textual address.
@@ -116,7 +118,7 @@ def is_multicast(text):
             raise ValueError
 
 
-def is_address(text):
+def is_address(text: str) -> bool:
     """Is the specified string an IPv4 or IPv6 address?
 
     *text*, a ``str``, the textual address.
@@ -135,7 +137,7 @@ def is_address(text):
             return False
 
 
-def low_level_address_tuple(high_tuple, af=None):
+def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) -> Any:
     """Given a "high-level" address tuple, i.e.
     an (address, port) return the appropriate "low-level" address tuple
     suitable for use in socket calls.
@@ -143,7 +145,6 @@ def low_level_address_tuple(high_tuple, af=None):
     If an *af* other than ``None`` is provided, it is assumed the
     address in the high-level tuple is valid and has that af.  If af
     is ``None``, then af_for_address will be called.
-
     """
     address, port = high_tuple
     if af is None:
diff --git a/dns/inet.pyi b/dns/inet.pyi
deleted file mode 100644 (file)
index 6d9dcc7..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-from typing import Union
-from socket import AddressFamily
-
-AF_INET6 : Union[int, AddressFamily]
index e1f38d3d4fcf0123fee1a124541b2d3dc6bf0ccf..fddad1b1e6b4c6613ad4dec5ee7a7fe8984639df 100644 (file)
 
 """IPv4 helper functions."""
 
+from typing import Union
+
 import struct
 
 import dns.exception
 
-def inet_ntoa(address):
+def inet_ntoa(address: bytes) -> str:
     """Convert an IPv4 address in binary form to text form.
 
     *address*, a ``bytes``, the IPv4 address in binary form.
@@ -34,17 +36,19 @@ def inet_ntoa(address):
     return ('%u.%u.%u.%u' % (address[0], address[1],
                              address[2], address[3]))
 
-def inet_aton(text):
+def inet_aton(text: Union[str, bytes]) -> bytes:
     """Convert an IPv4 address in text form to binary form.
 
-    *text*, a ``str``, the IPv4 address in textual form.
+    *text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
 
     Returns a ``bytes``.
     """
 
     if not isinstance(text, bytes):
-        text = text.encode()
-    parts = text.split(b'.')
+        btext = text.encode()
+    else:
+        btext = text
+    parts = btext.split(b'.')
     if len(parts) != 4:
         raise dns.exception.SyntaxError
     for part in parts:
index 0db6fcfaa9763e12008b695d05768f26a442e859..1d5bffde96f528fdbd3c99593dc2e0d8cc83629a 100644 (file)
@@ -17,6 +17,8 @@
 
 """IPv6 helper functions."""
 
+from typing import List, Union
+
 import re
 import binascii
 
@@ -25,7 +27,7 @@ import dns.ipv4
 
 _leading_zero = re.compile(r'0+([0-9a-f]+)')
 
-def inet_ntoa(address):
+def inet_ntoa(address: bytes) -> str:
     """Convert an IPv6 address in binary form to text form.
 
     *address*, a ``bytes``, the IPv6 address in binary form.
@@ -84,19 +86,19 @@ def inet_ntoa(address):
                 prefix = '::'
             else:
                 prefix = '::ffff:'
-            hex = prefix + dns.ipv4.inet_ntoa(address[12:])
+            thex = prefix + dns.ipv4.inet_ntoa(address[12:])
         else:
-            hex = ':'.join(chunks[:best_start]) + '::' + \
+            thex = ':'.join(chunks[:best_start]) + '::' + \
                   ':'.join(chunks[best_start + best_len:])
     else:
-        hex = ':'.join(chunks)
-    return hex
+        thex = ':'.join(chunks)
+    return thex
 
 _v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
 _colon_colon_start = re.compile(br'::.*')
 _colon_colon_end = re.compile(br'.*::$')
 
-def inet_aton(text, ignore_scope=False):
+def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes:
     """Convert an IPv6 address in text form to binary form.
 
     *text*, a ``str``, the IPv6 address in textual form.
@@ -111,53 +113,55 @@ def inet_aton(text, ignore_scope=False):
     # Our aim here is not something fast; we just want something that works.
     #
     if not isinstance(text, bytes):
-        text = text.encode()
+        btext = text.encode()
+    else:
+        btext = text
 
     if ignore_scope:
-        parts = text.split(b'%')
+        parts = btext.split(b'%')
         l = len(parts)
         if l == 2:
-            text = parts[0]
+            btext = parts[0]
         elif l > 2:
             raise dns.exception.SyntaxError
 
-    if text == b'':
+    if btext == b'':
         raise dns.exception.SyntaxError
-    elif text.endswith(b':') and not text.endswith(b'::'):
+    elif btext.endswith(b':') and not btext.endswith(b'::'):
         raise dns.exception.SyntaxError
-    elif text.startswith(b':') and not text.startswith(b'::'):
+    elif btext.startswith(b':') and not btext.startswith(b'::'):
         raise dns.exception.SyntaxError
-    elif text == b'::':
-        text = b'0::'
+    elif btext == b'::':
+        btext = b'0::'
     #
     # Get rid of the icky dot-quad syntax if we have it.
     #
-    m = _v4_ending.match(text)
+    m = _v4_ending.match(btext)
     if m is not None:
         b = dns.ipv4.inet_aton(m.group(2))
-        text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
+        btext = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
                                                       b[0], b[1], b[2],
                                                       b[3])).encode()
     #
     # Try to turn '::<whatever>' into ':<whatever>'; if no match try to
     # turn '<whatever>::' into '<whatever>:'
     #
-    m = _colon_colon_start.match(text)
+    m = _colon_colon_start.match(btext)
     if m is not None:
-        text = text[1:]
+        btext = btext[1:]
     else:
-        m = _colon_colon_end.match(text)
+        m = _colon_colon_end.match(btext)
         if m is not None:
-            text = text[:-1]
+            btext = btext[:-1]
     #
     # Now canonicalize into 8 chunks of 4 hex digits each
     #
-    chunks = text.split(b':')
+    chunks = btext.split(b':')
     l = len(chunks)
     if l > 8:
         raise dns.exception.SyntaxError
     seen_empty = False
-    canonical = []
+    canonical: List[bytes] = []
     for c in chunks:
         if c == b'':
             if seen_empty:
@@ -174,13 +178,13 @@ def inet_aton(text, ignore_scope=False):
             canonical.append(c)
     if l < 8 and not seen_empty:
         raise dns.exception.SyntaxError
-    text = b''.join(canonical)
+    btext = b''.join(canonical)
 
     #
     # Finally we can go to binary.
     #
     try:
-        return binascii.unhexlify(text)
+        return binascii.unhexlify(btext)
     except (binascii.Error, TypeError):
         raise dns.exception.SyntaxError
 
index c2751a90c8f6f03accb2f3c298d7ff0d40e6265b..46c0a684b307c465ec42dd8c3cb5f491b988c93b 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Messages"""
 
+from typing import Any, Dict, List, Optional, Tuple, Union
+
 import contextlib
 import io
 import time
@@ -73,6 +75,10 @@ class Truncated(dns.exception.DNSException):
 
     supp_kwargs = {'message'}
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     def message(self):
         """As much of the message as could be processed.
 
@@ -109,7 +115,7 @@ class MessageSection(dns.enum.IntEnum):
 
 
 class MessageError:
-    def __init__(self, exception, offset):
+    def __init__(self, exception: Exception, offset: int):
         self.exception = exception
         self.offset = offset
 
@@ -117,31 +123,38 @@ class MessageError:
 DEFAULT_EDNS_PAYLOAD = 1232
 MAX_CHAIN = 16
 
+IndexKeyType = Tuple[int, dns.name.Name, dns.rdataclass.RdataClass,
+                     dns.rdatatype.RdataType, Optional[dns.rdatatype.RdataType],
+                     Optional[dns.rdataclass.RdataClass]]
+IndexType = Dict[IndexKeyType, dns.rrset.RRset]
+SectionType = Union[int, List[dns.rrset.RRset]]
+
 class Message:
     """A DNS message."""
 
     _section_enum = MessageSection
 
-    def __init__(self, id=None):
+    def __init__(self, id: Optional[int]=None):
         if id is None:
             self.id = dns.entropy.random_16()
         else:
             self.id = id
         self.flags = 0
-        self.sections = [[], [], [], []]
-        self.opt = None
+        self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []]
+        self.opt: Optional[dns.rrset.RRset] = None
         self.request_payload = 0
-        self.keyring = None
-        self.tsig = None
+        self.keyring: Any = None
+        self.tsig: Optional[dns.rrset.RRset] = None
         self.request_mac = b''
         self.xfr = False
-        self.origin = None
-        self.tsig_ctx = None
-        self.index = {}
-        self.errors = []
+        self.origin: Optional[dns.name.Name] = None
+        self.tsig_ctx: Optional[Any] = None
+        self.index: IndexType = {}
+        self.errors: List[MessageError] = []
+        self.time = 0.0
 
     @property
-    def question(self):
+    def question(self) -> List[dns.rrset.RRset]:
         """ The question section."""
         return self.sections[0]
 
@@ -150,7 +163,7 @@ class Message:
         self.sections[0] = v
 
     @property
-    def answer(self):
+    def answer(self) -> List[dns.rrset.RRset]:
         """ The answer section."""
         return self.sections[1]
 
@@ -159,7 +172,7 @@ class Message:
         self.sections[1] = v
 
     @property
-    def authority(self):
+    def authority(self) -> List[dns.rrset.RRset]:
         """ The authority section."""
         return self.sections[2]
 
@@ -168,7 +181,7 @@ class Message:
         self.sections[2] = v
 
     @property
-    def additional(self):
+    def additional(self) -> List[dns.rrset.RRset]:
         """ The additional data section."""
         return self.sections[3]
 
@@ -182,7 +195,8 @@ class Message:
     def __str__(self):
         return self.to_text()
 
-    def to_text(self, origin=None, relativize=True, **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True,
+                **kw):
         """Convert the message to text.
 
         The *origin*, *relativize*, and any other keyword
@@ -242,7 +256,7 @@ class Message:
     def __ne__(self, other):
         return not self.__eq__(other)
 
-    def is_response(self, other):
+    def is_response(self, other: 'Message') -> bool:
         """Is *other*, also a ``dns.message.Message``, a response to this
         message?
 
@@ -275,7 +289,7 @@ class Message:
                 return False
         return True
 
-    def section_number(self, section):
+    def section_number(self, section: List[dns.rrset.RRset]) -> int:
         """Return the "section number" of the specified section for use
         in indexing.
 
@@ -291,7 +305,7 @@ class Message:
                 return self._section_enum(i)
         raise ValueError('unknown section')
 
-    def section_from_number(self, number):
+    def section_from_number(self, number: int) -> List[dns.rrset.RRset]:
         """Return the section list associated with the specified section
         number.
 
@@ -306,9 +320,15 @@ class Message:
         section = self._section_enum.make(number)
         return self.sections[section]
 
-    def find_rrset(self, section, name, rdclass, rdtype,
-                   covers=dns.rdatatype.NONE, deleting=None, create=False,
-                   force_unique=False):
+    def find_rrset(self,
+                   section: SectionType,
+                   name: dns.name.Name,
+                   rdclass: dns.rdataclass.RdataClass,
+                   rdtype: dns.rdatatype.RdataType,
+                   covers = dns.rdatatype.NONE,
+                   deleting: Optional[dns.rdataclass.RdataClass]=None,
+                   create=False,
+                   force_unique=False) -> dns.rrset.RRset:
         """Find the RRset with the given attributes in the specified section.
 
         *section*, an ``int`` section number, or one of the section
@@ -346,9 +366,10 @@ class Message:
 
         if isinstance(section, int):
             section_number = section
-            section = self.section_from_number(section_number)
+            the_section = self.section_from_number(section_number)
         else:
             section_number = self.section_number(section)
+            the_section = section
         key = (section_number, name, rdclass, rdtype, covers, deleting)
         if not force_unique:
             if self.index is not None:
@@ -356,21 +377,27 @@ class Message:
                 if rrset is not None:
                     return rrset
             else:
-                for rrset in section:
+                for rrset in the_section:
                     if rrset.full_match(name, rdclass, rdtype, covers,
                                         deleting):
                         return rrset
         if not create:
             raise KeyError
         rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
-        section.append(rrset)
+        the_section.append(rrset)
         if self.index is not None:
             self.index[key] = rrset
         return rrset
 
-    def get_rrset(self, section, name, rdclass, rdtype,
-                  covers=dns.rdatatype.NONE, deleting=None, create=False,
-                  force_unique=False):
+    def get_rrset(self,
+                  section: SectionType,
+                  name: dns.name.Name,
+                  rdclass: dns.rdataclass.RdataClass,
+                  rdtype: dns.rdatatype.RdataType,
+                  covers = dns.rdatatype.NONE,
+                  deleting: Optional[dns.rdataclass.RdataClass]=None,
+                  create=False,
+                  force_unique=False) -> Optional[dns.rrset.RRset]:
         """Get the RRset with the given attributes in the specified section.
 
         If the RRset is not found, None is returned.
@@ -412,8 +439,8 @@ class Message:
             rrset = None
         return rrset
 
-    def to_wire(self, origin=None, max_size=0, multi=False, tsig_ctx=None,
-                **kw):
+    def to_wire(self, origin: Optional[dns.name.Name]=None, max_size=0,
+                multi=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes:
         """Return a string containing the message in DNS compressed wire
         format.
 
@@ -486,9 +513,9 @@ class Message:
                                          original_id, error, other)
         return dns.rrset.from_rdata(keyname, 0, tsig)
 
-    def use_tsig(self, keyring, keyname=None, fudge=300,
-                 original_id=None, tsig_error=0, other_data=b'',
-                 algorithm=dns.tsig.default_algorithm):
+    def use_tsig(self, keyring: Any, keyname: Optional[dns.name.Name]=None,
+                 fudge=300, original_id: Optional[int]=None, tsig_error=0,
+                 other_data=b'', algorithm=dns.tsig.default_algorithm):
         """When sending, a TSIG signature using the specified key
         should be added.
 
@@ -546,35 +573,35 @@ class Message:
                                     b'', original_id, tsig_error, other_data)
 
     @property
-    def keyname(self):
+    def keyname(self) -> Optional[dns.name.Name]:
         if self.tsig:
             return self.tsig.name
         else:
             return None
 
     @property
-    def keyalgorithm(self):
+    def keyalgorithm(self) -> Optional[dns.name.Name]:
         if self.tsig:
             return self.tsig[0].algorithm
         else:
             return None
 
     @property
-    def mac(self):
+    def mac(self) -> Optional[bytes]:
         if self.tsig:
             return self.tsig[0].mac
         else:
             return None
 
     @property
-    def tsig_error(self):
+    def tsig_error(self) -> Optional[int]:
         if self.tsig:
             return self.tsig[0].error
         else:
             return None
 
     @property
-    def had_tsig(self):
+    def had_tsig(self) -> bool:
         return bool(self.tsig)
 
     @staticmethod
@@ -584,7 +611,8 @@ class Message:
         return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
 
     def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD,
-                 request_payload=None, options=None):
+                 request_payload: Optional[int]=None,
+                 options: Optional[List[dns.edns.Option]]=None):
         """Configure EDNS behavior.
 
         *edns*, an ``int``, is the EDNS level to use.  Specifying
@@ -625,14 +653,14 @@ class Message:
             self.request_payload = request_payload
 
     @property
-    def edns(self):
+    def edns(self) -> int:
         if self.opt:
             return (self.ednsflags & 0xff0000) >> 16
         else:
             return -1
 
     @property
-    def ednsflags(self):
+    def ednsflags(self) -> int:
         if self.opt:
             return self.opt.ttl
         else:
@@ -646,14 +674,14 @@ class Message:
             self.opt = self._make_opt(v)
 
     @property
-    def payload(self):
+    def payload(self) -> int:
         if self.opt:
             return self.opt[0].payload
         else:
             return 0
 
     @property
-    def options(self):
+    def options(self) -> Tuple:
         if self.opt:
             return self.opt[0].options
         else:
@@ -673,17 +701,17 @@ class Message:
         elif self.opt:
             self.ednsflags &= ~dns.flags.DO
 
-    def rcode(self):
+    def rcode(self) -> dns.rcode.Rcode:
         """Return the rcode.
 
-        Returns an ``int``.
+        Returns a ``dns.rcode.Rcode``.
         """
         return dns.rcode.from_flags(int(self.flags), int(self.ednsflags))
 
-    def set_rcode(self, rcode):
+    def set_rcode(self, rcode: dns.rcode.Rcode):
         """Set the rcode.
 
-        *rcode*, an ``int``, is the rcode to set.
+        *rcode*, a ``dns.rcode.Rcode``, is the rcode to set.
         """
         (value, evalue) = dns.rcode.to_flags(rcode)
         self.flags &= 0xFFF0
@@ -691,17 +719,17 @@ class Message:
         self.ednsflags &= 0x00FFFFFF
         self.ednsflags |= evalue
 
-    def opcode(self):
+    def opcode(self) -> dns.opcode.Opcode:
         """Return the opcode.
 
-        Returns an ``int``.
+        Returns a ``dns.opcode.Opcode``.
         """
         return dns.opcode.from_flags(int(self.flags))
 
-    def set_opcode(self, opcode):
+    def set_opcode(self, opcode: dns.opcode.Opcode):
         """Set the opcode.
 
-        *opcode*, an ``int``, is the opcode to set.
+        *opcode*, a ``dns.opcode.Opcode``, is the opcode to set.
         """
         self.flags &= 0x87FF
         self.flags |= dns.opcode.to_flags(opcode)
@@ -738,7 +766,7 @@ class ChainingResult:
     exist.
 
     The ``canonical_name`` attribute is the canonical name after all
-    chaining has been applied (this is the name as ``rrset.name`` in cases
+    chaining has been applied (this is the same name as ``rrset.name`` in cases
     where rrset is not ``None``).
 
     The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to
@@ -749,7 +777,8 @@ class ChainingResult:
     The ``cnames`` attribute is a list of all the CNAME RRSets followed to
     get to the canonical name.
     """
-    def __init__(self, canonical_name, answer, minimum_ttl, cnames):
+    def __init__(self, canonical_name: dns.name.Name, answer: Optional[dns.rrset.RRset],
+                 minimum_ttl: int, cnames: List[dns.rrset.RRset]):
         self.canonical_name = canonical_name
         self.answer = answer
         self.minimum_ttl = minimum_ttl
@@ -757,7 +786,7 @@ class ChainingResult:
 
 
 class QueryMessage(Message):
-    def resolve_chaining(self):
+    def resolve_chaining(self) -> ChainingResult:
         """Follow the CNAME chain in the response to determine the answer
         RRset.
 
@@ -831,7 +860,7 @@ class QueryMessage(Message):
                         break
         return ChainingResult(qname, answer, min_ttl, cnames)
 
-    def canonical_name(self):
+    def canonical_name(self) -> dns.name.Name:
         """Return the canonical name of the first name in the question
         section.
 
@@ -1042,7 +1071,7 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
               tsig_ctx=None, multi=False,
               question_only=False, one_rr_per_rrset=False,
               ignore_trailing=False, raise_on_truncation=False,
-              continue_on_error=False):
+              continue_on_error=False) -> Message:
     """Convert a DNS wire format message into a message object.
 
     *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
@@ -1354,7 +1383,7 @@ class _TextReader:
 
 
 def from_text(text, idna_codec=None, one_rr_per_rrset=False,
-              origin=None, relativize=True, relativize_to=None):
+              origin=None, relativize=True, relativize_to=None) -> Message:
     """Convert the text format message into a message object.
 
     The reader stops after reading the first blank line in the input to
@@ -1394,7 +1423,7 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False,
     return reader.read()
 
 
-def from_file(f, idna_codec=None, one_rr_per_rrset=False):
+def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message:
     """Read the next text format message from the specified file.
 
     Message blocks are separated by a single blank line.
@@ -1420,12 +1449,14 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False):
         if isinstance(f, str):
             f = stack.enter_context(open(f))
         return from_text(f, idna_codec, one_rr_per_rrset)
+    assert False  # for mypy
 
 
 def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
-               want_dnssec=False, ednsflags=None, payload=None,
-               request_payload=None, options=None, idna_codec=None,
-               id=None, flags=dns.flags.RD):
+               want_dnssec=False, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+               request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
+               idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None,
+               flags: int=dns.flags.RD) -> QueryMessage:
     """Make a query message.
 
     The query name, type, and class may all be specified either
@@ -1487,7 +1518,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
     # only pass keywords on to use_edns if they have been set to a
     # non-None value.  Setting a field will turn EDNS on if it hasn't
     # been configured.
-    kwargs = {}
+    kwargs: Dict[str, Any] = {}
     if ednsflags is not None:
         kwargs['ednsflags'] = ednsflags
     if payload is not None:
@@ -1505,7 +1536,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
 
 
 def make_response(query, recursion_available=False, our_payload=8192,
-                  fudge=300, tsig_error=0):
+                  fudge=300, tsig_error=0) -> Message:
     """Make a message which is a response for the specified query.
     The message returned is really a response skeleton; it has all
     of the infrastructure required of a response, but none of the
diff --git a/dns/message.pyi b/dns/message.pyi
deleted file mode 100644 (file)
index 252a411..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-from typing import Optional, Dict, List, Tuple, Union
-from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
-import hmac
-
-class Message:
-    def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
-        ...
-    def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
-                   covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
-                   force_unique=False) -> rrset.RRset:
-        ...
-    def __init__(self, id : Optional[int] =None) -> None:
-        self.id : int
-        self.flags = 0
-        self.sections : List[List[rrset.RRset]] = [[], [], [], []]
-        self.opt : rrset.RRset = None
-        self.request_payload = 0
-        self.keyring = None
-        self.tsig : rrset.RRset = None
-        self.request_mac = b''
-        self.xfr = False
-        self.origin = None
-        self.tsig_ctx = None
-        self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
-
-    def is_response(self, other : Message) -> bool:
-        ...
-
-    def set_rcode(self, rcode : rcode.Rcode):
-        ...
-
-def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
-    ...
-
-def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
-              tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
-              question_only=False, one_rr_per_rrset=False,
-              ignore_trailing=False) -> Message:
-    ...
-def make_response(query : Message, recursion_available=False, our_payload=8192,
-                  fudge=300) -> Message:
-    ...
-
-def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
-               want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
-               request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
-    ...
index 8905d70f723963f0e30e85fe084c6a52ecb6c246..29078eede887d3c92d531e280bec17326e49d606 100644 (file)
@@ -18,6 +18,8 @@
 """DNS Names.
 """
 
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
 import copy
 import struct
 
@@ -28,22 +30,47 @@ try:
 except ImportError:  # pragma: no cover
     have_idna_2008 = False
 
+import dns.enum
 import dns.wire
 import dns.exception
 import dns.immutable
 
-# fullcompare() result values
 
-#: The compared names have no relationship to each other.
-NAMERELN_NONE = 0
-#: the first name is a superdomain of the second.
-NAMERELN_SUPERDOMAIN = 1
-#: The first name is a subdomain of the second.
-NAMERELN_SUBDOMAIN = 2
-#: The compared names are equal.
-NAMERELN_EQUAL = 3
-#: The compared names have a common ancestor.
-NAMERELN_COMMONANCESTOR = 4
+CompressType = Dict['Name', int]
+
+
+class NameRelation(dns.enum.IntEnum):
+    """Name relation result from fullcompare()."""
+
+    # This is an IntEnum for backwards compatibility in case anyone
+    # has hardwired the constants.
+
+    #: The compared names have no relationship to each other.
+    NONE = 0
+    #: the first name is a superdomain of the second.
+    SUPERDOMAIN = 1
+    #: The first name is a subdomain of the second.
+    SUBDOMAIN = 2
+    #: The compared names are equal.
+    EQUAL = 3
+    #: The compared names have a common ancestor.
+    COMMONANCESTOR = 4
+
+    @classmethod
+    def _maximum(cls):
+        return cls.COMMONANCESTOR
+
+    @classmethod
+    def _short_name(cls):
+        return cls.__name__
+
+
+# Backwards compatibility
+NAMERELN_NONE = NameRelation.NONE
+NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
+NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
+NAMERELN_EQUAL = NameRelation.EQUAL
+NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR
 
 
 class EmptyLabel(dns.exception.SyntaxError):
@@ -95,6 +122,42 @@ class IDNAException(dns.exception.DNSException):
     supp_kwargs = {'idna_exception'}
     fmt = "IDNA processing exception: {idna_exception}"
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+_escaped = b'"().;\\@$'
+_escaped_text = '"().;\\@$'
+
+def _escapify(label: Union[bytes, str]) -> str:
+    """Escape the characters in label which need it.
+    @returns: the escaped string
+    @rtype: string"""
+    if isinstance(label, bytes):
+        # Ordinary DNS label mode.  Escape special characters and values
+        # < 0x20 or > 0x7f.
+        text = ''
+        for c in label:
+            if c in _escaped:
+                text += '\\' + chr(c)
+            elif c > 0x20 and c < 0x7F:
+                text += chr(c)
+            else:
+                text += '\\%03d' % c
+        return text
+
+    # Unicode label mode.  Escape only special characters and values < 0x20
+    text = ''
+    for uc in label:
+        if uc in _escaped_text:
+            text += '\\' + uc
+        elif uc <= '\x20':
+            text += '\\%03d' % ord(uc)
+        else:
+            text += uc
+    return text
+
 
 class IDNACodec:
     """Abstract base class for IDNA encoder/decoders."""
@@ -102,20 +165,22 @@ class IDNACodec:
     def __init__(self):
         pass
 
-    def is_idna(self, label):
+    def is_idna(self, label: bytes) -> bool:
         return label.lower().startswith(b'xn--')
 
-    def encode(self, label):
+    def encode(self, label: str) -> bytes:
         raise NotImplementedError  # pragma: no cover
 
-    def decode(self, label):
+    def decode(self, label: bytes) -> str:
         # We do not apply any IDNA policy on decode.
         if self.is_idna(label):
             try:
-                label = label[4:].decode('punycode')
+                slabel = label[4:].decode('punycode')
+                return _escapify(slabel)
             except Exception as e:
                 raise IDNAException(idna_exception=e)
-        return _escapify(label)
+        else:
+            return _escapify(label)
 
 
 class IDNA2003Codec(IDNACodec):
@@ -132,7 +197,7 @@ class IDNA2003Codec(IDNACodec):
         super().__init__()
         self.strict_decode = strict_decode
 
-    def encode(self, label):
+    def encode(self, label: str) -> bytes:
         """Encode *label*."""
 
         if label == '':
@@ -142,7 +207,7 @@ class IDNA2003Codec(IDNACodec):
         except UnicodeError:
             raise LabelTooLong
 
-    def decode(self, label):
+    def decode(self, label: bytes) -> str:
         """Decode *label*."""
         if not self.strict_decode:
             return super().decode(label)
@@ -188,7 +253,7 @@ class IDNA2008Codec(IDNACodec):
         self.allow_pure_ascii = allow_pure_ascii
         self.strict_decode = strict_decode
 
-    def encode(self, label):
+    def encode(self, label: str) -> bytes:
         if label == '':
             return b''
         if self.allow_pure_ascii and is_all_ascii(label):
@@ -208,7 +273,7 @@ class IDNA2008Codec(IDNACodec):
             else:
                 raise IDNAException(idna_exception=e)
 
-    def decode(self, label):
+    def decode(self, label: bytes) -> str:
         if not self.strict_decode:
             return super().decode(label)
         if label == b'':
@@ -223,9 +288,6 @@ class IDNA2008Codec(IDNACodec):
         except (idna.IDNAError, UnicodeError) as e:
             raise IDNAException(idna_exception=e)
 
-_escaped = b'"().;\\@$'
-_escaped_text = '"().;\\@$'
-
 IDNA_2003_Practical = IDNA2003Codec(False)
 IDNA_2003_Strict = IDNA2003Codec(True)
 IDNA_2003 = IDNA_2003_Practical
@@ -235,35 +297,7 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
 IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
 IDNA_2008 = IDNA_2008_Practical
 
-def _escapify(label):
-    """Escape the characters in label which need it.
-    @returns: the escaped string
-    @rtype: string"""
-    if isinstance(label, bytes):
-        # Ordinary DNS label mode.  Escape special characters and values
-        # < 0x20 or > 0x7f.
-        text = ''
-        for c in label:
-            if c in _escaped:
-                text += '\\' + chr(c)
-            elif c > 0x20 and c < 0x7F:
-                text += chr(c)
-            else:
-                text += '\\%03d' % c
-        return text
-
-    # Unicode label mode.  Escape only special characters and values < 0x20
-    text = ''
-    for c in label:
-        if c in _escaped_text:
-            text += '\\' + c
-        elif c <= '\x20':
-            text += '\\%03d' % ord(c)
-        else:
-            text += c
-    return text
-
-def _validate_labels(labels):
+def _validate_labels(labels: Tuple[bytes, ...]):
     """Check for empty labels in the middle of a label sequence,
     labels that are too long, and for too many labels.
 
@@ -293,7 +327,7 @@ def _validate_labels(labels):
         raise EmptyLabel
 
 
-def _maybe_convert_to_binary(label):
+def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
     """If label is ``str``, convert it to ``bytes``.  If it is already
     ``bytes`` just return it.
 
@@ -318,12 +352,12 @@ class Name:
 
     __slots__ = ['labels']
 
-    def __init__(self, labels):
+    def __init__(self, labels: Iterable[Union[bytes, str]]):
         """*labels* is any iterable whose values are ``str`` or ``bytes``.
         """
 
-        labels = [_maybe_convert_to_binary(x) for x in labels]
-        self.labels = tuple(labels)
+        blabels = [_maybe_convert_to_binary(x) for x in labels]
+        self.labels = tuple(blabels)
         _validate_labels(self.labels)
 
     def __copy__(self):
@@ -340,7 +374,7 @@ class Name:
         super().__setattr__('labels', state['labels'])
         _validate_labels(self.labels)
 
-    def is_absolute(self):
+    def is_absolute(self) -> bool:
         """Is the most significant label of this name the root label?
 
         Returns a ``bool``.
@@ -348,7 +382,7 @@ class Name:
 
         return len(self.labels) > 0 and self.labels[-1] == b''
 
-    def is_wild(self):
+    def is_wild(self) -> bool:
         """Is this name wild?  (I.e. Is the least significant label '*'?)
 
         Returns a ``bool``.
@@ -356,7 +390,7 @@ class Name:
 
         return len(self.labels) > 0 and self.labels[0] == b'*'
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         """Return a case-insensitive hash of the name.
 
         Returns an ``int``.
@@ -368,14 +402,14 @@ class Name:
                 h += (h << 3) + c
         return h
 
-    def fullcompare(self, other):
+    def fullcompare(self, other: 'Name') -> Tuple[NameRelation, int, int]:
         """Compare two names, returning a 3-tuple
         ``(relation, order, nlabels)``.
 
         *relation* describes the relation ship between the names,
-        and is one of: ``dns.name.NAMERELN_NONE``,
-        ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``,
-        ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``.
+        and is one of: ``dns.name.NameRelation.NONE``,
+        ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``,
+        ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``.
 
         *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and ==
         0 if *self* == *other*.  A relative name is always less than an
@@ -404,9 +438,9 @@ class Name:
         oabs = other.is_absolute()
         if sabs != oabs:
             if sabs:
-                return (NAMERELN_NONE, 1, 0)
+                return (NameRelation.NONE, 1, 0)
             else:
-                return (NAMERELN_NONE, -1, 0)
+                return (NameRelation.NONE, -1, 0)
         l1 = len(self.labels)
         l2 = len(other.labels)
         ldiff = l1 - l2
@@ -417,7 +451,7 @@ class Name:
 
         order = 0
         nlabels = 0
-        namereln = NAMERELN_NONE
+        namereln = NameRelation.NONE
         while l > 0:
             l -= 1
             l1 -= 1
@@ -427,24 +461,24 @@ class Name:
             if label1 < label2:
                 order = -1
                 if nlabels > 0:
-                    namereln = NAMERELN_COMMONANCESTOR
+                    namereln = NameRelation.COMMONANCESTOR
                 return (namereln, order, nlabels)
             elif label1 > label2:
                 order = 1
                 if nlabels > 0:
-                    namereln = NAMERELN_COMMONANCESTOR
+                    namereln = NameRelation.COMMONANCESTOR
                 return (namereln, order, nlabels)
             nlabels += 1
         order = ldiff
         if ldiff < 0:
-            namereln = NAMERELN_SUPERDOMAIN
+            namereln = NameRelation.SUPERDOMAIN
         elif ldiff > 0:
-            namereln = NAMERELN_SUBDOMAIN
+            namereln = NameRelation.SUBDOMAIN
         else:
-            namereln = NAMERELN_EQUAL
+            namereln = NameRelation.EQUAL
         return (namereln, order, nlabels)
 
-    def is_subdomain(self, other):
+    def is_subdomain(self, other: 'Name') -> bool:
         """Is self a subdomain of other?
 
         Note that the notion of subdomain includes equality, e.g.
@@ -454,11 +488,11 @@ class Name:
         """
 
         (nr, _, _) = self.fullcompare(other)
-        if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
+        if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
             return True
         return False
 
-    def is_superdomain(self, other):
+    def is_superdomain(self, other: 'Name') -> bool:
         """Is self a superdomain of other?
 
         Note that the notion of superdomain includes equality, e.g.
@@ -468,11 +502,11 @@ class Name:
         """
 
         (nr, _, _) = self.fullcompare(other)
-        if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
+        if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
             return True
         return False
 
-    def canonicalize(self):
+    def canonicalize(self) -> 'Name':
         """Return a name which is equal to the current name, but is in
         DNSSEC canonical form.
         """
@@ -521,7 +555,7 @@ class Name:
     def __str__(self):
         return self.to_text(False)
 
-    def to_text(self, omit_final_dot=False):
+    def to_text(self, omit_final_dot=False) -> str:
         """Convert name to DNS text format.
 
         *omit_final_dot* is a ``bool``.  If True, don't emit the final
@@ -542,7 +576,7 @@ class Name:
         s = '.'.join(map(_escapify, l))
         return s
 
-    def to_unicode(self, omit_final_dot=False, idna_codec=None):
+    def to_unicode(self, omit_final_dot=False, idna_codec: Optional[IDNACodec]=None) -> str:
         """Convert name to Unicode text format.
 
         IDN ACE labels are converted to Unicode.
@@ -572,7 +606,7 @@ class Name:
             idna_codec = IDNA_2003_Practical
         return '.'.join([idna_codec.decode(x) for x in l])
 
-    def to_digestable(self, origin=None):
+    def to_digestable(self, origin: Optional['Name']=None) -> bytes:
         """Convert name to a format suitable for digesting in hashes.
 
         The name is canonicalized and converted to uncompressed wire
@@ -589,10 +623,12 @@ class Name:
         Returns a ``bytes``.
         """
 
-        return self.to_wire(origin=origin, canonicalize=True)
+        digest = self.to_wire(origin=origin, canonicalize=True)
+        assert digest is not None
+        return digest
 
-    def to_wire(self, file=None, compress=None, origin=None,
-                canonicalize=False):
+    def to_wire(self, file=None, compress: Optional[CompressType]=None,
+                origin: Optional['Name']=None, canonicalize=False) -> Optional[bytes]:
         """Convert name to wire format, possibly compressing it.
 
         *file* is the file where the name is emitted (typically an
@@ -638,6 +674,7 @@ class Name:
                         out += label
             return bytes(out)
 
+        labels: Iterable[bytes]
         if not self.is_absolute():
             if origin is None or not origin.is_absolute():
                 raise NeedAbsoluteNameOrOrigin
@@ -670,8 +707,9 @@ class Name:
                         file.write(label.lower())
                     else:
                         file.write(label)
+        return None
 
-    def __len__(self):
+    def __len__(self) -> int:
         """The length of the name (in labels).
 
         Returns an ``int``.
@@ -688,7 +726,7 @@ class Name:
     def __sub__(self, other):
         return self.relativize(other)
 
-    def split(self, depth):
+    def split(self, depth: int) -> Tuple['Name', 'Name']:
         """Split a name into a prefix and suffix names at the specified depth.
 
         *depth* is an ``int`` specifying the number of labels in the suffix
@@ -709,7 +747,7 @@ class Name:
                 'depth must be >= 0 and <= the length of the name')
         return (Name(self[: -depth]), Name(self[-depth:]))
 
-    def concatenate(self, other):
+    def concatenate(self, other: 'Name') -> 'Name':
         """Return a new name which is the concatenation of self and other.
 
         Raises ``dns.name.AbsoluteConcatenation`` if the name is
@@ -724,7 +762,7 @@ class Name:
         labels.extend(list(other.labels))
         return Name(labels)
 
-    def relativize(self, origin):
+    def relativize(self, origin: 'Name') -> 'Name':
         """If the name is a subdomain of *origin*, return a new name which is
         the name relative to origin.  Otherwise return the name.
 
@@ -740,7 +778,7 @@ class Name:
         else:
             return self
 
-    def derelativize(self, origin):
+    def derelativize(self, origin: 'Name') -> 'Name':
         """If the name is a relative name, return a new name which is the
         concatenation of the name and origin.  Otherwise return the name.
 
@@ -756,7 +794,7 @@ class Name:
         else:
             return self
 
-    def choose_relativity(self, origin=None, relativize=True):
+    def choose_relativity(self, origin: Optional['Name']=None, relativize=True) -> 'Name':
         """Return a name with the relativity desired by the caller.
 
         If *origin* is ``None``, then the name is returned.
@@ -775,7 +813,7 @@ class Name:
         else:
             return self
 
-    def parent(self):
+    def parent(self) -> 'Name':
         """Return the parent of the name.
 
         For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
@@ -796,7 +834,7 @@ root = Name([b''])
 #: The empty name.
 empty = Name([])
 
-def from_unicode(text, origin=root, idna_codec=None):
+def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name:
     """Convert unicode text into a Name object.
 
     Labels are encoded in IDN ACE form according to rules specified by
@@ -870,16 +908,16 @@ def from_unicode(text, origin=root, idna_codec=None):
         labels.extend(list(origin.labels))
     return Name(labels)
 
-def is_all_ascii(text):
+def is_all_ascii(text: str) -> bool:
     for c in text:
         if ord(c) > 0x7f:
             return False
     return True
 
-def from_text(text, origin=root, idna_codec=None):
+def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name:
     """Convert text into a Name object.
 
-    *text*, a ``str``, is the text to convert into a name.
+    *text*, a ``bytes`` or ``str``, is the text to convert into a name.
 
     *origin*, a ``dns.name.Name``, specifies the origin to
     append to non-absolute names.  The default is the root name.
@@ -958,8 +996,9 @@ def from_text(text, origin=root, idna_codec=None):
         labels.extend(list(origin.labels))
     return Name(labels)
 
+# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
 
-def from_wire_parser(parser):
+def from_wire_parser(parser: 'dns.wire.Parser') -> Name:
     """Convert possibly compressed wire format into a Name.
 
     *parser* is a dns.wire.Parser.
@@ -992,7 +1031,7 @@ def from_wire_parser(parser):
     return Name(labels)
 
 
-def from_wire(message, current):
+def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
     """Convert possibly compressed wire format into a Name.
 
     *message* is a ``bytes`` containing an entire DNS message in DNS
diff --git a/dns/name.pyi b/dns/name.pyi
deleted file mode 100644 (file)
index c48d4bd..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-from typing import Optional, Union, Tuple, Iterable, List
-
-have_idna_2008: bool
-
-class Name:
-    def is_subdomain(self, o : Name) -> bool: ...
-    def is_superdomain(self, o : Name) -> bool: ...
-    def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
-        self.labels : List[bytes]
-    def is_absolute(self) -> bool: ...
-    def is_wild(self) -> bool: ...
-    def fullcompare(self, other) -> Tuple[int,int,int]: ...
-    def canonicalize(self) -> Name: ...
-    def __eq__(self, other) -> bool: ...
-    def __ne__(self, other) -> bool: ...
-    def __lt__(self, other : Name) -> bool: ...
-    def __le__(self, other : Name) -> bool: ...
-    def __ge__(self, other : Name) -> bool: ...
-    def __gt__(self, other : Name) -> bool: ...
-    def to_text(self, omit_final_dot=False) -> str: ...
-    def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
-    def to_digestable(self, origin=None) -> bytes: ...
-    def to_wire(self, file=None, compress=None, origin=None,
-                canonicalize=False) -> Optional[bytes]: ...
-    def __add__(self, other : Name) -> Name: ...
-    def __sub__(self, other : Name) -> Name: ...
-    def split(self, depth) -> List[Tuple[str,str]]: ...
-    def concatenate(self, other : Name) -> Name: ...
-    def relativize(self, origin) -> Name: ...
-    def derelativize(self, origin) -> Name: ...
-    def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
-    def parent(self) -> Name: ...
-
-class IDNACodec:
-    pass
-
-def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
-    ...
-
-empty : Name
index 63ce008b938c0f1e4cb87eef2a4b2083d22cfba7..a4c17f966ea074a031d67fbc63a151e4948bbe85 100644 (file)
 
 """DNS nodes.  A node is a set of rdatasets."""
 
+from typing import List, Optional, Union
+
 import enum
 import io
 
 import dns.immutable
+import dns.name
+import dns.rdataclass
 import dns.rdataset
 import dns.rdatatype
+import dns.rrset
 import dns.renderer
 
 
@@ -51,7 +56,7 @@ class NodeKind(enum.Enum):
     CNAME = 2
 
     @classmethod
-    def classify(cls, rdtype, covers):
+    def classify(cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType) -> 'NodeKind':
         if _matches_type_or_its_signature(_cname_types, rdtype, covers):
             return NodeKind.CNAME
         elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
@@ -60,7 +65,7 @@ class NodeKind(enum.Enum):
             return NodeKind.REGULAR
 
     @classmethod
-    def classify_rdataset(cls, rdataset):
+    def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> 'NodeKind':
         return cls.classify(rdataset.rdtype, rdataset.covers)
 
 
@@ -85,15 +90,15 @@ class Node:
 
     def __init__(self):
         # the set of rdatasets, represented as a list.
-        self.rdatasets = []
+        self.rdatasets: List[dns.rdataset.Rdataset] = []
 
-    def to_text(self, name, **kw):
+    def to_text(self, name: dns.name.Name, **kw) -> str:
         """Convert a node to text format.
 
         Each rdataset at the node is printed.  Any keyword arguments
         to this method are passed on to the rdataset's to_text() method.
 
-        *name*, a ``dns.name.Name`` or ``str``, the owner name of the
+        *name*, a ``dns.name.Name``, the owner name of the
         rdatasets.
 
         Returns a ``str``.
@@ -155,16 +160,19 @@ class Node:
             # edit self.rdatasets.
         self.rdatasets.append(rdataset)
 
-    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
+    def find_rdataset(self,
+                      rdclass: dns.rdataclass.RdataClass,
+                      rdtype: dns.rdatatype.RdataType,
+                      covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                      create=False) -> dns.rdataset.Rdataset:
         """Find an rdataset matching the specified properties in the
         current node.
 
-        *rdclass*, an ``int``, the class of the rdataset.
+        *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
 
-        *rdtype*, an ``int``, the type of the rdataset.
+        *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
 
-        *covers*, an ``int`` or ``None``, the covered type.
+        *covers*, a ``dns.rdatatype.RdataType``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -191,8 +199,11 @@ class Node:
         self._append_rdataset(rds)
         return rds
 
-    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
+    def get_rdataset(self,
+                     rdclass: dns.rdataclass.RdataClass,
+                     rdtype: dns.rdatatype.RdataType,
+                     covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                     create=False) -> Optional[dns.rdataset.Rdataset]:
         """Get an rdataset matching the specified properties in the
         current node.
 
@@ -223,7 +234,10 @@ class Node:
             rds = None
         return rds
 
-    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+    def delete_rdataset(self,
+                        rdclass: dns.rdataclass.RdataClass,
+                        rdtype: dns.rdatatype.RdataType,
+                        covers: dns.rdatatype.RdataType=dns.rdatatype.NONE):
         """Delete the rdataset matching the specified properties in the
         current node.
 
@@ -240,7 +254,7 @@ class Node:
         if rds is not None:
             self.rdatasets.remove(rds)
 
-    def replace_rdataset(self, replacement):
+    def replace_rdataset(self, replacement: dns.rdataset.Rdataset):
         """Replace an rdataset.
 
         It is not an error if there is no rdataset matching *replacement*.
@@ -265,7 +279,7 @@ class Node:
                              replacement.covers)
         self._append_rdataset(replacement)
 
-    def classify(self):
+    def classify(self) -> NodeKind:
         """Classify a node.
 
         A node which contains a CNAME or RRSIG(CNAME) is a
@@ -286,7 +300,7 @@ class Node:
                 return kind
         return NodeKind.NEUTRAL
 
-    def is_immutable(self):
+    def is_immutable(self) -> bool:
         return False
 
 
@@ -316,5 +330,5 @@ class ImmutableNode(Node):
     def replace_rdataset(self, replacement):
         raise TypeError("immutable")
 
-    def is_immutable(self):
+    def is_immutable(self) -> bool:
         return True
diff --git a/dns/node.pyi b/dns/node.pyi
deleted file mode 100644 (file)
index 0997edf..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import List, Optional, Union
-from . import rdataset, rdatatype, name
-class Node:
-    def __init__(self):
-        self.rdatasets : List[rdataset.Rdataset]
-    def to_text(self, name : Union[str,name.Name], **kw) -> str:
-        ...
-    def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
-                      create=False) -> rdataset.Rdataset:
-        ...
-    def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
-                     create=False) -> Optional[rdataset.Rdataset]:
-        ...
-    def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
-        ...
-    def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
-        ...
index 5cf6143c710f25e95ac8149b5e9c4fe6812a5b6d..971b62c8a29f3096335020357ad0991ef37e28a5 100644 (file)
@@ -45,7 +45,7 @@ class UnknownOpcode(dns.exception.DNSException):
     """An DNS opcode is unknown."""
 
 
-def from_text(text):
+def from_text(text: str) -> Opcode:
     """Convert text into an opcode.
 
     *text*, a ``str``, the textual opcode
@@ -58,7 +58,7 @@ def from_text(text):
     return Opcode.from_text(text)
 
 
-def from_flags(flags):
+def from_flags(flags: int) -> Opcode:
     """Extract an opcode from DNS message flags.
 
     *flags*, an ``int``, the DNS flags.
@@ -66,10 +66,10 @@ def from_flags(flags):
     Returns an ``int``.
     """
 
-    return (flags & 0x7800) >> 11
+    return Opcode((flags & 0x7800) >> 11)
 
 
-def to_flags(value):
+def to_flags(value: Opcode) -> int:
     """Convert an opcode to a value suitable for ORing into DNS message
     flags.
 
@@ -81,7 +81,7 @@ def to_flags(value):
     return (value << 11) & 0x7800
 
 
-def to_text(value):
+def to_text(value: Opcode) -> str:
     """Convert an opcode to text.
 
     *value*, an ``int`` the opcode value,
@@ -94,7 +94,7 @@ def to_text(value):
     return Opcode.to_text(value)
 
 
-def is_update(flags):
+def is_update(flags: int) -> bool:
     """Is the opcode in flags UPDATE?
 
     *flags*, an ``int``, the DNS message flags.
index 19894df65913bde5b3da704d74f174040f182799..e2dca20dcadcbd22f463e590ae67c23451f98c27 100644 (file)
@@ -17,6 +17,8 @@
 
 """Talk to a DNS server."""
 
+from typing import Any, Dict, Optional, Tuple, Union
+
 import base64
 import contextlib
 import enum
@@ -37,6 +39,8 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.serial
+import dns.transaction
+import dns.tsig
 import dns.xfr
 
 try:
@@ -74,6 +78,9 @@ except ImportError:  # pragma: no cover
         class WantWriteException(Exception):
             pass
 
+        class SSLContext:
+            pass
+
         class SSLSocket:
             pass
 
@@ -149,9 +156,12 @@ if hasattr(selectors, 'PollSelector'):
     # Prefer poll() on platforms that support it because it has no
     # limits on the maximum value of a file descriptor (plus it will
     # be more efficient for high values).
-    _selector_class = selectors.PollSelector
+    #
+    # We ignore typing here as we can't say _selector_class is Any
+    # on python < 3.8 due to a bug.
+    _selector_class = selectors.PollSelector  # type: ignore
 else:
-    _selector_class = selectors.SelectSelector  # pragma: no cover
+    _selector_class = selectors.SelectSelector  # type: ignore
 
 
 def _wait_for_readable(s, expiration):
@@ -248,10 +258,11 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
         s.close()
         raise
 
-def https(q, where, timeout=None, port=443, source=None, source_port=0,
+def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+          port=443, source: Optional[str]=None, source_port=0,
           one_rr_per_rrset=False, ignore_trailing=False,
-          session=None, path='/dns-query', post=True,
-          bootstrap_address=None, verify=True):
+          session: Optional[Any]=None, path='/dns-query', post=True,
+          bootstrap_address: Optional[str]=None, verify=True) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
     *q*, a ``dns.message.Message``, the query to send.
@@ -314,6 +325,8 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
     elif bootstrap_address is not None:
         _httpx_ok = False
         split_url = urllib.parse.urlsplit(where)
+        if split_url.hostname is None:
+            raise ValueError('DoH URL has no hostname')
         headers['Host'] = split_url.hostname
         url = where.replace(split_url.hostname, bootstrap_address)
         if _have_requests:
@@ -374,10 +387,10 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
         else:
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
             if _is_httpx:
-                wire = wire.decode()  # httpx does a repr() if we give it bytes
+                twire = wire.decode()  # httpx does a repr() if we give it bytes
                 response = session.get(url, headers=headers,
                                        timeout=timeout,
-                                       params={"dns": wire})
+                                       params={"dns": twire})
             else:
                 response = session.get(url, headers=headers,
                                        timeout=timeout, verify=verify,
@@ -395,7 +408,7 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
                               request_mac=q.request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
                               ignore_trailing=ignore_trailing)
-    r.time = response.elapsed
+    r.time = response.elapsed.total_seconds()
     if not q.is_response(r):
         raise BadResponse
     return r
@@ -427,7 +440,8 @@ def _udp_send(sock, data, destination, expiration):
             _wait_for_writable(sock, expiration)
 
 
-def send_udp(sock, what, destination, expiration=None):
+def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: Any,
+             expiration: Optional[float]=None) -> Tuple[int, float]:
     """Send a DNS message to the specified UDP socket.
 
     *sock*, a ``socket``.
@@ -451,10 +465,10 @@ def send_udp(sock, what, destination, expiration=None):
     return (n, sent_time)
 
 
-def receive_udp(sock, destination=None, expiration=None,
+def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None,
                 ignore_unexpected=False, one_rr_per_rrset=False,
-                keyring=None, request_mac=b'', ignore_trailing=False,
-                raise_on_truncation=False):
+                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+                ignore_trailing=False, raise_on_truncation=False) -> Any:
     """Read a DNS message from a UDP socket.
 
     *sock*, a ``socket``.
@@ -512,9 +526,10 @@ def receive_udp(sock, destination=None, expiration=None,
     else:
         return (r, received_time, from_address)
 
-def udp(q, where, timeout=None, port=53, source=None, source_port=0,
+def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+        source: Optional[str]=None, source_port=0,
         ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
-        raise_on_truncation=False, sock=None):
+        raise_on_truncation=False, sock: Optional[Any]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via UDP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -571,11 +586,13 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
         if not q.is_response(r):
             raise BadResponse
         return r
+    assert False  # help mypy figure out we can't get here
 
-def udp_with_fallback(q, where, timeout=None, port=53, source=None,
-                      source_port=0, ignore_unexpected=False,
-                      one_rr_per_rrset=False, ignore_trailing=False,
-                      udp_sock=None, tcp_sock=None):
+def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+                      source: Optional[str]=None, source_port=0,
+                      ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+                      udp_sock: Optional[Any]=None,
+                      tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]:
     """Return the response to the query, trying UDP first and falling back
     to TCP if UDP results in a truncated response.
 
@@ -665,7 +682,8 @@ def _net_write(sock, data, expiration):
             _wait_for_readable(sock, expiration)
 
 
-def send_tcp(sock, what, expiration=None):
+def send_tcp(sock: Any, what: Union[dns.message.Message, bytes],
+             expiration: Optional[float]=None) -> Tuple[int, float]:
     """Send a DNS message to the specified TCP socket.
 
     *sock*, a ``socket``.
@@ -680,18 +698,21 @@ def send_tcp(sock, what, expiration=None):
     """
 
     if isinstance(what, dns.message.Message):
-        what = what.to_wire()
-    l = len(what)
+        wire = what.to_wire()
+    else:
+        wire = what
+    l = len(wire)
     # copying the wire into tcpmsg is inefficient, but lets us
     # avoid writev() or doing a short write that would get pushed
     # onto the net
-    tcpmsg = struct.pack("!H", l) + what
+    tcpmsg = struct.pack("!H", l) + wire
     sent_time = time.time()
     _net_write(sock, tcpmsg, expiration)
     return (len(tcpmsg), sent_time)
 
-def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
-                keyring=None, request_mac=b'', ignore_trailing=False):
+def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=False,
+                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
+                ignore_trailing=False) -> Tuple[dns.message.Message, float]:
     """Read a DNS message from a TCP socket.
 
     *sock*, a ``socket``.
@@ -737,8 +758,9 @@ def _connect(s, address, expiration):
         raise OSError(err, os.strerror(err))
 
 
-def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False, sock=None):
+def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
+        source: Optional[str]=None, source_port=0,
+        one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[Any]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TCP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -790,6 +812,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
         if not q.is_response(r):
             raise BadResponse
         return r
+    assert False  # help mypy figure out we can't get here
 
 
 def _tls_handshake(s, expiration):
@@ -803,9 +826,11 @@ def _tls_handshake(s, expiration):
             _wait_for_writable(s, expiration)
 
 
-def tls(q, where, timeout=None, port=853, source=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False, sock=None,
-        ssl_context=None, server_hostname=None):
+def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
+        port=853, source: Optional[str]=None, source_port=0,
+        one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[ssl.SSLSocket]=None,
+        ssl_context: Optional[ssl.SSLContext]=None,
+        server_hostname: Optional[str]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TLS.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -885,7 +910,7 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
         if not q.is_response(r):
             raise BadResponse
         return r
-
+    assert False  # help mypy figure out we can't get here
 
 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
         timeout=None, port=53, keyring=None, keyname=None, relativize=True,
@@ -1066,9 +1091,10 @@ class UDPMode(enum.IntEnum):
     ONLY = 2
 
 
-def inbound_xfr(where, txn_manager, query=None,
-                port=53, timeout=None, lifetime=None, source=None,
-                source_port=0, udp_mode=UDPMode.NEVER):
+def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
+                query: Optional[dns.message.Message]=None,
+                port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+                source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER):
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
diff --git a/dns/query.pyi b/dns/query.pyi
deleted file mode 100644 (file)
index a22e229..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-from typing import Optional, Union, Dict, Generator, Any
-from . import tsig, rdatatype, rdataclass, name, message
-from requests.sessions import Session
-
-import socket
-
-# If the ssl import works, then
-#
-#    error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
-    import ssl
-except ImportError:
-    class ssl:    # type: ignore
-        SSLContext : Dict = {}
-
-have_doh: bool
-
-def https(q : message.Message, where: str, timeout : Optional[float] = None,
-          port : Optional[int] = 443, source : Optional[str] = None,
-          source_port : Optional[int] = 0,
-          session: Optional[Session] = None,
-          path : Optional[str] = '/dns-query', post : Optional[bool] = True,
-          bootstrap_address : Optional[str] = None,
-          verify : Optional[bool] = True) -> message.Message:
-    pass
-
-def tcp(q : message.Message, where : str, timeout : float = None, port=53,
-        af : Optional[int] = None, source : Optional[str] = None,
-        source_port : Optional[int] = 0,
-        one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False,
-        sock : Optional[socket.socket] = None) -> message.Message:
-    pass
-
-def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
-        rdclass=rdataclass.IN,
-        timeout : Optional[float] = None, port=53,
-        keyring : Optional[Dict[name.Name, bytes]] = None,
-        keyname : Union[str,name.Name]= None, relativize=True,
-        lifetime : Optional[float] = None,
-        source : Optional[str] = None, source_port=0, serial=0,
-        use_udp : Optional[bool] = False,
-        keyalgorithm=tsig.default_algorithm) \
-        -> Generator[Any,Any,message.Message]:
-    pass
-
-def udp(q : message.Message, where : str, timeout : Optional[float] = None,
-        port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
-        ignore_unexpected : Optional[bool] = False,
-        one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False,
-        sock : Optional[socket.socket] = None) -> message.Message:
-    pass
-
-def tls(q : message.Message, where : str, timeout : Optional[float] = None,
-        port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
-        one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False,
-        sock : Optional[socket.socket] = None,
-        ssl_context: Optional[ssl.SSLContext] = None,
-        server_hostname: Optional[str] = None) -> message.Message:
-    pass
index 6b5b5c5a39eb31cab89b2607c06d00205e9b69b6..1e1992be8303e61a61d79f88f97229c6c282c9ea 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS rdata."""
 
+from typing import Any, Dict, Optional, Tuple, Union
+
 from importlib import import_module
 import base64
 import binascii
@@ -137,7 +139,7 @@ class Rdata:
 
         self.rdclass = self._as_rdataclass(rdclass)
         self.rdtype = self._as_rdatatype(rdtype)
-        self.rdcomment = None
+        self.rdcomment: Optional[str] = None
 
     def _get_all_slots(self):
         return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
@@ -165,7 +167,7 @@ class Rdata:
             # it if needed.
             object.__setattr__(self, 'rdcomment', None)
 
-    def covers(self):
+    def covers(self) -> dns.rdatatype.RdataType:
         """Return the type a Rdata covers.
 
         DNS SIG/RRSIG rdatas apply to a specific type; this type is
@@ -174,12 +176,12 @@ class Rdata:
         creating rdatasets, allowing the rdataset to contain only RRSIGs
         of a particular type, e.g. RRSIG(NS).
 
-        Returns an ``int``.
+        Returns a ``dns.rdatatype.RdataType``.
         """
 
         return dns.rdatatype.NONE
 
-    def extended_rdatatype(self):
+    def extended_rdatatype(self) -> int:
         """Return a 32-bit type value, the least significant 16 bits of
         which are the ordinary DNS type, and the upper 16 bits of which are
         the "covered" type, if any.
@@ -189,7 +191,7 @@ class Rdata:
 
         return self.covers() << 16 | self.rdtype
 
-    def to_text(self, origin=None, relativize=True, **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw):
         """Convert an rdata to text format.
 
         Returns a ``str``.
@@ -197,11 +199,12 @@ class Rdata:
 
         raise NotImplementedError  # pragma: no cover
 
-    def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+    def _to_wire(self, file, compress: Optional[dns.name.CompressType]=None,
+                 origin: Optional[dns.name.Name]=None, canonicalize=False):
         raise NotImplementedError  # pragma: no cover
 
     def to_wire(self, file=None, compress=None, origin=None,
-                canonicalize=False):
+                canonicalize=False) -> bytes:
         """Convert an rdata to wire format.
 
         Returns a ``bytes`` or ``None``.
@@ -214,7 +217,7 @@ class Rdata:
             self._to_wire(f, compress, origin, canonicalize)
             return f.getvalue()
 
-    def to_generic(self, origin=None):
+    def to_generic(self, origin: Optional[dns.name.Name]=None) -> 'dns.rdata.GenericRdata':
         """Creates a dns.rdata.GenericRdata equivalent of this rdata.
 
         Returns a ``dns.rdata.GenericRdata``.
@@ -222,7 +225,7 @@ class Rdata:
         return dns.rdata.GenericRdata(self.rdclass, self.rdtype,
                                       self.to_wire(origin=origin))
 
-    def to_digestable(self, origin=None):
+    def to_digestable(self, origin: Optional[dns.name.Name]=None) -> bytes:
         """Convert rdata to a format suitable for digesting in hashes.  This
         is also the DNSSEC canonical form.
 
@@ -348,12 +351,16 @@ class Rdata:
         return hash(self.to_digestable(dns.name.root))
 
     @classmethod
-    def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
-                  relativize_to=None):
+    def from_text(cls, rdclass: dns.rdataclass.RdataClass,
+                  rdtype: dns.rdatatype.RdataType,
+                  tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize=True,
+                  relativize_to: Optional[dns.name.Name]=None):
         raise NotImplementedError  # pragma: no cover
 
     @classmethod
-    def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+    def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass,
+                         rdtype: dns.rdatatype.RdataType,
+                         parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None):
         raise NotImplementedError  # pragma: no cover
 
     def replace(self, **kwargs):
@@ -408,18 +415,20 @@ class Rdata:
         return dns.rdatatype.RdataType.make(value)
 
     @classmethod
-    def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
+    def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True) -> bytes:
         if encode and isinstance(value, str):
-            value = value.encode()
+            bvalue = value.encode()
         elif isinstance(value, bytearray):
-            value = bytes(value)
-        elif not isinstance(value, bytes):
+            bvalue = bytes(value)
+        elif isinstance(value, bytes):
+            bvalue = value
+        else:
             raise ValueError('not bytes')
-        if max_length is not None and len(value) > max_length:
+        if max_length is not None and len(bvalue) > max_length:
             raise ValueError('too long')
-        if not empty_ok and len(value) == 0:
+        if not empty_ok and len(bvalue) == 0:
             raise ValueError('empty bytes not allowed')
-        return value
+        return bvalue
 
     @classmethod
     def _as_name(cls, value):
@@ -571,7 +580,7 @@ class GenericRdata(Rdata):
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
         return cls(rdclass, rdtype, parser.get_remaining())
 
-_rdata_classes = {}
+_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = {}
 _module_prefix = 'dns.rdtypes'
 
 def get_rdata_class(rdclass, rdtype):
@@ -602,8 +611,12 @@ def get_rdata_class(rdclass, rdtype):
     return cls
 
 
-def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
-              relativize_to=None, idna_codec=None):
+def from_text(rdclass: Union[dns.rdataclass.RdataClass, str],
+              rdtype: Union[dns.rdatatype.RdataType, str],
+              tok: Union[dns.tokenizer.Tokenizer, str],
+              origin: Optional[dns.name.Name]=None,
+              relativize=True, relativize_to: Optional[dns.name.Name]=None,
+              idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata:
     """Build an rdata object from text format.
 
     This function attempts to dynamically load a class which
@@ -617,9 +630,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
     If *tok* is a ``str``, then a tokenizer is created and the string
     is used as its input.
 
-    *rdclass*, an ``int``, the rdataclass.
+    *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
 
-    *rdtype*, an ``int``, the rdatatype.
+    *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
 
     *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
 
@@ -681,7 +694,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
         return rdata
 
 
-def from_wire_parser(rdclass, rdtype, parser, origin=None):
+def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str],
+                     rdtype: Union[dns.rdatatype.RdataType, str],
+                     parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> Rdata:
     """Build an rdata object from wire format
 
     This function attempts to dynamically load a class which
@@ -692,9 +707,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
     Once a class is chosen, its from_wire() class method is called
     with the parameters to this function.
 
-    *rdclass*, an ``int``, the rdataclass.
+    *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
 
-    *rdtype*, an ``int``, the rdatatype.
+    *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
 
     *parser*, a ``dns.wire.Parser``, the parser, which should be
     restricted to the rdata length.
@@ -712,7 +727,10 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
         return cls.from_wire_parser(rdclass, rdtype, parser, origin)
 
 
-def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
+def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str],
+              rdtype: Union[dns.rdatatype.RdataType, str],
+              wire: bytes, current: int, rdlen: int,
+              origin: Optional[dns.name.Name]=None) -> Rdata:
     """Build an rdata object from wire format
 
     This function attempts to dynamically load a class which
diff --git a/dns/rdata.pyi b/dns/rdata.pyi
deleted file mode 100644 (file)
index f394791..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-from typing import Dict, Tuple, Any, Optional, BinaryIO
-from .name import Name, IDNACodec
-class Rdata:
-    def __init__(self):
-        self.address : str
-    def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
-        ...
-    @classmethod
-    def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
-        ...
-_rdata_modules : Dict[Tuple[Any,Rdata],Any]
-
-def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
-              relativize : bool = True, relativize_to : Optional[Name] = None,
-              idna_codec : Optional[IDNACodec] = None):
-    ...
-
-def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
-    ...
index e6e954804f3dda27c61b26c7c350d84c1583a985..218adba3e696fa28c0661a383066791c6ef1681d 100644 (file)
 
 """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
 
+from typing import Any, cast, Collection, Dict, List, Optional, Union
+
 import io
 import random
 import struct
 
 import dns.exception
 import dns.immutable
+import dns.name
 import dns.rdatatype
 import dns.rdataclass
 import dns.rdata
 import dns.set
+import dns.ttl
 
 # define SimpleSet here for backwards compatibility
 SimpleSet = dns.set.Set
@@ -47,22 +51,24 @@ class Rdataset(dns.set.Set):
 
     __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl']
 
-    def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0):
+    def __init__(self, rdclass: dns.rdataclass.RdataClass,
+                 rdtype: dns.rdatatype.RdataType,
+                 covers=dns.rdatatype.NONE, ttl=0):
         """Create a new rdataset of the specified class and type.
 
-        *rdclass*, an ``int``, the rdataclass.
+        *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
 
-        *rdtype*, an ``int``, the rdatatype.
+        *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
 
-        *covers*, an ``int``, the covered rdatatype.
+        *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
 
         *ttl*, an ``int``, the TTL.
         """
 
         super().__init__()
         self.rdclass = rdclass
-        self.rdtype = rdtype
-        self.covers = covers
+        self.rdtype: dns.rdatatype.RdataType = rdtype
+        self.covers: dns.rdatatype.RdataType = covers
         self.ttl = ttl
 
     def _clone(self):
@@ -73,7 +79,7 @@ class Rdataset(dns.set.Set):
         obj.ttl = self.ttl
         return obj
 
-    def update_ttl(self, ttl):
+    def update_ttl(self, ttl: int):
         """Perform TTL minimization.
 
         Set the TTL of the rdataset to be the lesser of the set's current
@@ -88,7 +94,7 @@ class Rdataset(dns.set.Set):
         elif ttl < self.ttl:
             self.ttl = ttl
 
-    def add(self, rd, ttl=None):  # pylint: disable=arguments-differ
+    def add(self, rd, ttl: Optional[int]=None):  # pylint: disable=arguments-differ
         """Add the specified rdata to the rdataset.
 
         If the optional *ttl* parameter is supplied, then
@@ -176,8 +182,11 @@ class Rdataset(dns.set.Set):
     def __ne__(self, other):
         return not self.__eq__(other)
 
-    def to_text(self, name=None, origin=None, relativize=True,
-                override_rdclass=None, want_comments=False, **kw):
+    def to_text(self, name: Optional[dns.name.Name]=None,
+                origin: Optional[dns.name.Name]=None,
+                relativize=True,
+                override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+                want_comments=False, **kw) -> str:
         """Convert the rdataset into DNS zone file format.
 
         See ``dns.name.Name.choose_relativity`` for more information
@@ -241,8 +250,11 @@ class Rdataset(dns.set.Set):
         #
         return s.getvalue()[:-1]
 
-    def to_wire(self, name, file, compress=None, origin=None,
-                override_rdclass=None, want_shuffle=True):
+    def to_wire(self, name: dns.name.Name, file: Any,
+                compress: Optional[dns.name.CompressType]=None,
+                origin: Optional[dns.name.Name]=None,
+                override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+                want_shuffle=True) -> int:
         """Convert the rdataset to wire format.
 
         *name*, a ``dns.name.Name`` is the owner name to use.
@@ -279,6 +291,7 @@ class Rdataset(dns.set.Set):
             file.write(stuff)
             return 1
         else:
+            l: Union[Rdataset, List[dns.rdata.Rdata]]
             if want_shuffle:
                 l = list(self)
                 random.shuffle(l)
@@ -299,7 +312,9 @@ class Rdataset(dns.set.Set):
                 file.seek(0, io.SEEK_END)
             return len(self)
 
-    def match(self, rdclass, rdtype, covers):
+    def match(self, rdclass: dns.rdataclass.RdataClass,
+              rdtype: dns.rdatatype.RdataType,
+              covers: dns.rdatatype.RdataType) -> bool:
         """Returns ``True`` if this rdataset matches the specified class,
         type, and covers.
         """
@@ -309,7 +324,7 @@ class Rdataset(dns.set.Set):
             return True
         return False
 
-    def processing_order(self):
+    def processing_order(self) -> List[dns.rdata.Rdata]:
         """Return rdatas in a valid processing order according to the type's
         specification.  For example, MX records are in preference order from
         lowest to highest preferences, with items of the same preference
@@ -331,7 +346,7 @@ class ImmutableRdataset(Rdataset):  # lgtm[py/missing-equals]
 
     _clone_class = Rdataset
 
-    def __init__(self, rdataset):
+    def __init__(self, rdataset: Rdataset):
         """Create an immutable rdataset from the specified rdataset."""
 
         super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
@@ -394,8 +409,12 @@ class ImmutableRdataset(Rdataset):  # lgtm[py/missing-equals]
         return ImmutableRdataset(super().symmetric_difference(other))
 
 
-def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
-                   origin=None, relativize=True, relativize_to=None):
+def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str],
+                   rdtype: Union[dns.rdatatype.RdataType, str],
+                   ttl: int, text_rdatas: Collection[str],
+                   idna_codec: Optional[dns.name.IDNACodec]=None,
+                   origin: Optional[dns.name.Name]=None,
+                   relativize=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset:
     """Create an rdataset with the specified class, type, and TTL, and with
     the specified list of rdatas in text format.
 
@@ -414,9 +433,9 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
     Returns a ``dns.rdataset.Rdataset`` object.
     """
 
-    rdclass = dns.rdataclass.RdataClass.make(rdclass)
-    rdtype = dns.rdatatype.RdataType.make(rdtype)
-    r = Rdataset(rdclass, rdtype)
+    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+    r = Rdataset(the_rdclass, the_rdtype)
     r.update_ttl(ttl)
     for t in text_rdatas:
         rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
@@ -425,17 +444,19 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
     return r
 
 
-def from_text(rdclass, rdtype, ttl, *text_rdatas):
+def from_text(rdclass: Union[dns.rdataclass.RdataClass, str],
+              rdtype: Union[dns.rdatatype.RdataType, str],
+              ttl: int, *text_rdatas) -> Rdataset:
     """Create an rdataset with the specified class, type, and TTL, and with
     the specified rdatas in text format.
 
     Returns a ``dns.rdataset.Rdataset`` object.
     """
 
-    return from_text_list(rdclass, rdtype, ttl, text_rdatas)
+    return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
 
 
-def from_rdata_list(ttl, rdatas):
+def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
     """Create an rdataset with the specified TTL, and with
     the specified list of rdata objects.
 
@@ -450,14 +471,15 @@ def from_rdata_list(ttl, rdatas):
             r = Rdataset(rd.rdclass, rd.rdtype)
             r.update_ttl(ttl)
         r.add(rd)
+    assert r is not None
     return r
 
 
-def from_rdata(ttl, *rdatas):
+def from_rdata(ttl: int, *rdatas) -> Rdataset:
     """Create an rdataset with the specified TTL, and with
     the specified rdata objects.
 
     Returns a ``dns.rdataset.Rdataset`` object.
     """
 
-    return from_rdata_list(ttl, rdatas)
+    return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))
diff --git a/dns/rdataset.pyi b/dns/rdataset.pyi
deleted file mode 100644 (file)
index a7bbf2d..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-from typing import Optional, Dict, List, Union
-from io import BytesIO
-from . import exception, name, set, rdatatype, rdata, rdataset
-
-class DifferingCovers(exception.DNSException):
-    """An attempt was made to add a DNS SIG/RRSIG whose covered type
-    is not the same as that of the other rdatas in the rdataset."""
-
-
-class IncompatibleTypes(exception.DNSException):
-    """An attempt was made to add DNS RR data of an incompatible type."""
-
-
-class Rdataset(set.Set):
-    def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
-        self.rdclass : int = rdclass
-        self.rdtype : int = rdtype
-        self.covers : int = covers
-        self.ttl : int = ttl
-
-    def update_ttl(self, ttl : int) -> None:
-        ...
-
-    def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
-        ...
-
-    def union_update(self, other : Rdataset):
-        ...
-
-    def intersection_update(self, other : Rdataset):
-        ...
-
-    def update(self, other : Rdataset):
-        ...
-
-    def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
-                override_rdclass : Optional[int] =None, **kw) -> bytes:
-        ...
-
-    def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
-                override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
-        ...
-
-    def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
-        ...
-
-
-def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
-    ...
-
-def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
-    ...
-
-def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
-    ...
-
-def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
-    ...
index 9499c7b9b62d20e0ee385c9c54198c3b11235b74..80f8acaf16fa6e9d9ddafa58ef9a3e97775627d8 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Rdata Types."""
 
+from typing import Dict
+
 import dns.enum
 import dns.exception
 
@@ -120,8 +122,8 @@ class RdataType(dns.enum.IntEnum):
     def _unknown_exception_class(cls):
         return UnknownRdatatype
 
-_registered_by_text = {}
-_registered_by_value = {}
+_registered_by_text: Dict[str, RdataType] = {}
+_registered_by_value: Dict[RdataType, str] = {}
 
 _metatypes = {RdataType.OPT}
 
index f35ce3adf9de941dea0074cf9b841c13b6f5b229..f8990ebed1fd39d0e75b646d62d69a464a00fd46 100644 (file)
@@ -20,7 +20,7 @@ import base64
 
 import dns.exception
 import dns.immutable
-import dns.dnssec
+import dns.dnssectypes
 import dns.rdata
 import dns.tokenizer
 
@@ -85,7 +85,7 @@ class CERT(dns.rdata.Rdata):
     def to_text(self, origin=None, relativize=True, **kw):
         certificate_type = _ctype_to_text(self.certificate_type)
         return "%s %d %s %s" % (certificate_type, self.key_tag,
-                                dns.dnssec.algorithm_to_text(self.algorithm),
+                                dns.dnssectypes.Algorithm.to_text(self.algorithm),
                                 dns.rdata._base64ify(self.certificate, **kw))
 
     @classmethod
@@ -93,7 +93,7 @@ class CERT(dns.rdata.Rdata):
                   relativize_to=None):
         certificate_type = _ctype_from_text(tok.get_string())
         key_tag = tok.get_uint16()
-        algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
+        algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
         b64 = tok.concatenate_remaining_identifiers().encode()
         certificate = base64.b64decode(b64)
         return cls(rdclass, rdtype, certificate_type, key_tag,
index d050ccc6fb79008bbc3ef12004cf48d2de3f2d2f..82650c0f11ed7d36d1aefb6f29c6f583a6d8b11f 100644 (file)
@@ -20,7 +20,7 @@ import calendar
 import struct
 import time
 
-import dns.dnssec
+import dns.dnssectypes
 import dns.immutable
 import dns.exception
 import dns.rdata
@@ -65,7 +65,7 @@ class RRSIG(dns.rdata.Rdata):
                  signature):
         super().__init__(rdclass, rdtype)
         self.type_covered = self._as_rdatatype(type_covered)
-        self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+        self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
         self.labels = self._as_uint8(labels)
         self.original_ttl = self._as_ttl(original_ttl)
         self.expiration = self._as_uint32(expiration)
@@ -94,7 +94,7 @@ class RRSIG(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         type_covered = dns.rdatatype.from_text(tok.get_string())
-        algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
+        algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
         labels = tok.get_int()
         original_ttl = tok.get_ttl()
         expiration = sigtime_to_posixtime(tok.get_string())
index 861fc4e35966ba539aca3d92bf752292daa18b75..59ffe039134d36e4f03e8ea325baad00192ccf00 100644 (file)
@@ -18,7 +18,6 @@
 import base64
 import struct
 
-import dns.dnssec
 import dns.immutable
 import dns.exception
 import dns.rdata
index 035f7b327c4def5aace92ea13f45b37c79c63868..75f99e5e1d7ab5c87c595ba3fd1386ae61744fc8 100644 (file)
@@ -6,7 +6,7 @@ import binascii
 import dns.immutable
 import dns.rdata
 import dns.rdatatype
-import dns.zone
+import dns.zonetypes
 
 
 @dns.immutable.immutable
@@ -21,8 +21,8 @@ class ZONEMD(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
         super().__init__(rdclass, rdtype)
         self.serial = self._as_uint32(serial)
-        self.scheme = dns.zone.DigestScheme.make(scheme)
-        self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm)
+        self.scheme = dns.zonetypes.DigestScheme.make(scheme)
+        self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm)
         self.digest = self._as_bytes(digest)
 
         if self.scheme == 0:  # reserved, RFC 8976 Sec. 5.2
@@ -30,7 +30,7 @@ class ZONEMD(dns.rdata.Rdata):
         if self.hash_algorithm == 0:  # reserved, RFC 8976 Sec. 5.3
             raise ValueError('hash_algorithm 0 is reserved')
 
-        hasher = dns.zone._digest_hashers.get(self.hash_algorithm)
+        hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm)
         if hasher and hasher().digest_size != len(self.digest):
             raise ValueError('digest length inconsistent with hash algorithm')
 
index 788bb2bf9010b18d51fc3bdc073a26f4e485cfc7..832df2d7b3c2cb7c94fd26ebd913b02450f79d4a 100644 (file)
@@ -21,7 +21,7 @@ import struct
 
 import dns.exception
 import dns.immutable
-import dns.dnssec
+import dns.dnssectypes
 import dns.rdata
 
 # wildcard import
@@ -44,7 +44,7 @@ class DNSKEYBase(dns.rdata.Rdata):
         super().__init__(rdclass, rdtype)
         self.flags = self._as_uint16(flags)
         self.protocol = self._as_uint8(protocol)
-        self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+        self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
         self.key = self._as_bytes(key)
 
     def to_text(self, origin=None, relativize=True, **kw):
diff --git a/dns/rdtypes/dnskeybase.pyi b/dns/rdtypes/dnskeybase.pyi
deleted file mode 100644 (file)
index 1b999cf..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-from typing import Set, Any
-
-SEP : int
-REVOKE : int
-ZONE : int
-
-def flags_to_text_set(flags : int) -> Set[str]:
-    ...
-
-def flags_from_text_set(texts_set) -> int:
-    ...
-
-from .. import rdata
-
-class DNSKEYBase(rdata.Rdata):
-    def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
-        self.flags : int
-        self.protocol : int
-        self.key : str
-        self.algorithm : int
-
-    def to_text(self, origin : Any = None, relativize=True, **kw : Any):
-        ...
-
-    @classmethod
-    def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
-                  relativize_to=None):
-        ...
-
-    def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
-        ...
-
-    @classmethod
-    def from_parser(cls, rdclass, rdtype, parser, origin=None):
-        ...
-
-    def flags_to_text_set(self) -> Set[str]:
-        ...
index 0c2e7471b10ab2f37c7469fb2d000d8871f0b8a0..3bf93accb00cd518f4b50b904fbf3bf26e0ce3c4 100644 (file)
@@ -18,7 +18,7 @@
 import struct
 import binascii
 
-import dns.dnssec
+import dns.dnssectypes
 import dns.immutable
 import dns.rdata
 import dns.rdatatype
@@ -43,7 +43,7 @@ class DSBase(dns.rdata.Rdata):
                  digest):
         super().__init__(rdclass, rdtype)
         self.key_tag = self._as_uint16(key_tag)
-        self.algorithm = dns.dnssec.Algorithm.make(algorithm)
+        self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
         self.digest_type = self._as_uint8(digest_type)
         self.digest = self._as_bytes(digest)
         try:
index 68071ee0abda00da55edb71162841605edad756e..7ad7914f9565336562b26b6c4efe2a92bbcb9795 100644 (file)
@@ -17,6 +17,8 @@
 
 """TXT-like base class."""
 
+from typing import Iterable, Optional, Tuple, Union
+
 import struct
 
 import dns.exception
@@ -32,7 +34,7 @@ class TXTBase(dns.rdata.Rdata):
 
     __slots__ = ['strings']
 
-    def __init__(self, rdclass, rdtype, strings):
+    def __init__(self, rdclass, rdtype, strings: Iterable[Union[bytes, str]]):
         """Initialize a TXT-like rdata.
 
         *rdclass*, an ``int`` is the rdataclass of the Rdata.
@@ -42,10 +44,9 @@ class TXTBase(dns.rdata.Rdata):
         *strings*, a tuple of ``bytes``
         """
         super().__init__(rdclass, rdtype)
-        self.strings = self._as_tuple(strings,
-                                      lambda x: self._as_bytes(x, True, 255))
+        self.strings: Tuple[bytes] = self._as_tuple(strings, lambda x: self._as_bytes(x, True, 255))
 
-    def to_text(self, origin=None, relativize=True, **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw):
         txt = ''
         prefix = ''
         for s in self.strings:
@@ -54,8 +55,8 @@ class TXTBase(dns.rdata.Rdata):
         return txt
 
     @classmethod
-    def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
-                  relativize_to=None):
+    def from_text(cls, rdclass, rdtype, tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None,
+                  relativize=True, relativize_to: Optional[dns.name.Name]=None):
         strings = []
         for token in tok.get_remaining():
             token = token.unescape_to_bytes()
diff --git a/dns/rdtypes/txtbase.pyi b/dns/rdtypes/txtbase.pyi
deleted file mode 100644 (file)
index f8d5df9..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-import typing
-from .. import rdata
-
-class TXTBase(rdata.Rdata):
-    strings: typing.Tuple[bytes, ...]
-
-    def __init__(self, rdclass: int, rdtype: int, strings: typing.Iterable[bytes]) -> None:
-        ...
-    def to_text(self, origin: typing.Any, relativize: bool, **kw: typing.Any) -> str:
-        ...
-class TXT(TXTBase):
-    ...
index 332c82c0376ae608e8e92383bfc57259bf096da0..42d228d9a7a6b08de0ff147b60430ec1a045a094 100644 (file)
@@ -16,6 +16,9 @@
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
 """DNS stub resolver."""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
 from urllib.parse import urlparse
 import contextlib
 import socket
@@ -52,6 +55,10 @@ class NXDOMAIN(dns.exception.DNSException):
 
     # pylint: disable=arguments-differ
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     def _check_kwargs(self, qnames,
                       responses=None):
         if not isinstance(qnames, (list, tuple, set)):
@@ -132,7 +139,10 @@ class YXDOMAIN(dns.exception.DNSException):
     """The DNS query name is too long after DNAME substitution."""
 
 
-def _errors_to_text(errors):
+ErrorTuple = Tuple[str, bool, int, Exception, dns.message.Message]
+
+
+def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
     """Turn a resolution errors trace into a list of text."""
     texts = []
     for err in errors:
@@ -148,6 +158,10 @@ class LifetimeTimeout(dns.exception.Timeout):
     fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1]
     supp_kwargs = {'timeout', 'errors'}
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     def _fmt_kwargs(self, **kwargs):
         srv_msgs = _errors_to_text(kwargs['errors'])
         return super()._fmt_kwargs(timeout=kwargs['timeout'],
@@ -166,6 +180,10 @@ class NoAnswer(dns.exception.DNSException):
           'to the question: {query}'
     supp_kwargs = {'response'}
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     def _fmt_kwargs(self, **kwargs):
         return super()._fmt_kwargs(query=kwargs['response'].question)
 
@@ -186,6 +204,10 @@ class NoNameservers(dns.exception.DNSException):
     fmt = "%s {query}: {errors}" % msg[:-1]
     supp_kwargs = {'request', 'errors'}
 
+    # We do this as otherwise mypy complains about unexpected keyword argument idna_exception
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     def _fmt_kwargs(self, **kwargs):
         srv_msgs = _errors_to_text(kwargs['errors'])
         return super()._fmt_kwargs(query=kwargs['request'].question,
@@ -222,8 +244,9 @@ class Answer:
     RRset's name might not be the query name.
     """
 
-    def __init__(self, qname, rdtype, rdclass, response, nameserver=None,
-                 port=None):
+    def __init__(self, qname: dns.name.Name, rdtype: dns.rdatatype.RdataType,
+                 rdclass: dns.rdataclass.RdataClass, response: dns.message.QueryMessage,
+                 nameserver: Optional[str]=None, port: Optional[int]=None):
         self.qname = qname
         self.rdtype = rdtype
         self.rdclass = rdclass
@@ -280,7 +303,7 @@ class CacheStatistics:
         self.hits = 0
         self.misses = 0
 
-    def clone(self):
+    def clone(self) -> 'CacheStatistics':
         return CacheStatistics(self.hits, self.misses)
 
 
@@ -304,7 +327,7 @@ class CacheBase:
         with self.lock:
             return self.statistics.misses
 
-    def get_statistics_snapshot(self):
+    def get_statistics_snapshot(self) -> CacheStatistics:
         """Return a consistent snapshot of all the statistics.
 
         If running with multiple threads, it's better to take a
@@ -315,6 +338,9 @@ class CacheBase:
             return self.statistics.clone()
 
 
+CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass]
+
+
 class Cache(CacheBase):
     """Simple thread-safe DNS answer cache."""
 
@@ -342,12 +368,12 @@ class Cache(CacheBase):
             now = time.time()
             self.next_cleaning = now + self.cleaning_interval
 
-    def get(self, key):
+    def get(self, key: CacheKey) -> Optional[Answer]:
         """Get the answer associated with *key*.
 
         Returns None if no answer is cached for the key.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
 
         Returns a ``dns.resolver.Answer`` or ``None``.
@@ -362,10 +388,10 @@ class Cache(CacheBase):
             self.statistics.hits += 1
             return v
 
-    def put(self, key, value):
+    def put(self, key: CacheKey, value: Answer):
         """Associate key and value in the cache.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
 
         *value*, a ``dns.resolver.Answer``, the answer.
@@ -375,13 +401,13 @@ class Cache(CacheBase):
             self._maybe_clean()
             self.data[key] = value
 
-    def flush(self, key=None):
+    def flush(self, key: Optional[CacheKey]=None):
         """Flush the cache.
 
         If *key* is not ``None``, only that item is flushed.  Otherwise
         the entire cache is flushed.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
         """
 
@@ -442,12 +468,12 @@ class LRUCache(CacheBase):
             max_size = 1
         self.max_size = max_size
 
-    def get(self, key):
+    def get(self, key: CacheKey) -> Optional[Answer]:
         """Get the answer associated with *key*.
 
         Returns None if no answer is cached for the key.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
 
         Returns a ``dns.resolver.Answer`` or ``None``.
@@ -470,7 +496,7 @@ class LRUCache(CacheBase):
             node.hits += 1
             return node.value
 
-    def get_hits_for_key(self, key):
+    def get_hits_for_key(self, key: CacheKey) -> int:
         """Return the number of cache hits associated with the specified key."""
         with self.lock:
             node = self.data.get(key)
@@ -479,10 +505,10 @@ class LRUCache(CacheBase):
             else:
                 return node.hits
 
-    def put(self, key, value):
+    def put(self, key: CacheKey, value: Answer):
         """Associate key and value in the cache.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
 
         *value*, a ``dns.resolver.Answer``, the answer.
@@ -501,13 +527,13 @@ class LRUCache(CacheBase):
             node.link_after(self.sentinel)
             self.data[key] = node
 
-    def flush(self, key=None):
+    def flush(self, key: Optional[CacheKey]=None):
         """Flush the cache.
 
         If *key* is not ``None``, only that item is flushed.  Otherwise
         the entire cache is flushed.
 
-        *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the
+        *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
         query name, rdtype, and rdclass respectively.
         """
 
@@ -537,8 +563,10 @@ class _Resolution:
     resolver data structures directly.
     """
 
-    def __init__(self, resolver, qname, rdtype, rdclass, tcp,
-                 raise_on_no_answer, search):
+    def __init__(self, resolver: 'BaseResolver', qname: Union[dns.name.Name, str],
+                 rdtype: Union[dns.rdatatype.RdataType, str],
+                 rdclass: Union[dns.rdataclass.RdataClass, str],
+                 tcp: bool, raise_on_no_answer: bool, search: Optional[bool]):
         if isinstance(qname, str):
             qname = dns.name.from_text(qname, None)
         rdtype = dns.rdatatype.RdataType.make(rdtype)
@@ -554,21 +582,20 @@ class _Resolution:
         self.rdclass = rdclass
         self.tcp = tcp
         self.raise_on_no_answer = raise_on_no_answer
-        self.nxdomain_responses = {}
-        #
+        self.nxdomain_responses: Dict[dns.name.Name, Answer] = {}
         # Initialize other things to help analysis tools
         self.qname = dns.name.empty
-        self.nameservers = []
-        self.current_nameservers = []
-        self.errors = []
-        self.nameserver = None
+        self.nameservers: List[str] = []
+        self.current_nameservers: List[str] = []
+        self.errors: List[ErrorTuple] = []
+        self.nameserver: Optional[str] = None
         self.port = 0
         self.tcp_attempt = False
         self.retry_with_tcp = False
-        self.request = None
-        self.backoff = 0
+        self.request: Optional[dns.message.QueryMessage] = None
+        self.backoff = 0.0
 
-    def next_request(self):
+    def next_request(self) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]:
         """Get the next request to send, and check the cache.
 
         Returns a (request, answer) tuple.  At most one of request or
@@ -732,6 +759,7 @@ class _Resolution:
                                 dns.rcode.to_text(rcode), response))
             return (None, False)
 
+
 class BaseResolver:
     """DNS stub resolver."""
 
@@ -765,10 +793,10 @@ class BaseResolver:
             dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
         if len(self.domain) == 0:
             self.domain = dns.name.root
-        self.nameservers = []
-        self.nameserver_ports = {}
+        self.nameservers: List[str] = []
+        self.nameserver_ports: Dict[str, int] = {}
         self.port = 53
-        self.search = []
+        self.search: List[dns.name.Name] = []
         self.use_search_by_default = False
         self.timeout = 2.0
         self.lifetime = 5.0
@@ -777,13 +805,13 @@ class BaseResolver:
         self.keyalgorithm = dns.tsig.default_algorithm
         self.edns = -1
         self.ednsflags = 0
-        self.ednsoptions = None
+        self.ednsoptions: Optional[List[dns.edns.Option]] = None
         self.payload = 0
         self.cache = None
         self.flags = None
         self.retry_servfail = False
         self.rotate = False
-        self.ndots = None
+        self.ndots: Optional[int] = None
 
     def read_resolv_conf(self, f):
         """Process *f* as a file in the /etc/resolv.conf format.  If f is
@@ -862,7 +890,8 @@ class BaseResolver:
         except AttributeError:
             raise NotImplementedError
 
-    def _compute_timeout(self, start, lifetime=None, errors=None):
+    def _compute_timeout(self, start: float, lifetime: Optional[float]=None,
+                         errors: Optional[List[ErrorTuple]]=None) -> float:
         lifetime = self.lifetime if lifetime is None else lifetime
         now = time.time()
         duration = now - start
@@ -881,7 +910,7 @@ class BaseResolver:
             raise LifetimeTimeout(timeout=duration, errors=errors)
         return min(lifetime - duration, self.timeout)
 
-    def _get_qnames_to_try(self, qname, search):
+    def _get_qnames_to_try(self, qname: dns.name.Name, search: Optional[bool]) -> List[dns.name.Name]:
         # This is a separate method so we can unit test the search
         # rules without requiring the Internet.
         if search is None:
@@ -960,7 +989,7 @@ class BaseResolver:
         self.payload = payload
         self.ednsoptions = options
 
-    def set_flags(self, flags):
+    def set_flags(self, flags: int):
         """Overrides the default flags with your own.
 
         *flags*, an ``int``, the message flags to use.
@@ -969,11 +998,11 @@ class BaseResolver:
         self.flags = flags
 
     @property
-    def nameservers(self):
+    def nameservers(self) -> List[str]:
         return self._nameservers
 
     @nameservers.setter
-    def nameservers(self, nameservers):
+    def nameservers(self, nameservers: List[str]):
         """
         *nameservers*, a ``list`` of nameservers.
 
@@ -998,9 +1027,11 @@ class BaseResolver:
 class Resolver(BaseResolver):
     """DNS stub resolver."""
 
-    def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-                tcp=False, source=None, raise_on_no_answer=True, source_port=0,
-                lifetime=None, search=None):  # pylint: disable=arguments-differ
+    def resolve(self, qname: Union[dns.name.Name, str],
+                rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+                rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+                tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+                lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pylint: disable=arguments-differ
         """Query nameservers to find the answer to the question.
 
         The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -1064,6 +1095,7 @@ class Resolver(BaseResolver):
             if answer is not None:
                 # cache hit!
                 return answer
+            assert request is not None  # needed for type checking
             done = False
             while not done:
                 (nameserver, port, tcp, backoff) = resolution.next_nameserver()
@@ -1101,9 +1133,11 @@ class Resolver(BaseResolver):
                 if answer is not None:
                     return answer
 
-    def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-              tcp=False, source=None, raise_on_no_answer=True, source_port=0,
-              lifetime=None):  # pragma: no cover
+    def query(self, qname: Union[dns.name.Name, str],
+              rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+              rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+              tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+              lifetime: Optional[float]=None) -> Answer:  # pragma: no cover
         """Query nameservers to find the answer to the question.
 
         This method calls resolve() with ``search=True``, and is
@@ -1117,7 +1151,7 @@ class Resolver(BaseResolver):
                             raise_on_no_answer, source_port, lifetime,
                             True)
 
-    def resolve_address(self, ipaddr, *args, **kwargs):
+    def resolve_address(self, ipaddr: str, *args, **kwargs) -> Answer:
         """Use a resolver to run a reverse query for PTR records.
 
         This utilizes the resolve() method to perform a PTR lookup on the
@@ -1130,15 +1164,19 @@ class Resolver(BaseResolver):
         except for rdtype and rdclass are also supported by this
         function.
         """
-
+        # We make a modified kwargs for type checking happiness, as otherwise
+        # we get a legit warning about possibly having rdtype and rdclass
+        # in the kwargs more than once.
+        modified_kwargs = {}
+        modified_kwargs.update(kwargs)
+        modified_kwargs['rdtype'] = dns.rdatatype.PTR
+        modified_kwargs['rdclass'] = dns.rdataclass.IN
         return self.resolve(dns.reversename.from_address(ipaddr),
-                            rdtype=dns.rdatatype.PTR,
-                            rdclass=dns.rdataclass.IN,
-                            *args, **kwargs)
+                            *args, **modified_kwargs)
 
     # pylint: disable=redefined-outer-name
 
-    def canonical_name(self, name):
+    def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
         """Determine the canonical name of *name*.
 
         The canonical name is the name the resolver uses for queries
@@ -1163,13 +1201,14 @@ class Resolver(BaseResolver):
 
 
 #: The default resolver.
-default_resolver = None
+default_resolver: Optional[Resolver] = None
 
 
-def get_default_resolver():
+def get_default_resolver() -> Resolver:
     """Get the default resolver, initializing it if necessary."""
     if default_resolver is None:
         reset_default_resolver()
+    assert default_resolver is not None
     return default_resolver
 
 
@@ -1184,9 +1223,12 @@ def reset_default_resolver():
     default_resolver = Resolver()
 
 
-def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-            tcp=False, source=None, raise_on_no_answer=True,
-            source_port=0, lifetime=None, search=None):
+def resolve(qname: Union[dns.name.Name, str],
+            rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+            rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+            tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+            lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer:  # pragma: no cover
+
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
@@ -1200,9 +1242,11 @@ def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
                                           raise_on_no_answer, source_port,
                                           lifetime, search)
 
-def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-          tcp=False, source=None, raise_on_no_answer=True,
-          source_port=0, lifetime=None):  # pragma: no cover
+def query(qname: Union[dns.name.Name, str],
+          rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+          rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+          tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+          lifetime: Optional[float]=None) -> Answer:  # pragma: no cover
     """Query nameservers to find the answer to the question.
 
     This method calls resolve() with ``search=True``, and is
@@ -1217,7 +1261,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
                    True)
 
 
-def resolve_address(ipaddr, *args, **kwargs):
+def resolve_address(ipaddr: str, *args, **kwargs) -> Answer:
     """Use a resolver to run a reverse query for PTR records.
 
     See ``dns.resolver.Resolver.resolve_address`` for more information on the
@@ -1227,7 +1271,7 @@ def resolve_address(ipaddr, *args, **kwargs):
     return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
 
 
-def canonical_name(name):
+def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
     """Determine the canonical name of *name*.
 
     See ``dns.resolver.Resolver.canonical_name`` for more information on the
@@ -1237,8 +1281,9 @@ def canonical_name(name):
     return get_default_resolver().canonical_name(name)
 
 
-def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None,
-                  lifetime=None):
+def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
+                  tcp=False, resolver: Optional[Resolver]=None,
+                  lifetime: Optional[float]=None) -> dns.name.Name:
     """Find the name of the zone which contains the specified name.
 
     *name*, an absolute ``dns.name.Name`` or ``str``, the query name.
@@ -1285,6 +1330,7 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None,
                 rlifetime = None
             answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp,
                                       lifetime=rlifetime)
+            assert answer.rrset is not None
             if answer.rrset.name == name:
                 return name
             # otherwise we were CNAMEd or DNAMEd and need to look higher
@@ -1544,7 +1590,7 @@ def _gethostbyaddr(ip):
     return (canonical, aliases, addresses)
 
 
-def override_system_resolver(resolver=None):
+def override_system_resolver(resolver: Optional[Resolver]=None):
     """Override the system resolver routines in the socket module with
     versions which use dnspython's resolver.
 
diff --git a/dns/resolver.pyi b/dns/resolver.pyi
deleted file mode 100644 (file)
index 348df4d..0000000
+++ /dev/null
@@ -1,66 +0,0 @@
-from typing import Union, Optional, List, Any, Dict
-from . import exception, rdataclass, name, rdatatype
-
-import socket
-_gethostbyname = socket.gethostbyname
-
-class NXDOMAIN(exception.DNSException): ...
-class YXDOMAIN(exception.DNSException): ...
-class NoAnswer(exception.DNSException): ...
-class NoNameservers(exception.DNSException): ...
-class NotAbsolute(exception.DNSException): ...
-class NoRootSOA(exception.DNSException): ...
-class NoMetaqueries(exception.DNSException): ...
-class NoResolverConfiguration(exception.DNSException): ...
-Timeout = exception.Timeout
-
-def resolve(qname : str, rdtype : Union[int,str] = 0,
-            rdclass : Union[int,str] = 0,
-            tcp=False, source=None, raise_on_no_answer=True,
-            source_port=0, lifetime : Optional[float]=None,
-            search : Optional[bool]=None):
-    ...
-def query(qname : str, rdtype : Union[int,str] = 0,
-          rdclass : Union[int,str] = 0,
-          tcp=False, source=None, raise_on_no_answer=True,
-          source_port=0, lifetime : Optional[float]=None):
-    ...
-def resolve_address(ipaddr: str, *args: Any, **kwargs: Optional[Dict]):
-    ...
-class LRUCache:
-    def __init__(self, max_size=1000):
-        ...
-    def get(self, key):
-        ...
-    def put(self, key, val):
-        ...
-class Answer:
-    def __init__(self, qname, rdtype, rdclass, response,
-                 raise_on_no_answer=True):
-        ...
-def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False,
-                  resolver : Optional[Resolver] = None):
-    ...
-
-class Resolver:
-    def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
-                 configure : Optional[bool] = True):
-        self.nameservers : List[str]
-    def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
-                rdclass : Union[int,str] = rdataclass.IN,
-                tcp : bool = False, source : Optional[str] = None,
-                raise_on_no_answer=True, source_port : int = 0,
-                lifetime : Optional[float]=None,
-                search : Optional[bool]=None):
-        ...
-    def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
-              rdclass : Union[int,str] = rdataclass.IN,
-              tcp : bool = False, source : Optional[str] = None,
-              raise_on_no_answer=True, source_port : int = 0,
-              lifetime : Optional[float]=None):
-        ...
-default_resolver: typing.Optional[Resolver]
-def reset_default_resolver() -> None:
-    ...
-def get_default_resolver() -> Resolver:
-    ...
index e0beb03df95131b2bd258fef1d7d39d8ee135d05..4b70cf6495ad677c78c6ca5bf2bafe541f77b4a9 100644 (file)
@@ -27,8 +27,8 @@ ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.')
 ipv6_reverse_domain = dns.name.from_text('ip6.arpa.')
 
 
-def from_address(text, v4_origin=ipv4_reverse_domain,
-                 v6_origin=ipv6_reverse_domain):
+def from_address(text: str, v4_origin=ipv4_reverse_domain,
+                 v6_origin=ipv6_reverse_domain) -> dns.name.Name:
     """Convert an IPv4 or IPv6 address in textual form into a Name object whose
     value is the reverse-map domain name of the address.
 
@@ -63,8 +63,8 @@ def from_address(text, v4_origin=ipv4_reverse_domain,
     return dns.name.from_text('.'.join(reversed(parts)), origin=origin)
 
 
-def to_address(name, v4_origin=ipv4_reverse_domain,
-               v6_origin=ipv6_reverse_domain):
+def to_address(name: dns.name.Name, v4_origin=ipv4_reverse_domain,
+               v6_origin=ipv6_reverse_domain) -> str:
     """Convert a reverse map domain name into textual address form.
 
     *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name
diff --git a/dns/reversename.pyi b/dns/reversename.pyi
deleted file mode 100644 (file)
index 97f072e..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-from . import name
-def from_address(text : str) -> name.Name:
-    ...
-
-def to_address(name : name.Name) -> str:
-    ...
index a71d45737c3b873d17bec113546ca88979fe21e5..3745857145cbfdb91862c8089ff40eb5c25afd1e 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS RRsets (an RRset is a named rdataset)"""
 
+from typing import cast, Collection, Optional, Union
 
 import dns.name
 import dns.rdataset
@@ -37,8 +38,10 @@ class RRset(dns.rdataset.Rdataset):
 
     __slots__ = ['name', 'deleting']
 
-    def __init__(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                 deleting=None):
+    def __init__(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass,
+                 rdtype: dns.rdatatype.RdataType,
+                 covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                 deleting: Optional[dns.rdataclass.RdataClass]=None):
         """Create a new RRset."""
 
         super().__init__(rdclass, rdtype, covers)
@@ -76,7 +79,7 @@ class RRset(dns.rdataset.Rdataset):
             return False
         return super().__eq__(other)
 
-    def match(self, *args, **kwargs):
+    def match(self, *args, **kwargs) -> bool:
         """Does this rrset match the specified attributes?
 
         Behaves as :py:func:`full_match()` if the first argument is a
@@ -93,8 +96,9 @@ class RRset(dns.rdataset.Rdataset):
         else:
             return super().match(*args, **kwargs)
 
-    def full_match(self, name, rdclass, rdtype, covers,
-                    deleting=None):
+    def full_match(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass,
+                   rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType,
+                   deleting: Optional[dns.rdataclass.RdataClass]=None) -> bool:
         """Returns ``True`` if this rrset matches the specified name, class,
         type, covers, and deletion state.
         """
@@ -106,7 +110,7 @@ class RRset(dns.rdataset.Rdataset):
 
     # pylint: disable=arguments-differ
 
-    def to_text(self, origin=None, relativize=True, **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw) -> str:  # type: ignore
         """Convert the RRset into DNS zone file format.
 
         See ``dns.name.Name.choose_relativity`` for more information
@@ -126,8 +130,8 @@ class RRset(dns.rdataset.Rdataset):
         return super().to_text(self.name, origin, relativize,
                                self.deleting, **kw)
 
-    def to_wire(self, file, compress=None, origin=None,
-                **kw):
+    def to_wire(self, file, compress: Optional[dns.name.CompressType]=None,  # type: ignore
+                origin: Optional[dns.name.Name]=None, **kw) -> int:
         """Convert the RRset to wire format.
 
         All keyword arguments are passed to ``dns.rdataset.to_wire()``; see
@@ -141,7 +145,7 @@ class RRset(dns.rdataset.Rdataset):
 
     # pylint: enable=arguments-differ
 
-    def to_rdataset(self):
+    def to_rdataset(self) -> dns.rdataset.Rdataset:
         """Convert an RRset into an Rdataset.
 
         Returns a ``dns.rdataset.Rdataset``.
@@ -149,9 +153,13 @@ class RRset(dns.rdataset.Rdataset):
         return dns.rdataset.from_rdata_list(self.ttl, list(self))
 
 
-def from_text_list(name, ttl, rdclass, rdtype, text_rdatas,
-                   idna_codec=None, origin=None, relativize=True,
-                   relativize_to=None):
+def from_text_list(name: Union[dns.name.Name, str], ttl: int,
+                   rdclass: Union[dns.rdataclass.RdataClass, str],
+                   rdtype: Union[dns.rdatatype.RdataType, str],
+                   text_rdatas: Collection[str],
+                   idna_codec: Optional[dns.name.IDNACodec]=None,
+                   origin: Optional[dns.name.Name]=None, relativize=True,
+                   relativize_to: Optional[dns.name.Name]=None) -> RRset:
     """Create an RRset with the specified name, TTL, class, and type, and with
     the specified list of rdatas in text format.
 
@@ -172,9 +180,9 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas,
 
     if isinstance(name, str):
         name = dns.name.from_text(name, None, idna_codec=idna_codec)
-    rdclass = dns.rdataclass.RdataClass.make(rdclass)
-    rdtype = dns.rdatatype.RdataType.make(rdtype)
-    r = RRset(name, rdclass, rdtype)
+    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+    r = RRset(name, the_rdclass, the_rdtype)
     r.update_ttl(ttl)
     for t in text_rdatas:
         rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
@@ -183,17 +191,23 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas,
     return r
 
 
-def from_text(name, ttl, rdclass, rdtype, *text_rdatas):
+def from_text(name: Union[dns.name.Name, str], ttl: int,
+              rdclass: Union[dns.rdataclass.RdataClass, str],
+              rdtype: Union[dns.rdatatype.RdataType, str],
+              *text_rdatas) -> RRset:
     """Create an RRset with the specified name, TTL, class, and type and with
     the specified rdatas in text format.
 
     Returns a ``dns.rrset.RRset`` object.
     """
 
-    return from_text_list(name, ttl, rdclass, rdtype, text_rdatas)
+    return from_text_list(name, ttl, rdclass, rdtype,
+                          cast(Collection[str], text_rdatas))
 
 
-def from_rdata_list(name, ttl, rdatas, idna_codec=None):
+def from_rdata_list(name: Union[dns.name.Name, str], ttl:int,
+                    rdatas: Collection[dns.rdata.Rdata],
+                    idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset:
     """Create an RRset with the specified name and TTL, and with
     the specified list of rdata objects.
 
@@ -216,14 +230,15 @@ def from_rdata_list(name, ttl, rdatas, idna_codec=None):
             r = RRset(name, rd.rdclass, rd.rdtype)
             r.update_ttl(ttl)
         r.add(rd)
+    assert r is not None
     return r
 
 
-def from_rdata(name, ttl, *rdatas):
+def from_rdata(name: Union[dns.name.Name, str], ttl:int, *rdatas) -> RRset:
     """Create an RRset with the specified name and TTL, and with
     the specified rdata objects.
 
     Returns a ``dns.rrset.RRset`` object.
     """
 
-    return from_rdata_list(name, ttl, rdatas)
+    return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas))
diff --git a/dns/rrset.pyi b/dns/rrset.pyi
deleted file mode 100644 (file)
index 0a81a2a..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-from typing import List, Optional
-from . import rdataset, rdatatype
-
-class RRset(rdataset.Rdataset):
-    def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE,
-                 deleting : Optional[int] =None) -> None:
-        self.name = name
-        self.deleting = deleting
-def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str):
-    ...
index b0474151c85086303fb030a58cd692ca3656dfde..138ffbf966929f4103da27af5399127b1308eb98 100644 (file)
@@ -3,7 +3,7 @@
 """Serial Number Arthimetic from RFC 1982"""
 
 class Serial:
-    def __init__(self, value, bits=32):
+    def __init__(self, value:int , bits=32):
         self.value = value % 2 ** bits
         self.bits = bits
 
index cb6a6302d0cdf779d6d04dd861929ba360c02f1a..bb94ce94c3666199ce59237a13edc45f01da4738 100644 (file)
@@ -17,6 +17,8 @@
 
 """Tokenize DNS zone file format"""
 
+from typing import Optional, List, Tuple
+
 import io
 import sys
 
@@ -48,7 +50,7 @@ class Token:
     has_escape: Does the token value contain escapes?
     """
 
-    def __init__(self, ttype, value='', has_escape=False, comment=None):
+    def __init__(self, ttype: int, value='', has_escape=False, comment: Optional[str]=None):
         """Initialize a token instance."""
 
         self.ttype = ttype
@@ -56,28 +58,28 @@ class Token:
         self.has_escape = has_escape
         self.comment = comment
 
-    def is_eof(self):
+    def is_eof(self) -> bool:
         return self.ttype == EOF
 
-    def is_eol(self):
+    def is_eol(self) -> bool:
         return self.ttype == EOL
 
-    def is_whitespace(self):
+    def is_whitespace(self) -> bool:
         return self.ttype == WHITESPACE
 
-    def is_identifier(self):
+    def is_identifier(self) -> bool:
         return self.ttype == IDENTIFIER
 
-    def is_quoted_string(self):
+    def is_quoted_string(self) -> bool:
         return self.ttype == QUOTED_STRING
 
-    def is_comment(self):
+    def is_comment(self) -> bool:
         return self.ttype == COMMENT
 
-    def is_delimiter(self):  # pragma: no cover (we don't return delimiters yet)
+    def is_delimiter(self) -> bool:  # pragma: no cover (we don't return delimiters yet)
         return self.ttype == DELIMITER
 
-    def is_eol_or_eof(self):
+    def is_eol_or_eof(self) -> bool:
         return self.ttype == EOL or self.ttype == EOF
 
     def __eq__(self, other):
@@ -95,7 +97,7 @@ class Token:
     def __str__(self):
         return '%d "%s"' % (self.ttype, self.value)
 
-    def unescape(self):
+    def unescape(self) -> 'Token':
         if not self.has_escape:
             return self
         unescaped = ''
@@ -127,7 +129,7 @@ class Token:
             unescaped += c
         return Token(self.ttype, unescaped)
 
-    def unescape_to_bytes(self):
+    def unescape_to_bytes(self) -> 'Token':
         # We used to use unescape() for TXT-like records, but this
         # caused problems as we'd process DNS escapes into Unicode code
         # points instead of byte values, and then a to_text() of the
@@ -223,7 +225,8 @@ class Tokenizer:
     encoder/decoder is used.
     """
 
-    def __init__(self, f=sys.stdin, filename=None, idna_codec=None):
+    def __init__(self, f=sys.stdin, filename: Optional[str]=None,
+                 idna_codec: Optional[dns.name.IDNACodec]=None):
         """Initialize a tokenizer instance.
 
         f: The file to tokenize.  The default is sys.stdin.
@@ -253,19 +256,21 @@ class Tokenizer:
                 else:
                     filename = '<file>'
         self.file = f
-        self.ungotten_char = None
-        self.ungotten_token = None
+        self.ungotten_char: Optional[str] = None
+        self.ungotten_token: Optional[Token] = None
         self.multiline = 0
         self.quoting = False
         self.eof = False
         self.delimiters = _DELIMITERS
         self.line_number = 1
+        assert filename is not None
         self.filename = filename
         if idna_codec is None:
-            idna_codec = dns.name.IDNA_2003
-        self.idna_codec = idna_codec
+            self.idna_codec: dns.name.IDNACodec = dns.name.IDNA_2003
+        else:
+            self.idna_codec = idna_codec
 
-    def _get_char(self):
+    def _get_char(self) -> str:
         """Read a character from input.
         """
 
@@ -283,7 +288,7 @@ class Tokenizer:
             self.ungotten_char = None
         return c
 
-    def where(self):
+    def where(self) -> Tuple[str, int]:
         """Return the current location in the input.
 
         Returns a (string, int) tuple.  The first item is the filename of
@@ -328,7 +333,7 @@ class Tokenizer:
                     return skipped
             skipped += 1
 
-    def get(self, want_leading=False, want_comment=False):
+    def get(self, want_leading=False, want_comment=False) -> Token:
         """Get the next token.
 
         want_leading: If True, return a WHITESPACE token if the
@@ -345,16 +350,16 @@ class Tokenizer:
         """
 
         if self.ungotten_token is not None:
-            token = self.ungotten_token
+            utoken = self.ungotten_token
             self.ungotten_token = None
-            if token.is_whitespace():
+            if utoken.is_whitespace():
                 if want_leading:
-                    return token
-            elif token.is_comment():
+                    return utoken
+            elif utoken.is_comment():
                 if want_comment:
-                    return token
+                    return utoken
             else:
-                return token
+                return utoken
         skipped = self.skip_whitespace()
         if want_leading and skipped > 0:
             return Token(WHITESPACE, ' ')
@@ -438,7 +443,7 @@ class Tokenizer:
             ttype = EOF
         return Token(ttype, token, has_escape)
 
-    def unget(self, token):
+    def unget(self, token: Token):
         """Unget a token.
 
         The unget buffer for tokens is only one token large; it is
@@ -487,7 +492,7 @@ class Tokenizer:
             raise dns.exception.SyntaxError('expecting an integer')
         return int(token.value, base)
 
-    def get_uint8(self):
+    def get_uint8(self) -> int:
         """Read the next token and interpret it as an 8-bit unsigned
         integer.
 
@@ -502,7 +507,7 @@ class Tokenizer:
                 '%d is not an unsigned 8-bit integer' % value)
         return value
 
-    def get_uint16(self, base=10):
+    def get_uint16(self, base=10) -> int:
         """Read the next token and interpret it as a 16-bit unsigned
         integer.
 
@@ -521,7 +526,7 @@ class Tokenizer:
                     '%d is not an unsigned 16-bit integer' % value)
         return value
 
-    def get_uint32(self, base=10):
+    def get_uint32(self, base=10) -> int:
         """Read the next token and interpret it as a 32-bit unsigned
         integer.
 
@@ -536,7 +541,7 @@ class Tokenizer:
                 '%d is not an unsigned 32-bit integer' % value)
         return value
 
-    def get_uint48(self, base=10):
+    def get_uint48(self, base=10) -> int:
         """Read the next token and interpret it as a 48-bit unsigned
         integer.
 
@@ -551,7 +556,7 @@ class Tokenizer:
                 '%d is not an unsigned 48-bit integer' % value)
         return value
 
-    def get_string(self, max_length=None):
+    def get_string(self, max_length=None) -> str:
         """Read the next token and interpret it as a string.
 
         Raises dns.exception.SyntaxError if not a string.
@@ -568,7 +573,7 @@ class Tokenizer:
             raise dns.exception.SyntaxError("string too long")
         return token.value
 
-    def get_identifier(self):
+    def get_identifier(self) -> str:
         """Read the next token, which should be an identifier.
 
         Raises dns.exception.SyntaxError if not an identifier.
@@ -581,7 +586,7 @@ class Tokenizer:
             raise dns.exception.SyntaxError('expecting an identifier')
         return token.value
 
-    def get_remaining(self, max_tokens=None):
+    def get_remaining(self, max_tokens=None) -> List[Token]:
         """Return the remaining tokens on the line, until an EOL or EOF is seen.
 
         max_tokens: If not None, stop after this number of tokens.
@@ -600,7 +605,7 @@ class Tokenizer:
                 break
         return tokens
 
-    def concatenate_remaining_identifiers(self, allow_empty=False):
+    def concatenate_remaining_identifiers(self, allow_empty=False) -> str:
         """Read the remaining tokens on the line, which should be identifiers.
 
         Raises dns.exception.SyntaxError if there are no remaining tokens,
@@ -625,7 +630,8 @@ class Tokenizer:
             raise dns.exception.SyntaxError('expecting another identifier')
         return s
 
-    def as_name(self, token, origin=None, relativize=False, relativize_to=None):
+    def as_name(self, token: Token, origin: Optional[dns.name.Name]=None,
+                relativize=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
         """Try to interpret the token as a DNS name.
 
         Raises dns.exception.SyntaxError if not a name.
@@ -637,7 +643,8 @@ class Tokenizer:
         name = dns.name.from_text(token.value, origin, self.idna_codec)
         return name.choose_relativity(relativize_to or origin, relativize)
 
-    def get_name(self, origin=None, relativize=False, relativize_to=None):
+    def get_name(self, origin: Optional[dns.name.Name]=None, relativize=False,
+                 relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
         """Read the next token and interpret it as a DNS name.
 
         Raises dns.exception.SyntaxError if not a name.
@@ -648,7 +655,7 @@ class Tokenizer:
         token = self.get()
         return self.as_name(token, origin, relativize, relativize_to)
 
-    def get_eol_as_token(self):
+    def get_eol_as_token(self) -> Token:
         """Read the next token and raise an exception if it isn't EOL or
         EOF.
 
@@ -662,10 +669,10 @@ class Tokenizer:
                                                       token.value))
         return token
 
-    def get_eol(self):
+    def get_eol(self) -> str:
         return self.get_eol_as_token().value
 
-    def get_ttl(self):
+    def get_ttl(self) -> int:
         """Read the next token and interpret it as a DNS TTL.
 
         Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an
index d725492432877205735decb844b65d50e1e2410b..ccb557cec3f3eeaa4bc02c557f4182d508bc9591 100644 (file)
@@ -1,9 +1,12 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+from typing import Callable, List, Optional, Tuple, Union
+
 import collections
 
 import dns.exception
 import dns.name
+import dns.node
 import dns.rdataclass
 import dns.rdataset
 import dns.rdatatype
@@ -13,11 +16,11 @@ import dns.ttl
 
 
 class TransactionManager:
-    def reader(self):
+    def reader(self) -> 'Transaction':
         """Begin a read-only transaction."""
         raise NotImplementedError  # pragma: no cover
 
-    def writer(self, replacement=False):
+    def writer(self, replacement=False) -> 'Transaction':
         """Begin a writable transaction.
 
         *replacement*, a ``bool``.  If `True`, the content of the
@@ -27,7 +30,7 @@ class TransactionManager:
         """
         raise NotImplementedError  # pragma: no cover
 
-    def origin_information(self):
+    def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
         """Returns a tuple
 
             (absolute_origin, relativize, effective_origin)
@@ -52,12 +55,12 @@ class TransactionManager:
         """
         raise NotImplementedError  # pragma: no cover
 
-    def get_class(self):
+    def get_class(self) -> dns.rdataclass.RdataClass:
         """The class of the transaction manager.
         """
         raise NotImplementedError  # pragma: no cover
 
-    def from_wire_origin(self):
+    def from_wire_origin(self) -> Optional[dns.name.Name]:
         """Origin to use in from_wire() calls.
         """
         (absolute_origin, relativize, _) = self.origin_information()
@@ -90,22 +93,33 @@ def _ensure_immutable_node(node):
     return dns.node.ImmutableNode(node)
 
 
+CheckPutRdatasetType = Callable[['Transaction', dns.name.Name, dns.rdataset.Rdataset], None]
+CheckDeleteRdatasetType = Callable[['Transaction', dns.name.Name,
+                                    dns.rdatatype.RdataType, dns.rdatatype.RdataType], None]
+CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None]
+
+
 class Transaction:
 
-    def __init__(self, manager, replacement=False, read_only=False):
+    def __init__(self, manager: TransactionManager, replacement=False, read_only=False):
         self.manager = manager
         self.replacement = replacement
         self.read_only = read_only
         self._ended = False
-        self._check_put_rdataset = []
-        self._check_delete_rdataset = []
-        self._check_delete_name = []
+        self._check_put_rdataset: List[CheckPutRdatasetType]= []
+        self._check_delete_rdataset: List[CheckDeleteRdatasetType] = []
+        self._check_delete_name: List[CheckDeleteNameType] = []
 
     #
     # This is the high level API
     #
+    # Note that we currently use non-immutable types in the return type signature to avoid
+    # covariance problems, e.g. if the caller has a List[Rdataset], mypy will be unhappy if we
+    # return an ImmutableRdataset.
 
-    def get(self, name, rdtype, covers=dns.rdatatype.NONE):
+    def get(self, name: Optional[Union[dns.name.Name,str]],
+            rdtype: Union[dns.rdatatype.RdataType, str],
+            covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rdataset.Rdataset:
         """Return the rdataset associated with *name*, *rdtype*, and *covers*,
         or `None` if not found.
 
@@ -115,10 +129,11 @@ class Transaction:
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
         rdtype = dns.rdatatype.RdataType.make(rdtype)
+        covers = dns.rdatatype.RdataType.make(covers)
         rdataset = self._get_rdataset(name, rdtype, covers)
         return _ensure_immutable_rdataset(rdataset)
 
-    def get_node(self, name):
+    def get_node(self, name) -> dns.node.Node:
         """Return the node at *name*, if any.
 
         Returns an immutable node or ``None``.
@@ -210,7 +225,7 @@ class Transaction:
         self._check_read_only()
         return self._delete(True, args)
 
-    def name_exists(self, name):
+    def name_exists(self, name: Union[dns.name.Name, str]) -> bool:
         """Does the specified name exist?"""
         self._check_ended()
         if isinstance(name, str):
@@ -253,7 +268,7 @@ class Transaction:
         self._check_ended()
         return self._iterate_rdatasets()
 
-    def changed(self):
+    def changed(self) -> bool:
         """Has this transaction changed anything?
 
         For read-only transactions, the result is always `False`.
@@ -289,7 +304,7 @@ class Transaction:
         """
         self._end(False)
 
-    def check_put_rdataset(self, check):
+    def check_put_rdataset(self, check: CheckPutRdatasetType):
         """Call *check* before putting (storing) an rdataset.
 
         The function is called with the transaction, the name, and the rdataset.
@@ -301,7 +316,7 @@ class Transaction:
         """
         self._check_put_rdataset.append(check)
 
-    def check_delete_rdataset(self, check):
+    def check_delete_rdataset(self, check: CheckDeleteRdatasetType):
         """Call *check* before deleting an rdataset.
 
         The function is called with the transaction, the name, the rdatatype,
@@ -314,7 +329,7 @@ class Transaction:
         """
         self._check_delete_rdataset.append(check)
 
-    def check_delete_name(self, check):
+    def check_delete_name(self, check: CheckDeleteNameType):
         """Call *check* before putting (storing) an rdataset.
 
         The function is called with the transaction and the name.
index 788581c91359a2b9d21906d92f8483fd8a50d671..06a1bd09af2db4837040515f1adcceb04bb2d268 100644 (file)
 
 """A place to store TSIG keys."""
 
+from typing import Any, Dict, Union
+
 import base64
 
 import dns.name
 import dns.tsig
 
 
-def from_text(textring):
+def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
     """Convert a dictionary containing (textual DNS name, base64 secret)
     pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
     a dictionary containing (textual DNS name, (algorithm, base64 secret))
@@ -32,16 +34,16 @@ def from_text(textring):
 
     keyring = {}
     for (name, value) in textring.items():
-        name = dns.name.from_text(name)
+        kname = dns.name.from_text(name)
         if isinstance(value, str):
-            keyring[name] = dns.tsig.Key(name, value).secret
+            keyring[kname] = dns.tsig.Key(kname, value).secret
         else:
             (algorithm, secret) = value
-            keyring[name] = dns.tsig.Key(name, secret, algorithm)
+            keyring[kname] = dns.tsig.Key(kname, secret, algorithm)
     return keyring
 
 
-def to_text(keyring):
+def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
     """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
     into a text keyring which has (textual DNS name, (textual algorithm,
     base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes)
@@ -52,14 +54,14 @@ def to_text(keyring):
     def b64encode(secret):
         return base64.encodebytes(secret).decode().rstrip()
     for (name, key) in keyring.items():
-        name = name.to_text()
+        tname = name.to_text()
         if isinstance(key, bytes):
-            textring[name] = b64encode(key)
+            textring[tname] = b64encode(key)
         else:
             if isinstance(key.secret, bytes):
                 text_secret = b64encode(key.secret)
             else:
                 text_secret = str(key.secret)
 
-            textring[name] = (key.algorithm.to_text(), text_secret)
+            textring[tname] = (key.algorithm.to_text(), text_secret)
     return textring
diff --git a/dns/tsigkeyring.pyi b/dns/tsigkeyring.pyi
deleted file mode 100644 (file)
index b5d51e1..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-from typing import Dict
-from . import name
-
-def from_text(textring : Dict[str,str]) -> Dict[name.Name,bytes]:
-    ...
-def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]:
-    ...
index df92b2b60fd3bd4c68fde582cdd0c701b5f1f388..9f5730e710111397f894156e480a0ac0c95d7494 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS TTL conversion."""
 
+from typing import Union
+
 import dns.exception
 
 # Technically TTLs are supposed to be between 0 and 2**31 - 1, with values
@@ -31,7 +33,7 @@ class BadTTL(dns.exception.SyntaxError):
     """DNS TTL value is not well-formed."""
 
 
-def from_text(text):
+def from_text(text: str) -> int:
     """Convert the text form of a TTL to an integer.
 
     The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported.
@@ -81,7 +83,7 @@ def from_text(text):
     return total
 
 
-def make(value):
+def make(value: Union[int, str]) -> int:
     if isinstance(value, int):
         return value
     elif isinstance(value, str):
index 9a047553ab594f62d55f15c0cfd04fbbf3f47b91..5df0cc783a0a8250108a7d0cc675e10fb1154221 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS Dynamic Update Support"""
 
+from typing import Any, Optional, Union
 
 import dns.message
 import dns.name
@@ -41,11 +42,14 @@ class UpdateSection(dns.enum.IntEnum):
 
 class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
 
-    _section_enum = UpdateSection
+    # ignore the mypy error here as we mean to use a different enum
+    _section_enum = UpdateSection  # type: ignore
 
-    def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None,
-                 keyname=None, keyalgorithm=dns.tsig.default_algorithm,
-                 id=None):
+    def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None,
+                 rdclass=dns.rdataclass.IN,
+                 keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None,
+                 keyalgorithm=dns.tsig.default_algorithm,
+                 id: Optional[int]=None):
         """Initialize a new DNS Update object.
 
         See the documentation of the Message class for a complete
@@ -152,7 +156,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                                              self.origin)
                     self._add_rr(name, ttl, rd, section=section)
 
-    def add(self, name, *args):
+    def add(self, name: Union[dns.name.Name, str], *args):
         """Add records.
 
         The first argument is always a name.  The other
@@ -167,7 +171,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
 
         self._add(False, self.update, name, *args)
 
-    def delete(self, name, *args):
+    def delete(self, name: Union[dns.name.Name, str], *args):
         """Delete records.
 
         The first argument is always a name.  The other
@@ -187,31 +191,31 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
         if len(args) == 0:
             self.find_rrset(self.update, name, dns.rdataclass.ANY,
                             dns.rdatatype.ANY, dns.rdatatype.NONE,
-                            dns.rdatatype.ANY, True, True)
+                            dns.rdataclass.ANY, True, True)
         elif isinstance(args[0], dns.rdataset.Rdataset):
             for rds in args:
                 for rd in rds:
                     self._add_rr(name, 0, rd, dns.rdataclass.NONE)
         else:
-            args = list(args)
-            if isinstance(args[0], dns.rdata.Rdata):
-                for rd in args:
+            largs = list(args)
+            if isinstance(largs[0], dns.rdata.Rdata):
+                for rd in largs:
                     self._add_rr(name, 0, rd, dns.rdataclass.NONE)
             else:
-                rdtype = dns.rdatatype.RdataType.make(args.pop(0))
-                if len(args) == 0:
+                rdtype = dns.rdatatype.RdataType.make(largs.pop(0))
+                if len(largs) == 0:
                     self.find_rrset(self.update, name,
                                     self.zone_rdclass, rdtype,
                                     dns.rdatatype.NONE,
                                     dns.rdataclass.ANY,
                                     True, True)
                 else:
-                    for s in args:
+                    for s in largs:
                         rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s,
                                                  self.origin)
                         self._add_rr(name, 0, rd, dns.rdataclass.NONE)
 
-    def replace(self, name, *args):
+    def replace(self, name: Union[dns.name.Name, str], *args):
         """Replace records.
 
         The first argument is always a name.  The other
@@ -229,7 +233,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
 
         self._add(True, self.update, name, *args)
 
-    def present(self, name, *args):
+    def present(self, name: Union[dns.name.Name, str], *args):
         """Require that an owner name (and optionally an rdata type,
         or specific rdataset) exists as a prerequisite to the
         execution of the update.
@@ -256,9 +260,11 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                 len(args) > 1:
             if not isinstance(args[0], dns.rdataset.Rdataset):
                 # Add a 0 TTL
-                args = list(args)
-                args.insert(0, 0)
-            self._add(False, self.prerequisite, name, *args)
+                largs = list(args)
+                largs.insert(0, 0)
+                self._add(False, self.prerequisite, name, *largs)
+            else:
+                self._add(False, self.prerequisite, name, *args)
         else:
             rdtype = dns.rdatatype.RdataType.make(args[0])
             self.find_rrset(self.prerequisite, name,
@@ -266,7 +272,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                             dns.rdatatype.NONE, None,
                             True, True)
 
-    def absent(self, name, rdtype=None):
+    def absent(self, name: Union[dns.name.Name, str], rdtype=None):
         """Require that an owner name (and optionally an rdata type) does
         not exist as a prerequisite to the execution of the update."""
 
diff --git a/dns/update.pyi b/dns/update.pyi
deleted file mode 100644 (file)
index eeac059..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Optional,Dict,Union,Any
-
-from . import message, tsig, rdataclass, name
-
-class Update(message.Message):
-    def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None,
-                 keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None:
-        self.id : int
-    def add(self, name : Union[str,name.Name], *args : Any):
-        ...
-    def delete(self, name, *args : Any):
-        ...
-    def replace(self, name : Union[str,name.Name], *args : Any):
-        ...
-    def present(self, name : Union[str,name.Name], *args : Any):
-        ...
-    def absent(self, name : Union[str,name.Name], rdtype=None):
-        """Require that an owner name (and optionally an rdata type) does
-        not exist as a prerequisite to the execution of the update."""
-    def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes:
-        ...
index a7e1204bbdb89608448fe7d0672866086db5885d..02316c822a799e90de89b8f014f3fcfe26653ce1 100644 (file)
@@ -2,6 +2,8 @@
 
 """DNS Versioned Zones."""
 
+from typing import Callable, Deque, Optional, Set, Union
+
 import collections
 try:
     import threading as _threading
@@ -38,8 +40,8 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
 
     node_factory = Node
 
-    def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
-                 pruning_policy=None):
+    def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass=dns.rdataclass.IN, relativize=True,
+                 pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None):
         """Initialize a versioned zone object.
 
         *origin* is the origin of the zone.  It may be a ``dns.name.Name``,
@@ -51,26 +53,26 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
         *relativize*, a ``bool``, determine's whether domain names are
         relativized to the zone's origin.  The default is ``True``.
 
-        *pruning policy*, a function taking a `Version` and returning
-        a `bool`, or `None`.  Should the version be pruned?  If `None`,
+        *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning
+        a ``bool``, or ``None``.  Should the version be pruned?  If ``None``,
         the default policy, which retains one version is used.
         """
         super().__init__(origin, rdclass, relativize)
-        self._versions = collections.deque()
+        self._versions: Deque[Version] = collections.deque()
         self._version_lock = _threading.Lock()
         if pruning_policy is None:
             self._pruning_policy = self._default_pruning_policy
         else:
             self._pruning_policy = pruning_policy
-        self._write_txn = None
-        self._write_event = None
-        self._write_waiters = collections.deque()
-        self._readers = set()
+        self._write_txn: Optional[Transaction] = None
+        self._write_event: Optional[_threading.Event] = None
+        self._write_waiters: Deque[_threading.Event] = collections.deque()
+        self._readers: Set[Transaction] = set()
         self._commit_version_unlocked(None,
                                       WritableVersion(self, replacement=True),
                                       origin)
 
-    def reader(self, id=None, serial=None):  # pylint: disable=arguments-differ
+    def reader(self, id: Optional[int]=None, serial: Optional[int]=None) -> Transaction:  # pylint: disable=arguments-differ
         if id is not None and serial is not None:
             raise ValueError('cannot specify both id and serial')
         with self._version_lock:
@@ -86,6 +88,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
                 if self.relativize:
                     oname = dns.name.empty
                 else:
+                    assert self.origin is not None
                     oname = self.origin
                 version = None
                 for v in reversed(self._versions):
@@ -103,7 +106,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
             self._readers.add(txn)
             return txn
 
-    def writer(self, replacement=False):
+    def writer(self, replacement=False) -> Transaction:
         event = None
         while True:
             with self._version_lock:
@@ -178,21 +181,21 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
               self._pruning_policy(self, self._versions[0]):
             self._versions.popleft()
 
-    def set_max_versions(self, max_versions):
+    def set_max_versions(self, max_versions: Optional[int]):
         """Set a pruning policy that retains up to the specified number
         of versions
         """
         if max_versions is not None and max_versions < 1:
             raise ValueError('max versions must be at least 1')
         if max_versions is None:
-            def policy(*_):
+            def policy(zone, _):  # pylint: disable=unused-argument
                 return False
         else:
             def policy(zone, _):
                 return len(zone._versions) > max_versions
         self.set_pruning_policy(policy)
 
-    def set_pruning_policy(self, policy):
+    def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]):
         """Set the pruning policy for the zone.
 
         The *policy* function takes a `Version` and returns `True` if
index 572e27e708d59aa9c75e9ab92d90deace8e07616..d3317a59349bb35c16b5eaa2e4e060fc72d0575a 100644 (file)
@@ -1,5 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+from typing import Optional, Tuple
+
 import contextlib
 import struct
 
@@ -7,7 +9,7 @@ import dns.exception
 import dns.name
 
 class Parser:
-    def __init__(self, wire, current=0):
+    def __init__(self, wire: bytes, current=0):
         self.wire = wire
         self.current = 0
         self.end = len(self.wire)
@@ -18,7 +20,8 @@ class Parser:
     def remaining(self):
         return self.end - self.current
 
-    def get_bytes(self, size):
+    def get_bytes(self, size=int) -> bytes:
+        assert size >= 0
         if size > self.remaining():
             raise dns.exception.FormError
         output = self.wire[self.current:self.current + size]
@@ -26,35 +29,35 @@ class Parser:
         self.furthest = max(self.furthest, self.current)
         return output
 
-    def get_counted_bytes(self, length_size=1):
+    def get_counted_bytes(self, length_size=1) -> bytes:
         length = int.from_bytes(self.get_bytes(length_size), 'big')
         return self.get_bytes(length)
 
-    def get_remaining(self):
+    def get_remaining(self) -> bytes:
         return self.get_bytes(self.remaining())
 
-    def get_uint8(self):
+    def get_uint8(self) -> int:
         return struct.unpack('!B', self.get_bytes(1))[0]
 
-    def get_uint16(self):
+    def get_uint16(self) -> int:
         return struct.unpack('!H', self.get_bytes(2))[0]
 
-    def get_uint32(self):
+    def get_uint32(self) -> int:
         return struct.unpack('!I', self.get_bytes(4))[0]
 
-    def get_uint48(self):
+    def get_uint48(self) -> int:
         return int.from_bytes(self.get_bytes(6), 'big')
 
-    def get_struct(self, format):
+    def get_struct(self, format: str) -> Tuple:
         return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
 
-    def get_name(self, origin=None):
+    def get_name(self, origin: Optional['dns.name.Name']=None) -> 'dns.name.Name':
         name = dns.name.from_wire_parser(self)
         if origin:
             name = name.relativize(origin)
         return name
 
-    def seek(self, where):
+    def seek(self, where: int):
         # Note that seeking to the end is OK!  (If you try to read
         # after such a seek, you'll get an exception as expected.)
         if where < 0 or where > self.end:
@@ -62,7 +65,8 @@ class Parser:
         self.current = where
 
     @contextlib.contextmanager
-    def restrict_to(self, size):
+    def restrict_to(self, size: int):
+        assert size >= 0
         if size > self.remaining():
             raise dns.exception.FormError
         saved_end = self.end
index 2ef1b0a713be6f3aae4065bfa08103820ef445fa..618eac2f2d6b8b4459c8c6fae6f17bf2fd7cd3b8 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+from typing import Any, List, Optional, Tuple
+
 import dns.exception
 import dns.message
 import dns.name
 import dns.rcode
 import dns.serial
+import dns.rdataset
 import dns.rdatatype
+import dns.transaction
+import dns.tsig
 import dns.zone
 
 
@@ -46,8 +51,8 @@ class Inbound:
     State machine for zone transfers.
     """
 
-    def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR,
-                 serial=None, is_udp=False):
+    def __init__(self, txn_manager: dns.transaction.TransactionManager, rdtype=dns.rdatatype.AXFR,
+                 serial: Optional[int]=None, is_udp=False):
         """Initialize an inbound zone transfer.
 
         *txn_manager* is a :py:class:`dns.transaction.TransactionManager`.
@@ -61,7 +66,7 @@ class Inbound:
         XFR.
         """
         self.txn_manager = txn_manager
-        self.txn = None
+        self.txn: Optional[dns.transaction.Transaction] = None
         self.rdtype = rdtype
         if rdtype == dns.rdatatype.IXFR:
             if serial is None:
@@ -71,12 +76,12 @@ class Inbound:
         self.serial = serial
         self.is_udp = is_udp
         (_, _, self.origin) = txn_manager.origin_information()
-        self.soa_rdataset = None
+        self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None
         self.done = False
         self.expecting_SOA = False
         self.delete_mode = False
 
-    def process_message(self, message):
+    def process_message(self, message: dns.message.Message) -> bool:
         """Process one message in the transfer.
 
         The message should have the same relativization as was specified when
@@ -146,6 +151,7 @@ class Inbound:
             rdataset = rrset
             if self.done:
                 raise dns.exception.FormError("answers after final SOA")
+            assert self.txn is not None  # for mypy
             if rdataset.rdtype == dns.rdatatype.SOA and \
                name == self.origin:
                 #
@@ -238,11 +244,11 @@ class Inbound:
         return False
 
 
-def make_query(txn_manager, serial=0,
-               use_edns=None, ednsflags=None, payload=None,
-               request_payload=None, options=None,
-               keyring=None, keyname=None,
-               keyalgorithm=dns.tsig.default_algorithm):
+def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0,
+               use_edns=None, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+               request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
+               keyring: Any=None, keyname: Optional[dns.name.Name]=None,
+               keyalgorithm=dns.tsig.default_algorithm) -> Tuple[dns.message.QueryMessage, Optional[int]]:
     """Make an AXFR or IXFR query.
 
     *txn_manager* is a ``dns.transaction.TransactionManager``, typically a
@@ -263,6 +269,8 @@ def make_query(txn_manager, serial=0,
     Returns a `(query, serial)` tuple.
     """
     (zone_origin, _, origin) = txn_manager.origin_information()
+    if zone_origin is None:
+        raise ValueError('no zone origin')
     if serial is None:
         rdtype = dns.rdatatype.AXFR
     elif not isinstance(serial, int):
@@ -293,15 +301,17 @@ def make_query(txn_manager, serial=0,
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
     return (q, serial)
 
-def extract_serial_from_query(query):
+def extract_serial_from_query(query: dns.message.Message) -> Optional[int]:
     """Extract the SOA serial number from query if it is an IXFR and return
     it, otherwise return None.
 
     *query* is a dns.message.QueryMessage that is an IXFR or AXFR request.
 
     Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have
-    an appropriate SOA RRset in the authority section."""
-
+    an appropriate SOA RRset in the authority section.
+    """
+    if not isinstance(query, dns.message.QueryMessage):
+        raise ValueError('query not a QueryMessage')
     question = query.question[0]
     if question.rdtype == dns.rdatatype.AXFR:
         return None
index 6a154cedc8039377eb2723f31050c9e0625fd48e..a9a400775c1e029749497031ab39ae5db4f6a074 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Zones."""
 
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
+
 import contextlib
 import hashlib
 import io
@@ -30,6 +32,7 @@ import dns.node
 import dns.rdataclass
 import dns.rdatatype
 import dns.rdata
+import dns.rdataset
 import dns.rdtypes.ANY.SOA
 import dns.rdtypes.ANY.ZONEMD
 import dns.rrset
@@ -38,6 +41,7 @@ import dns.transaction
 import dns.ttl
 import dns.grange
 import dns.zonefile
+from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers
 
 
 class BadZone(dns.exception.DNSException):
@@ -80,33 +84,6 @@ class DigestVerificationFailure(dns.exception.DNSException):
     """The ZONEMD digest failed to verify."""
 
 
-class DigestScheme(dns.enum.IntEnum):
-    """ZONEMD Scheme"""
-
-    SIMPLE = 1
-
-    @classmethod
-    def _maximum(cls):
-        return 255
-
-
-class DigestHashAlgorithm(dns.enum.IntEnum):
-    """ZONEMD Hash Algorithm"""
-
-    SHA384 = 1
-    SHA512 = 2
-
-    @classmethod
-    def _maximum(cls):
-        return 255
-
-
-_digest_hashers = {
-    DigestHashAlgorithm.SHA384: hashlib.sha384,
-    DigestHashAlgorithm.SHA512: hashlib.sha512,
-}
-
-
 class Zone(dns.transaction.TransactionManager):
 
     """A DNS zone.
@@ -123,7 +100,8 @@ class Zone(dns.transaction.TransactionManager):
 
     __slots__ = ['rdclass', 'origin', 'nodes', 'relativize']
 
-    def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True):
+    def __init__(self, origin: Optional[Union[dns.name.Name, str]],
+                 rdclass=dns.rdataclass.IN, relativize=True):
         """Initialize a zone object.
 
         *origin* is the origin of the zone.  It may be a ``dns.name.Name``,
@@ -146,7 +124,7 @@ class Zone(dns.transaction.TransactionManager):
                 raise ValueError("origin parameter must be an absolute name")
         self.origin = origin
         self.rdclass = rdclass
-        self.nodes = {}
+        self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
         self.relativize = relativize
 
     def __eq__(self, other):
@@ -172,17 +150,27 @@ class Zone(dns.transaction.TransactionManager):
 
         return not self.__eq__(other)
 
-    def _validate_name(self, name):
+    def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
         elif not isinstance(name, dns.name.Name):
             raise KeyError("name parameter must be convertible to a DNS name")
         if name.is_absolute():
+            if self.origin is None:
+                # This should probably never happen as other code (e.g.
+                # _rr_line) will notice the lack of an origin before us, but
+                # we check just in case!
+                raise KeyError('no zone origin is defined')
             if not name.is_subdomain(self.origin):
                 raise KeyError(
                     "name parameter must be a subdomain of the zone origin")
             if self.relativize:
                 name = name.relativize(self.origin)
+        elif not self.relativize:
+            # We have a relative name in a non-relative zone, so derelativize.
+            if self.origin is None:
+                raise KeyError('no zone origin is defined')
+            name = name.derelativize(self.origin)
         return name
 
     def __getitem__(self, key):
@@ -217,7 +205,7 @@ class Zone(dns.transaction.TransactionManager):
         key = self._validate_name(key)
         return key in self.nodes
 
-    def find_node(self, name, create=False):
+    def find_node(self, name: Union[dns.name.Name, str], create=False):
         """Find a node in the zone, possibly creating it.
 
         *name*: the name of the node to find.
@@ -243,7 +231,7 @@ class Zone(dns.transaction.TransactionManager):
             self.nodes[name] = node
         return node
 
-    def get_node(self, name, create=False):
+    def get_node(self, name: Union[dns.name.Name, str], create=False):
         """Get a node in the zone, possibly creating it.
 
         This method is like ``find_node()``, except it returns None instead
@@ -270,7 +258,7 @@ class Zone(dns.transaction.TransactionManager):
             node = None
         return node
 
-    def delete_node(self, name):
+    def delete_node(self, name: Union[dns.name.Name, str]):
         """Delete the specified node if it exists.
 
         *name*: the name of the node to find.
@@ -285,8 +273,10 @@ class Zone(dns.transaction.TransactionManager):
         if name in self.nodes:
             del self.nodes[name]
 
-    def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
+    def find_rdataset(self, name: Union[dns.name.Name, str],
+                      rdtype: Union[dns.rdatatype.RdataType, str],
+                      covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
+                      create=False) -> dns.rdataset.Rdataset:
         """Look for an rdataset with the specified name and type in the zone,
         and return an rdataset encapsulating it.
 
@@ -300,9 +290,9 @@ class Zone(dns.transaction.TransactionManager):
         name must be a subdomain of the zone's origin.  If ``zone.relativize``
         is ``True``, then the name will be relativized.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdatatype.RdataType`` or ``str`` the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -323,13 +313,12 @@ class Zone(dns.transaction.TransactionManager):
 
         name = self._validate_name(name)
         rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if covers is not None:
-            covers = dns.rdatatype.RdataType.make(covers)
+        covers = dns.rdatatype.RdataType.make(covers)
         node = self.find_node(name, create)
         return node.find_rdataset(self.rdclass, rdtype, covers, create)
 
     def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
+                     create=False) -> Optional[dns.rdataset.Rdataset]:
         """Look for an rdataset with the specified name and type in the zone.
 
         This method is like ``find_rdataset()``, except it returns None instead
@@ -344,9 +333,9 @@ class Zone(dns.transaction.TransactionManager):
         name must be a subdomain of the zone's origin.  If ``zone.relativize``
         is ``True``, then the name will be relativized.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -371,7 +360,9 @@ class Zone(dns.transaction.TransactionManager):
             rdataset = None
         return rdataset
 
-    def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+    def delete_rdataset(self, name: Union[dns.name.Name, str],
+                        rdtype: Union[dns.rdatatype.RdataType, str],
+                        covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE):
         """Delete the rdataset matching *rdtype* and *covers*, if it
         exists at the node specified by *name*.
 
@@ -386,9 +377,9 @@ class Zone(dns.transaction.TransactionManager):
         name must be a subdomain of the zone's origin.  If ``zone.relativize``
         is ``True``, then the name will be relativized.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -401,15 +392,15 @@ class Zone(dns.transaction.TransactionManager):
 
         name = self._validate_name(name)
         rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if covers is not None:
-            covers = dns.rdatatype.RdataType.make(covers)
+        covers = dns.rdatatype.RdataType.make(covers)
         node = self.get_node(name)
         if node is not None:
             node.delete_rdataset(self.rdclass, rdtype, covers)
             if len(node) == 0:
                 self.delete_node(name)
 
-    def replace_rdataset(self, name, replacement):
+    def replace_rdataset(self, name: Union[dns.name.Name, str],
+                         replacement: dns.rdataset.Rdataset):
         """Replace an rdataset at name.
 
         It is not an error if there is no rdataset matching I{replacement}.
@@ -433,7 +424,9 @@ class Zone(dns.transaction.TransactionManager):
         node = self.find_node(name, True)
         node.replace_rdataset(replacement)
 
-    def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE):
+    def find_rrset(self, name: Union[dns.name.Name, str],
+                   rdtype: Union[dns.rdatatype.RdataType, str],
+                   covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rrset.RRset:
         """Look for an rdataset with the specified name and type in the zone,
         and return an RRset encapsulating it.
 
@@ -451,9 +444,9 @@ class Zone(dns.transaction.TransactionManager):
         name must be a subdomain of the zone's origin.  If ``zone.relativize``
         is ``True``, then the name will be relativized.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -472,16 +465,17 @@ class Zone(dns.transaction.TransactionManager):
         Returns a ``dns.rrset.RRset`` or ``None``.
         """
 
-        name = self._validate_name(name)
-        rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if covers is not None:
-            covers = dns.rdatatype.RdataType.make(covers)
-        rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers)
-        rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers)
+        vname = self._validate_name(name)
+        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+        the_covers = dns.rdatatype.RdataType.make(covers)
+        rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers)
+        rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers)
         rrset.update(rdataset)
         return rrset
 
-    def get_rrset(self, name, rdtype, covers=dns.rdatatype.NONE):
+    def get_rrset(self, name: Union[dns.name.Name, str],
+                  rdtype: Union[dns.rdatatype.RdataType, str],
+                  covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> Optional[dns.rrset.RRset]:
         """Look for an rdataset with the specified name and type in the zone,
         and return an RRset encapsulating it.
 
@@ -498,9 +492,9 @@ class Zone(dns.transaction.TransactionManager):
         name must be a subdomain of the zone's origin.  If ``zone.relativize``
         is ``True``, then the name will be relativized.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -526,15 +520,15 @@ class Zone(dns.transaction.TransactionManager):
         return rrset
 
     def iterate_rdatasets(self, rdtype=dns.rdatatype.ANY,
-                          covers=dns.rdatatype.NONE):
+                          covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
         """Return a generator which yields (name, rdataset) tuples for
         all rdatasets in the zone which have the specified *rdtype*
         and *covers*.  If *rdtype* is ``dns.rdatatype.ANY``, the default,
         then all rdatasets will be matched.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -546,8 +540,7 @@ class Zone(dns.transaction.TransactionManager):
         """
 
         rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if covers is not None:
-            covers = dns.rdatatype.RdataType.make(covers)
+        covers = dns.rdatatype.RdataType.make(covers)
         for (name, node) in self.items():
             for rds in node:
                 if rdtype == dns.rdatatype.ANY or \
@@ -555,15 +548,15 @@ class Zone(dns.transaction.TransactionManager):
                     yield (name, rds)
 
     def iterate_rdatas(self, rdtype=dns.rdatatype.ANY,
-                       covers=dns.rdatatype.NONE):
+                       covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]:
         """Return a generator which yields (name, ttl, rdata) tuples for
         all rdatas in the zone which have the specified *rdtype*
         and *covers*.  If *rdtype* is ``dns.rdatatype.ANY``, the default,
         then all rdatas will be matched.
 
-        *rdtype*, an ``int`` or ``str``, the rdata type desired.
+        *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired.
 
-        *covers*, an ``int`` or ``str`` or ``None``, the covered type.
+        *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type.
         Usually this value is ``dns.rdatatype.NONE``, but if the
         rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
         then the covers value will be the rdata type the SIG/RRSIG
@@ -575,8 +568,7 @@ class Zone(dns.transaction.TransactionManager):
         """
 
         rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if covers is not None:
-            covers = dns.rdatatype.RdataType.make(covers)
+        covers = dns.rdatatype.RdataType.make(covers)
         for (name, node) in self.items():
             for rds in node:
                 if rdtype == dns.rdatatype.ANY or \
@@ -584,7 +576,7 @@ class Zone(dns.transaction.TransactionManager):
                     for rdata in rds:
                         yield (name, rds.ttl, rdata)
 
-    def to_file(self, f, sorted=True, relativize=True, nl=None,
+    def to_file(self, f: Any, sorted=True, relativize=True, nl: Optional[str]=None,
                 want_comments=False, want_origin=False):
         """Write a zone to a file.
 
@@ -634,6 +626,7 @@ class Zone(dns.transaction.TransactionManager):
                 nl = nl.decode()
 
             if want_origin:
+                assert self.origin is not None
                 l = '$ORIGIN ' + self.origin.to_text()
                 l_b = l.encode(file_enc)
                 try:
@@ -661,7 +654,7 @@ class Zone(dns.transaction.TransactionManager):
                     f.write(l)
                     f.write(nl)
 
-    def to_text(self, sorted=True, relativize=True, nl=None,
+    def to_text(self, sorted=True, relativize=True, nl: Optional[str]=None,
                 want_comments=False, want_origin=False):
         """Return a zone's text as though it were written to a file.
 
@@ -713,7 +706,7 @@ class Zone(dns.transaction.TransactionManager):
         if self.get_rdataset(name, dns.rdatatype.NS) is None:
             raise NoNS
 
-    def get_soa(self, txn=None):
+    def get_soa(self, txn: Optional[dns.transaction.Transaction]=None):
         """Get the zone SOA RR.
 
         Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
@@ -723,7 +716,12 @@ class Zone(dns.transaction.TransactionManager):
         if self.relativize:
             origin_name = dns.name.empty
         else:
+            if self.origin is None:
+                # get_soa() has been called very early, and there must not be
+                # an SOA if there is no origin.
+                raise NoSOA
             origin_name = self.origin
+        soa: Optional[dns.rdataset.Rdataset]
         if txn:
             soa = txn.get(origin_name, dns.rdatatype.SOA)
         else:
@@ -732,7 +730,7 @@ class Zone(dns.transaction.TransactionManager):
             raise NoSOA
         return soa[0]
 
-    def _compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE):
+    def _compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> bytes:
         hashinfo = _digest_hashers.get(hash_algorithm)
         if not hashinfo:
             raise UnsupportedDigestHashAlgorithm
@@ -742,6 +740,7 @@ class Zone(dns.transaction.TransactionManager):
         if self.relativize:
             origin_name = dns.name.empty
         else:
+            assert self.origin is not None
             origin_name = self.origin
         hasher = hashinfo()
         for (name, node) in sorted(self.items()):
@@ -760,11 +759,7 @@ class Zone(dns.transaction.TransactionManager):
                     hasher.update(rrnamebuf + rrfixed + rrlen + rdata)
         return hasher.digest()
 
-    def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE):
-        if self.relativize:
-            origin_name = dns.name.empty
-        else:
-            origin_name = self.origin
+    def compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD:
         serial = self.get_soa().serial
         digest = self._compute_digest(hash_algorithm, scheme)
         return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass,
@@ -772,13 +767,15 @@ class Zone(dns.transaction.TransactionManager):
                                              serial, scheme, hash_algorithm,
                                              digest)
 
-    def verify_digest(self, zonemd=None):
+    def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None):
+        digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]]
         if zonemd:
             digests = [zonemd]
         else:
-            digests = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
-            if digests is None:
+            rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
+            if rds is None:
                 raise NoDigest
+            digests = rds
         for digest in digests:
             try:
                 computed = self._compute_digest(digest.hash_algorithm,
@@ -791,16 +788,17 @@ class Zone(dns.transaction.TransactionManager):
 
     # TransactionManager methods
 
-    def reader(self):
+    def reader(self) -> 'Transaction':
         return Transaction(self, False,
                            Version(self, 1, self.nodes, self.origin))
 
-    def writer(self, replacement=False):
+    def writer(self, replacement=False) -> 'Transaction':
         txn = Transaction(self, replacement)
         txn._setup_version()
         return txn
 
-    def origin_information(self):
+    def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
+        effective: Optional[dns.name.Name]
         if self.relativize:
             effective = dns.name.empty
         else:
@@ -878,7 +876,9 @@ class ImmutableVersionedNode(VersionedNode):
 
 
 class Version:
-    def __init__(self, zone, id, nodes=None, origin=None):
+    def __init__(self, zone: Zone, id: int,
+                 nodes: Optional[Dict[dns.name.Name, dns.node.Node]]=None,
+                 origin: Optional[dns.name.Name]=None):
         self.zone = zone
         self.id = id
         if nodes is not None:
@@ -887,7 +887,7 @@ class Version:
             self.nodes = {}
         self.origin = origin
 
-    def _validate_name(self, name):
+    def _validate_name(self, name: dns.name.Name):
         if name.is_absolute():
             if self.origin is None:
                 # This should probably never happen as other code (e.g.
@@ -898,13 +898,19 @@ class Version:
                 raise KeyError("name is not a subdomain of the zone origin")
             if self.zone.relativize:
                 name = name.relativize(self.origin)
+        elif not self.zone.relativize:
+            # We have a relative name in a non-relative zone, so derelativize.
+            if self.origin is None:
+                raise KeyError('no zone origin is defined')
+            name = name.derelativize(self.origin)
         return name
 
-    def get_node(self, name):
+    def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
         name = self._validate_name(name)
         return self.nodes.get(name)
 
-    def get_rdataset(self, name, rdtype, covers):
+    def get_rdataset(self, name: dns.name.Name, rdtype: dns.rdatatype.RdataType,
+                     covers: dns.rdatatype.RdataType) -> Optional[dns.rdataset.Rdataset]:
         node = self.get_node(name)
         if node is None:
             return None
@@ -915,7 +921,7 @@ class Version:
 
 
 class WritableVersion(Version):
-    def __init__(self, zone, replacement=False):
+    def __init__(self, zone: Zone, replacement=False):
         # The zone._versions_lock must be held by our caller in a versioned
         # zone.
         id = zone._get_next_version_id()
@@ -929,9 +935,9 @@ class WritableVersion(Version):
         # We have to copy the zone origin as it may be None in the first
         # version, and we don't want to mutate the zone until we commit.
         self.origin = zone.origin
-        self.changed = set()
+        self.changed: Set[dns.name.Name] = set()
 
-    def _maybe_cow(self, name):
+    def _maybe_cow(self, name: dns.name.Name):
         name = self._validate_name(name)
         node = self.nodes.get(name)
         if node is None or name not in self.changed:
@@ -941,7 +947,9 @@ class WritableVersion(Version):
                 # code used new_node.id != self.id for the "do we need to CoW?"
                 # test.  Now we use the changed set as this works with both
                 # regular zones and versioned zones.
-                new_node.id = self.id
+                #
+                # We ignore the mypy error as this is safe but it doesn't see it.
+                new_node.id = self.id   # type: ignore
             if node is not None:
                 # moo!  copy on write!
                 new_node.rdatasets.extend(node.rdatasets)
@@ -951,17 +959,18 @@ class WritableVersion(Version):
         else:
             return node
 
-    def delete_node(self, name):
+    def delete_node(self, name: dns.name.Name):
         name = self._validate_name(name)
         if name in self.nodes:
             del self.nodes[name]
             self.changed.add(name)
 
-    def put_rdataset(self, name, rdataset):
+    def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset):
         node = self._maybe_cow(name)
         node.replace_rdataset(rdataset)
 
-    def delete_rdataset(self, name, rdtype, covers):
+    def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType,
+                        covers: dns.rdatatype.RdataType):
         node = self._maybe_cow(name)
         node.delete_rdataset(self.zone.rdclass, rdtype, covers)
         if len(node) == 0:
@@ -970,7 +979,7 @@ class WritableVersion(Version):
 
 @dns.immutable.immutable
 class ImmutableVersion(Version):
-    def __init__(self, version):
+    def __init__(self, version: WritableVersion):
         # We tell super() that it's a replacement as we don't want it
         # to copy the nodes, as we're about to do that with an
         # immutable Dict.
@@ -985,7 +994,9 @@ class ImmutableVersion(Version):
             # it might not exist if we deleted it in the version
             if node:
                 version.nodes[name] = ImmutableVersionedNode(node)
-        self.nodes = dns.immutable.Dict(version.nodes, True)
+        # We're changing the type of the nodes dictionary here on purpose, so
+        # we ignore the mypy error.
+        self.nodes = dns.immutable.Dict(version.nodes, True)  # type: ignore
 
 
 class Transaction(dns.transaction.Transaction):
@@ -1066,9 +1077,11 @@ class Transaction(dns.transaction.Transaction):
         return (absolute, relativize, effective)
 
 
-def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
-              relativize=True, zone_factory=Zone, filename=None,
-              allow_include=False, check_origin=True, idna_codec=None):
+def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
+              rdclass=dns.rdataclass.IN,
+              relativize=True, zone_factory=Zone, filename: Optional[str]=None,
+              allow_include=False, check_origin=True,
+              idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone:
     """Build a zone object from a zone file format string.
 
     *text*, a ``str``, the zone file format input.
@@ -1077,7 +1090,7 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
     of the zone; if not specified, the first ``$ORIGIN`` statement in the
     zone file will determine the origin of the zone.
 
-    *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
+    *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is class IN.
 
     *relativize*, a ``bool``, determine's whether domain names are
     relativized to the zone's origin.  The default is ``True``.
@@ -1132,9 +1145,10 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
     return zone
 
 
-def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
-              relativize=True, zone_factory=Zone, filename=None,
-              allow_include=True, check_origin=True):
+def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
+              rdclass=dns.rdataclass.IN,
+              relativize=True, zone_factory=Zone, filename: Optional[str]=None,
+              allow_include=True, check_origin=True) -> Zone:
     """Read a zone file and build a zone object.
 
     *f*, a file or ``str``.  If *f* is a string, it is treated
@@ -1184,6 +1198,7 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
             f = stack.enter_context(open(f))
         return from_text(f, origin, rdclass, relativize, zone_factory,
                          filename, allow_include, check_origin)
+    assert False  # make mypy happy
 
 
 def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
diff --git a/dns/zone.pyi b/dns/zone.pyi
deleted file mode 100644 (file)
index 272814f..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict
-from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype
-
-class BadZone(exception.DNSException): ...
-class NoSOA(BadZone): ...
-class NoNS(BadZone): ...
-class UnknownOrigin(BadZone): ...
-
-class Zone:
-    def __getitem__(self, key : str) -> node.Node:
-        ...
-    def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None:
-        self.nodes : Dict[str,node.Node]
-        self.origin = origin
-    def values(self):
-        return self.nodes.values()
-    def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]:
-        ...
-    def __iter__(self) -> Iterator[str]:
-        ...
-    def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]:
-        ...
-    def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset:
-        ...
-    def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE,
-                      create=False) -> rdataset.Rdataset:
-        ...
-    def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]:
-        ...
-    def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]:
-        ...
-    def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None:
-        ...
-    def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None:
-        ...
-    def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY,
-                          covers : Union[str,int] =rdatatype.NONE):
-        ...
-    def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None):
-        ...
-    def to_text(self, sorted=True, relativize=True, nl : Optional[str] = None) -> str:
-        ...
-
-def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True):
-    ...
-
-def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN,
-              relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None,
-              allow_include=False, check_origin=True) -> zone.Zone:
-    ...
-
-def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN,
-              relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None,
-              allow_include=True, check_origin=True) -> zone.Zone:
-    ...
index 53b40880bc916c9f0a3ace8c04060a57ded76e7b..605131dcce6615fac6b5f87b2c829d0cdad27db1 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Zones."""
 
+from typing import Any, List, Optional, Tuple, Union
+
 import re
 import sys
 
@@ -61,14 +63,27 @@ def _check_cname_and_other_data(txn, name, rdataset):
     # adding the rdataset is ok
 
 
+SavedStateType = Tuple[dns.tokenizer.Tokenizer,
+                       Optional[dns.name.Name],   # current_origin
+                       Optional[dns.name.Name],   # last_name
+                       Optional[str],             # current_file
+                       int,                       # last_ttl
+                       bool,                      # last_ttl_known
+                       int,                       # default_ttl
+                       bool]                      # default_ttl_known
+
+
 class Reader:
 
     """Read a DNS zone file into a transaction."""
 
-    def __init__(self, tok, rdclass, txn, allow_include=False,
-                 allow_directives=True, force_name=None,
-                 force_ttl=None, force_rdclass=None, force_rdtype=None,
-                 default_ttl=None):
+    def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass,
+                 txn: dns.transaction.Transaction, allow_include=False,
+                 allow_directives=True, force_name: Optional[dns.name.Name]=None,
+                 force_ttl: Optional[int]=None,
+                 force_rdclass: Optional[dns.rdataclass.RdataClass]=None,
+                 force_rdtype: Optional[dns.rdatatype.RdataType]=None,
+                 default_ttl: Optional[int]=None):
         self.tok = tok
         (self.zone_origin, self.relativize, _) = \
             txn.manager.origin_information()
@@ -86,7 +101,7 @@ class Reader:
         self.last_name = self.current_origin
         self.zone_rdclass = rdclass
         self.txn = txn
-        self.saved_state = []
+        self.saved_state: List[SavedStateType] = []
         self.current_file = None
         self.allow_include = allow_include
         self.allow_directives = allow_directives
@@ -548,10 +563,16 @@ class RRSetsReaderManager(dns.transaction.TransactionManager):
         self.rrsets = rrsets
 
 
-def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN,
-                default_rdclass=dns.rdataclass.IN,
-                rdtype=None, default_ttl=None, idna_codec=None,
-                origin=dns.name.root, relativize=False):
+def read_rrsets(text: Any,
+                name: Optional[Union[dns.name.Name, str]]=None,
+                ttl: Optional[int]=None,
+                rdclass: Optional[Union[dns.rdataclass.RdataClass, str]]=dns.rdataclass.IN,
+                default_rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+                rdtype: Optional[Union[dns.rdatatype.RdataType, str]]=None,
+                default_ttl: Optional[Union[int, str]]=None,
+                idna_codec: Optional[dns.name.IDNACodec]=None,
+                origin: Optional[Union[dns.name.Name, str]]=dns.name.root,
+                relativize=False) -> List[dns.rrset.RRset]:
     """Read one or more rrsets from the specified text, possibly subject
     to restrictions.
 
@@ -610,15 +631,19 @@ def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN,
     if isinstance(default_ttl, str):
         default_ttl = dns.ttl.from_text(default_ttl)
     if rdclass is not None:
-        rdclass = dns.rdataclass.RdataClass.make(rdclass)
-    default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+        the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    else:
+        the_rdclass = None
+    the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
     if rdtype is not None:
-        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+    else:
+        the_rdtype = None
     manager = RRSetsReaderManager(origin, relativize, default_rdclass)
     with manager.writer(True) as txn:
         tok = dns.tokenizer.Tokenizer(text, '<input>', idna_codec=idna_codec)
-        reader = Reader(tok, default_rdclass, txn, allow_directives=False,
-                        force_name=name, force_ttl=ttl, force_rdclass=rdclass,
-                        force_rdtype=rdtype, default_ttl=default_ttl)
+        reader = Reader(tok, the_default_rdclass, txn, allow_directives=False,
+                        force_name=name, force_ttl=ttl, force_rdclass=the_rdclass,
+                        force_rdtype=the_rdtype, default_ttl=default_ttl)
         reader.read()
     return manager.rrsets
diff --git a/dns/zonetypes.py b/dns/zonetypes.py
new file mode 100644 (file)
index 0000000..195ee2e
--- /dev/null
@@ -0,0 +1,37 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""Common zone-related types."""
+
+# This is a separate file to avoid import circularity between dns.zone and
+# the implementation of the ZONEMD type.
+
+import hashlib
+
+import dns.enum
+
+
+class DigestScheme(dns.enum.IntEnum):
+    """ZONEMD Scheme"""
+
+    SIMPLE = 1
+
+    @classmethod
+    def _maximum(cls):
+        return 255
+
+
+class DigestHashAlgorithm(dns.enum.IntEnum):
+    """ZONEMD Hash Algorithm"""
+
+    SHA384 = 1
+    SHA512 = 2
+
+    @classmethod
+    def _maximum(cls):
+        return 255
+
+
+_digest_hashers = {
+    DigestHashAlgorithm.SHA384: hashlib.sha384,
+    DigestHashAlgorithm.SHA512: hashlib.sha512,
+}
index 19107cb52cda988087b5e57ad39fc02060d3781a..348deefb903faca8bdc23e86010f2cffc59acaa8 100644 (file)
@@ -16,6 +16,5 @@ Dnspython Manual
    async
    exceptions
    utilities
-   typing
    threads
    examples
diff --git a/doc/typing.rst b/doc/typing.rst
deleted file mode 100644 (file)
index 1325f10..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-.. _typing:
-
-A Note on Typing
-----------------
-
-Dnspython has partial support for type annotations in separate .pyi
-files.  Type information will not be integrated into the main files
-until major LTS versions of various Linux distributions containing 3.6
-are beyond their support times.  Improvements to the .pyi files are
-welcome during this time.
index a0ba7e30a7b1d6a4244ba6caac412c79615311b1..de66885af9b0a919918ad5161e6b02269c7146de 100644 (file)
--- a/mypy.ini
+++ b/mypy.ini
@@ -2,3 +2,9 @@
 
 [mypy-requests_toolbelt.*]
 ignore_missing_imports = True
+
+[mypy-curio]
+ignore_missing_imports = True
+
+[mypy-trio]
+ignore_missing_imports = True
index d4d76275bc05266d4a93ef856f61531d08b113e1..e81899256745eae9d4162ab08efc22ab18ad725a 100644 (file)
@@ -15,6 +15,8 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+from typing import Any
+
 import unittest
 
 import dns.dnssec
@@ -22,6 +24,8 @@ import dns.name
 import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
+import dns.rdtypes.ANY.CDS
+import dns.rdtypes.ANY.DS
 import dns.rrset
 
 # pylint: disable=line-too-long
@@ -472,6 +476,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
         self.assertEqual(good_ds, good_ds_mnemonic)
 
     def testMakeExampleSHA1DS(self):  # type: () -> None
+        algorithm: Any
         for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1):
             ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
             self.assertEqual(ds, example_ds_sha1)
@@ -479,11 +484,13 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
             self.assertEqual(ds, example_ds_sha1)
 
     def testMakeExampleSHA256DS(self):  # type: () -> None
+        algorithm: Any
         for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256):
             ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
             self.assertEqual(ds, example_ds_sha256)
 
     def testMakeExampleSHA384DS(self):  # type: () -> None
+        algorithm: Any
         for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384):
             ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
             self.assertEqual(ds, example_ds_sha384)
@@ -493,6 +500,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
         self.assertEqual(ds, good_ds)
 
     def testInvalidAlgorithm(self):  # type: () -> None
+        algorithm: Any
         for algorithm in (10, 'shax'):
             with self.assertRaises(dns.dnssec.UnsupportedAlgorithm):
                 ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
@@ -508,6 +516,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
         for rdtype in digest_types:
             rd = dns.rdata.from_text(dns.rdataclass.IN, rdtype,
                                      f'18673 3 5 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7')
+            assert isinstance(rd, dns.rdtypes.ANY.DS.DS) or isinstance(rd, dns.rdtypes.ANY.CDS.CDS)
             self.assertEqual(rd.digest_type, 5)
             self.assertEqual(rd.digest, bytes.fromhex('71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7'))
 
index f91d7e65a5dbc15da0bbeff5462efff59897b49d..45f83793ea5c4c4d281c14ed4a523dd0ce034aa0 100644 (file)
@@ -89,7 +89,7 @@ class NameTestCase(unittest.TestCase):
             try:
                 dns.name.from_text(t)
             except Exception:
-                self.fail("good test '%s' raised an exception" % t)
+                self.fail("good test '%r' raised an exception" % t)
         for t in bad:
             caught = False
             try:
@@ -97,7 +97,7 @@ class NameTestCase(unittest.TestCase):
             except Exception:
                 caught = True
             if not caught:
-                self.fail("bad test '%s' did not raise an exception" % t)
+                self.fail("bad test '%r' did not raise an exception" % t)
 
     def testImmutable1(self):
         def bad():
@@ -106,7 +106,7 @@ class NameTestCase(unittest.TestCase):
 
     def testImmutable2(self):
         def bad():
-            self.origin.labels[0] = 'foo'
+            self.origin.labels[0] = 'foo'  # type: ignore
         self.assertRaises(TypeError, bad)
 
     def testAbs1(self):
@@ -879,7 +879,7 @@ class NameTestCase(unittest.TestCase):
 
     def testReverseIPv6(self):
         e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')
-        n = dns.reversename.from_address(b'::1')
+        n = dns.reversename.from_address('::1')
         self.assertEqual(e, n)
 
     def testReverseIPv6MappedIpv4(self):
@@ -906,7 +906,7 @@ class NameTestCase(unittest.TestCase):
     def testReverseIPv6AlternateOrigin(self):
         e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar.')
         origin = dns.name.from_text('foo.bar')
-        n = dns.reversename.from_address(b'::1', v6_origin=origin)
+        n = dns.reversename.from_address('::1', v6_origin=origin)
         self.assertEqual(e, n)
 
     def testForwardIPv4(self):
@@ -980,12 +980,12 @@ class NameTestCase(unittest.TestCase):
 
     def testFromUnicodeNotString(self):
         def bad():
-            dns.name.from_unicode(b'123')
+            dns.name.from_unicode(b'123')  # type: ignore
         self.assertRaises(ValueError, bad)
 
     def testFromUnicodeBadOrigin(self):
         def bad():
-            dns.name.from_unicode('example', 123)
+            dns.name.from_unicode('example', 123)  # type: ignore
         self.assertRaises(ValueError, bad)
 
     def testFromUnicodeEmptyLabel(self):
@@ -998,17 +998,17 @@ class NameTestCase(unittest.TestCase):
 
     def testFromTextNotString(self):
         def bad():
-            dns.name.from_text(123)
+            dns.name.from_text(123)  # type: ignore
         self.assertRaises(ValueError, bad)
 
     def testFromTextBadOrigin(self):
         def bad():
-            dns.name.from_text('example', 123)
+            dns.name.from_text('example', 123)  # type: ignore
         self.assertRaises(ValueError, bad)
 
     def testFromWireNotBytes(self):
         def bad():
-            dns.name.from_wire(123, 0)
+            dns.name.from_wire(123, 0)  # type: ignore
         self.assertRaises(ValueError, bad)
 
     def testBadPunycode(self):
@@ -1035,7 +1035,7 @@ class NameTestCase(unittest.TestCase):
             c.encode('Königsgäßchen')
         with self.assertRaises(dns.name.NoIDNA2008):
             c = dns.name.IDNA2008Codec(strict_decode=True)
-            c.decode('xn--eckwd4c7c.xn--zckzah.')
+            c.decode(b'xn--eckwd4c7c.xn--zckzah.')
         dns.name.have_idna_2008 = True
 
     @unittest.skipUnless(dns.name.have_idna_2008,
index 4be695a0b1d2d20dd629aca1523e29f2cd86ad69..76754dde88c31946c9daf3f994f9f2b9abf2a146 100644 (file)
@@ -1,6 +1,7 @@
 
 import dns.rdata
 import dns.rdataset
+import dns.rdtypes.IN.SRV
 
 
 def test_processing_order_shuffle():
@@ -42,6 +43,7 @@ def test_processing_order_priority_weighted():
         for j in range(3):
             assert rds[j] in po
         assert rds[0] == po[0]
+        assert isinstance(po[1], dns.rdtypes.IN.SRV.SRV)
         if po[1].weight == 90:
             weight_90_count += 1
         else:
index 88b1e58ebe2b76cc666a111a28f28a6ab430f97f..473c73339011f13a8f4dda018234d98a97cfb37e 100644 (file)
@@ -499,6 +499,13 @@ class ZoneTestCase(unittest.TestCase):
         rds = z.get_rdataset('@', 'loc')
         self.assertTrue(rds is None)
 
+    def testGetRdatasetWithRelativeNameFromAbsoluteZone(self):
+        z = dns.zone.from_text(example_text, 'example.', relativize=False)
+        rds = z.get_rdataset(dns.name.empty, 'soa')
+        self.assertIsNotNone(rds)
+        exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo.example. bar.example. 1 2 3 4 5')
+        self.assertEqual(rds, exrds)
+
     def testGetRRset1(self):
         z = dns.zone.from_text(example_text, 'example.', relativize=True)
         rrs = z.get_rrset('@', 'soa')
@@ -1077,7 +1084,6 @@ class VersionedZoneTestCase(unittest.TestCase):
             self.assertTrue(soa.rdtype, dns.rdatatype.SOA)
             self.assertEqual(soa.serial, 1)
 
-
     def testGetSoaEmptyZone(self):
         z = dns.zone.Zone('example.')
         with self.assertRaises(dns.zone.NoSOA):