From: Bob Halley Date: Tue, 15 Mar 2022 15:37:20 +0000 (-0700) Subject: black autoformatting X-Git-Tag: v2.3.0rc1~93 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b1d2332687adbecc0acbb4e623124f783f859d9e;p=thirdparty%2Fdnspython.git black autoformatting --- diff --git a/.flake8 b/.flake8 index 809116b6..39a47307 100644 --- a/.flake8 +++ b/.flake8 @@ -1,21 +1,3 @@ [flake8] -ignore = - # Prefer emacs indentation of continued lines - E126, - E127, - E129, - # Whitespace round parameter '=' can be excessive - E252, - # Multiple # in a comment is OK - E266, - # Not excited by the "two blank lines" rule - E302, - E305, - # or the one blank line rule - E306, - # Ambigious variables are ok. - E741, - # Lines ending with binary operators are OK - W504, - -max-line-length = 120 +extend-ignore = W503, E203, E266, E741, F401 +max-line-length = 88 diff --git a/README.md b/README.md index 282713e5..324d4ace 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![License: ISC](https://img.shields.io/badge/License-ISC-brightgreen.svg)](https://opensource.org/licenses/ISC) [![Coverage](https://codecov.io/github/rthalley/dnspython/coverage.svg?branch=master)](https://codecov.io/github/rthalley/dnspython) [![LGTM Grade](https://img.shields.io/lgtm/grade/python/github/rthalley/dnspython)](https://lgtm.com/projects/g/rthalley/dnspython/) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) ## INTRODUCTION diff --git a/dns/__init__.py b/dns/__init__.py index a620f975..196be22d 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -18,51 +18,51 @@ """dnspython DNS toolkit""" __all__ = [ - 'asyncbackend', - 'asyncquery', - 'asyncresolver', - 'dnssec', - 'dnssectypes', - 'e164', - 'edns', - 'entropy', - 'exception', - 'flags', - 'immutable', - 'inet', - 'ipv4', - 'ipv6', - 'message', - 'name', - 'namedict', - 'node', - 'opcode', - 'query', - 'rcode', - 'rdata', - 'rdataclass', - 'rdataset', - 'rdatatype', - 'renderer', - 'resolver', - 'reversename', - 'rrset', - 'serial', - 'set', - 'tokenizer', - 'transaction', - 'tsig', - 'tsigkeyring', - 'ttl', - 'rdtypes', - 'update', - 'version', - 'versioned', - 'wire', - 'xfr', - 'zone', - 'zonetypes', - 'zonefile', + "asyncbackend", + "asyncquery", + "asyncresolver", + "dnssec", + "dnssectypes", + "e164", + "edns", + "entropy", + "exception", + "flags", + "immutable", + "inet", + "ipv4", + "ipv6", + "message", + "name", + "namedict", + "node", + "opcode", + "query", + "rcode", + "rdata", + "rdataclass", + "rdataset", + "rdatatype", + "renderer", + "resolver", + "reversename", + "rrset", + "serial", + "set", + "tokenizer", + "transaction", + "tsig", + "tsigkeyring", + "ttl", + "rdtypes", + "update", + "version", + "versioned", + "wire", + "xfr", + "zone", + "zonetypes", + "zonefile", ] from dns.version import version as __version__ # noqa diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index 674bf6ea..ff24604f 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -3,6 +3,7 @@ # This is a nullcontext for both sync and async. 3.7 has a nullcontext, # but it is only for sync use. + class NullContext: def __init__(self, enter_result=None): self.enter_result = enter_result @@ -23,6 +24,7 @@ class NullContext: # These are declared here so backends can import them without creating # circular dependencies with dns.asyncbackend. + class Socket: # pragma: no cover async def close(self): pass @@ -59,13 +61,21 @@ class StreamSocket(Socket): # pragma: no cover raise NotImplementedError -class Backend: # pragma: no cover +class Backend: # pragma: no cover def name(self): - return 'unknown' - - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): + return "unknown" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): raise NotImplementedError def datagram_connection_required(self): diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 10917774..50bde1dd 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -10,7 +10,8 @@ import dns._asyncbackend import dns.exception -_is_win32 = sys.platform == 'win32' +_is_win32 = sys.platform == "win32" + def _get_running_loop(): try: @@ -76,10 +77,10 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): self.protocol.close() async def getpeername(self): - return self.transport.get_extra_info('peername') + return self.transport.get_extra_info("peername") async def getsockname(self): - return self.transport.get_extra_info('sockname') + return self.transport.get_extra_info("sockname") class StreamSocket(dns._asyncbackend.StreamSocket): @@ -93,8 +94,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return await _maybe_wait_for(self.writer.drain(), timeout) async def recv(self, size, timeout): - return await _maybe_wait_for(self.reader.read(size), - timeout) + return await _maybe_wait_for(self.reader.read(size), timeout) async def close(self): self.writer.close() @@ -104,43 +104,60 @@ class StreamSocket(dns._asyncbackend.StreamSocket): pass async def getpeername(self): - return self.writer.get_extra_info('peername') + return self.writer.get_extra_info("peername") async def getsockname(self): - return self.writer.get_extra_info('sockname') + return self.writer.get_extra_info("sockname") class Backend(dns._asyncbackend.Backend): def name(self): - return 'asyncio' - - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): - if destination is None and socktype == socket.SOCK_DGRAM and \ - _is_win32: - raise NotImplementedError('destinationless datagram sockets ' - 'are not supported by asyncio ' - 'on Windows') + return "asyncio" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): + if destination is None and socktype == socket.SOCK_DGRAM and _is_win32: + raise NotImplementedError( + "destinationless datagram sockets " + "are not supported by asyncio " + "on Windows" + ) loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: transport, protocol = await loop.create_datagram_endpoint( - _DatagramProtocol, source, family=af, - proto=proto, remote_addr=destination) + _DatagramProtocol, + source, + family=af, + proto=proto, + remote_addr=destination, + ) return DatagramSocket(af, transport, protocol) elif socktype == socket.SOCK_STREAM: (r, w) = await _maybe_wait_for( - asyncio.open_connection(destination[0], - destination[1], - ssl=ssl_context, - family=af, - proto=proto, - local_addr=source, - server_hostname=server_hostname), - timeout) + asyncio.open_connection( + destination[0], + destination[1], + ssl=ssl_context, + family=af, + proto=proto, + local_addr=source, + server_hostname=server_hostname, + ), + timeout, + ) return StreamSocket(af, r, w) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await asyncio.sleep(interval) diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index 3f22b5d3..765d6471 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -32,7 +32,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def sendto(self, what, destination, timeout): async with _maybe_timeout(timeout): return await self.socket.sendto(what, destination) - raise dns.exception.Timeout(timeout=timeout) # pragma: no cover lgtm[py/unreachable-statement] + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] async def recvfrom(self, size, timeout): async with _maybe_timeout(timeout): @@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket): class Backend(dns._asyncbackend.Backend): def name(self): - return 'curio' - - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): + return "curio" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): if socktype == socket.SOCK_DGRAM: s = curio.socket.socket(af, socktype, proto) try: @@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend): else: source_addr = None async with _maybe_timeout(timeout): - s = await curio.open_connection(destination[0], destination[1], - ssl=ssl_context, - source_addr=source_addr, - server_hostname=server_hostname) + s = await curio.open_connection( + destination[0], + destination[1], + ssl=ssl_context, + source_addr=source_addr, + server_hostname=server_hostname, + ) return StreamSocket(s) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await curio.sleep(interval) diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py index ececdbeb..63c0a2d3 100644 --- a/dns/_immutable_ctx.py +++ b/dns/_immutable_ctx.py @@ -8,7 +8,7 @@ import contextvars import inspect -_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) +_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) class _Immutable: @@ -41,6 +41,7 @@ def _immutable_init(f): f(*args, **kwargs) finally: _in__init__.reset(previous) + nf.__signature__ = inspect.signature(f) return nf @@ -50,7 +51,7 @@ def immutable(cls): # Some ancestor already has the mixin, so just make sure we keep # following the __init__ protocol. cls.__init__ = _immutable_init(cls.__init__) - if hasattr(cls, '__setstate__'): + if hasattr(cls, "__setstate__"): cls.__setstate__ = _immutable_init(cls.__setstate__) ncls = cls else: @@ -63,7 +64,8 @@ def immutable(cls): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if hasattr(cls, '__setstate__'): + if hasattr(cls, "__setstate__"): + @_immutable_init def __setstate__(self, *args, **kwargs): super().__setstate__(*args, **kwargs) diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 8a337e9d..b0c02103 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -32,7 +32,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def sendto(self, what, destination, timeout): with _maybe_timeout(timeout): return await self.socket.sendto(what, destination) - raise dns.exception.Timeout(timeout=timeout) # pragma: no cover lgtm[py/unreachable-statement] + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] async def recvfrom(self, size, timeout): with _maybe_timeout(timeout): @@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket): class Backend(dns._asyncbackend.Backend): def name(self): - return 'trio' - - async def make_socket(self, af, socktype, proto=0, source=None, - destination=None, timeout=None, - ssl_context=None, server_hostname=None): + return "trio" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): s = trio.socket.socket(af, socktype, proto) stream = None try: @@ -107,14 +117,16 @@ class Backend(dns._asyncbackend.Backend): if ssl_context: tls = True try: - stream = trio.SSLStream(stream, ssl_context, - server_hostname=server_hostname) + stream = trio.SSLStream( + stream, ssl_context, server_hostname=server_hostname + ) except Exception: # pragma: no cover await stream.aclose() raise return StreamSocket(af, stream, tls) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await trio.sleep(interval) diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index ffd6d674..c7565a99 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -6,7 +6,12 @@ import dns.exception # pylint: disable=unused-import -from dns._asyncbackend import Socket, DatagramSocket, StreamSocket, Backend # noqa: F401 lgtm[py/unused-import] +from dns._asyncbackend import ( + Socket, + DatagramSocket, + StreamSocket, + Backend, +) # noqa: F401 lgtm[py/unused-import] # pylint: enable=unused-import @@ -17,6 +22,7 @@ _backends: Dict[str, Backend] = {} # Allow sniffio import to be disabled for testing purposes _no_sniffio = False + class AsyncLibraryNotFoundError(dns.exception.DNSException): pass @@ -33,17 +39,20 @@ def get_backend(name: str) -> Backend: backend = _backends.get(name) if backend: return backend - if name == 'trio': + if name == "trio": import dns._trio_backend + backend = dns._trio_backend.Backend() - elif name == 'curio': + elif name == "curio": import dns._curio_backend + backend = dns._curio_backend.Backend() - elif name == 'asyncio': + elif name == "asyncio": import dns._asyncio_backend + backend = dns._asyncio_backend.Backend() else: - raise NotImplementedError(f'unimplemented async backend {name}') + raise NotImplementedError(f"unimplemented async backend {name}") _backends[name] = backend return backend @@ -60,23 +69,25 @@ def sniff() -> str: if _no_sniffio: raise ImportError import sniffio + try: return sniffio.current_async_library() except sniffio.AsyncLibraryNotFoundError: - raise AsyncLibraryNotFoundError('sniffio cannot determine ' + - 'async library') + raise AsyncLibraryNotFoundError( + "sniffio cannot determine " + "async library" + ) except ImportError: import asyncio + try: asyncio.get_running_loop() - return 'asyncio' + return "asyncio" except RuntimeError: - raise AsyncLibraryNotFoundError('no async library detected') + raise AsyncLibraryNotFoundError("no async library detected") def get_default_backend() -> Backend: - """Get the default backend, initializing it if necessary. - """ + """Get the default backend, initializing it if necessary.""" if _default_backend: return _default_backend diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 977f0d41..28e124d7 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -34,8 +34,16 @@ import dns.rdataclass import dns.rdatatype import dns.transaction -from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ - UDPMode, _have_httpx, _have_http2, NoDOH +from dns.query import ( + _compute_times, + _matches_destination, + BadResponse, + ssl, + UDPMode, + _have_httpx, + _have_http2, + NoDOH, +) if _have_httpx: import httpx @@ -50,11 +58,11 @@ def _source_tuple(af, address, port): if address or port: if address is None: if af == socket.AF_INET: - address = '0.0.0.0' + address = "0.0.0.0" elif af == socket.AF_INET6: - address = '::' + address = "::" else: - raise NotImplementedError(f'unknown address family {af}') + raise NotImplementedError(f"unknown address family {af}") return (address, port) else: return None @@ -69,9 +77,12 @@ def _timeout(expiration, now=None): return None -async def send_udp(sock: dns.asyncbackend.DatagramSocket, - what: Union[dns.message.Message, bytes], destination: Any, - expiration: Optional[float]=None) -> Tuple[int, float]: +async def send_udp( + sock: dns.asyncbackend.DatagramSocket, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -95,11 +106,17 @@ async def send_udp(sock: dns.asyncbackend.DatagramSocket, return (n, sent_time) -async def receive_udp(sock: dns.asyncbackend.DatagramSocket, - destination: Optional[Any]=None, expiration: Optional[float]=None, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'', - ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any: +async def receive_udp( + sock: dns.asyncbackend.DatagramSocket, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, +) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -108,24 +125,39 @@ async def receive_udp(sock: dns.asyncbackend.DatagramSocket, parameters, exceptions, and return type of this method. """ - wire = b'' + wire = b"" while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _matches_destination(sock.family, from_address, destination, - ignore_unexpected): + if _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): break received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) return (r, received_time, from_address) -async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - raise_on_truncation: bool=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: + +async def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[dns.asyncbackend.DatagramSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, @@ -156,16 +188,20 @@ async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, dtuple = (where, port) else: dtuple = None - s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, - dtuple) + s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) assert s is not None await send_udp(s, wire, destination, expiration) - (r, received_time, _) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse @@ -174,12 +210,21 @@ async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, if not sock and s: await s.close() -async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None, - tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]: + +async def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None, + tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -201,20 +246,42 @@ async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optiona method. """ try: - response = await udp(q, where, timeout, port, source, source_port, - ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True, udp_sock, backend) + response = await udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + backend, + ) return (response, False) except dns.message.Truncated: - response = await tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, tcp_sock, - backend) + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + backend, + ) return (response, True) -async def send_tcp(sock: dns.asyncbackend.StreamSocket, - what: Union[dns.message.Message, bytes], - expiration: Optional[float]=None) -> Tuple[int, float]: +async def send_tcp( + sock: dns.asyncbackend.StreamSocket, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -241,21 +308,24 @@ async def _read_exactly(sock, count, expiration): """Read the specified number of bytes from stream. Keep trying until we either get the desired amount, or we hit EOF. """ - s = b'' + s = b"" while count > 0: n = await sock.recv(count, _timeout(expiration)) - if n == b'': + if n == b"": raise EOFError count = count - len(n) s = s + n return s -async def receive_tcp(sock: dns.asyncbackend.StreamSocket, - expiration: Optional[float]=None, one_rr_per_rrset: bool=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, - request_mac: Optional[bytes]=b'', - ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]: +async def receive_tcp( + sock: dns.asyncbackend.StreamSocket, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -268,17 +338,28 @@ async def receive_tcp(sock: dns.asyncbackend.StreamSocket, (l,) = struct.unpack("!H", ldata) wire = await _read_exactly(sock, l, expiration) received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) return (r, received_time) -async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - sock: Optional[dns.asyncbackend.StreamSocket]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: +async def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the @@ -313,13 +394,14 @@ async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, dtuple = (where, port) if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, - dtuple, timeout) + s = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout + ) assert s is not None await send_tcp(s, wire, expiration) - (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing) + (r, received_time) = await receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse @@ -328,13 +410,21 @@ async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, if not sock and s: await s.close() -async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port: int=853, source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - sock: Optional[dns.asyncbackend.StreamSocket]=None, - backend: Optional[dns.asyncbackend.Backend]=None, - ssl_context: Optional[ssl.SSLContext]=None, - server_hostname: Optional[str]=None) -> dns.message.Message: + +async def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket @@ -367,15 +457,32 @@ async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, dtuple = (where, port) if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, - dtuple, timeout, ssl_context, - server_hostname) + s = await backend.make_socket( + af, + socket.SOCK_STREAM, + 0, + stuple, + dtuple, + timeout, + ssl_context, + server_hostname, + ) else: s = sock try: timeout = _timeout(expiration) - response = await tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, s, backend) + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + s, + backend, + ) end_time = time.time() response.time = end_time - begin_time return response @@ -383,11 +490,21 @@ async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, if not sock and s: await s.close() -async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port: int=443, source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - client: Optional[httpx.AsyncClient]=None, - path: str='/dns-query', post: bool=True, verify: bool=True) -> dns.message.Message: + +async def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + client: Optional[httpx.AsyncClient] = None, + path: str = "/dns-query", + post: bool = True, + verify: bool = True, +) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *client*, a ``httpx.AsyncClient``. If provided, the client to use for @@ -401,7 +518,7 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non """ if not _have_httpx: - raise NoDOH('httpx is not available.') # pragma: no cover + raise NoDOH("httpx is not available.") # pragma: no cover wire = q.to_wire() try: @@ -409,14 +526,12 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non except ValueError: af = None transport = None - headers = { - "accept": "application/dns-message" - } + headers = {"accept": "application/dns-message"} if af is not None: if af == socket.AF_INET: - url = 'https://{}:{}{}'.format(where, port, path) + url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: - url = 'https://[{}]:{}{}'.format(where, port, path) + url = "https://[{}]:{}{}".format(where, port, path) else: url = where if source is not None: @@ -426,24 +541,29 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non client_to_close = None try: if not client: - client = httpx.AsyncClient(http1=True, http2=_have_http2, - verify=verify, transport=transport) + client = httpx.AsyncClient( + http1=True, http2=_have_http2, verify=verify, transport=transport + ) client_to_close = client # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) - response = await client.post(url, headers=headers, content=wire, - timeout=timeout) + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) + response = await client.post( + url, headers=headers, content=wire, timeout=timeout + ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await client.get(url, headers=headers, timeout=timeout, - params={"dns": twire}) + response = await client.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) finally: if client_to_close: await client_to_close.aclose() @@ -451,25 +571,37 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}' - '\nResponse body: {!r}'.format(where, - response.status_code, - response.content)) - r = dns.message.from_wire(response.content, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {!r}".format( + where, response.status_code, response.content + ) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r -async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, - query: Optional[dns.message.Message]=None, - port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, - source: Optional[str]=None, source_port: int=0, udp_mode: UDPMode=UDPMode.NEVER, - backend: Optional[dns.asyncbackend.Backend]=None) -> None: + +async def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> None: """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -502,42 +634,48 @@ async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManage is_udp = False if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, - _timeout(expiration)) + s = await backend.make_socket( + af, sock_type, 0, stuple, dtuple, _timeout(expiration) + ) async with s: if is_udp: await s.sendto(wire, dtuple, _timeout(expiration)) else: tcpmsg = struct.pack("!H", len(wire)) + wire await s.sendall(tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, - is_udp) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if is_udp: destination = _lltuple((where, port), af) while True: timeout = _timeout(mexpiration) - (rwire, from_address) = await s.recvfrom(65535, - timeout) - if _matches_destination(af, from_address, - destination, True): + (rwire, from_address) = await s.recvfrom(65535, timeout) + if _matches_destination( + af, from_address, destination, True + ): break else: ldata = await _read_exactly(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) rwire = await _read_exactly(s, l, mexpiration) - is_ixfr = (rdtype == dns.rdatatype.IXFR) - r = dns.message.from_wire(rwire, keyring=query.keyring, - request_mac=query.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) try: done = inbound.process_message(r) except dns.xfr.UseTCP: diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index e196dbbc..14a25a3b 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -42,13 +42,19 @@ _tcp = dns.asyncquery.tcp class Resolver(dns.resolver.BaseResolver): """Asynchronous DNS stub resolver.""" - async def resolve(self, qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp: bool=False, source: Optional[str]=None, - raise_on_no_answer: bool=True, source_port: int=0, - lifetime: Optional[float]=None, search: Optional[bool]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: + async def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, @@ -59,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver): type of this method. """ - resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search) + resolution = dns.resolver._Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) if not backend: backend = dns.asyncbackend.get_default_backend() start = time.time() @@ -79,25 +86,34 @@ class Resolver(dns.resolver.BaseResolver): (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: await backend.sleep(backoff) - timeout = self._compute_timeout(start, lifetime, - resolution.errors) + timeout = self._compute_timeout(start, lifetime, resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: - response = await _tcp(request, nameserver, - timeout, port, - source, source_port, - backend=backend) + response = await _tcp( + request, + nameserver, + timeout, + port, + source, + source_port, + backend=backend, + ) else: - response = await _udp(request, nameserver, - timeout, port, - source, source_port, - raise_on_truncation=True, - backend=backend) + response = await _udp( + request, + nameserver, + timeout, + port, + source, + source_port, + raise_on_truncation=True, + backend=backend, + ) else: - response = await dns.asyncquery.https(request, - nameserver, - timeout=timeout) + response = await dns.asyncquery.https( + request, nameserver, timeout=timeout + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -109,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver): if answer is not None: return answer - async def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Dict[str, Any]) -> dns.resolver.Answer: + async def resolve_address( + self, ipaddr: str, *args: Any, **kwargs: Dict[str, Any] + ) -> dns.resolver.Answer: """Use an asynchronous resolver to run a reverse query for PTR records. @@ -129,10 +147,11 @@ class Resolver(dns.resolver.BaseResolver): # in the kwargs more than once. modified_kwargs: Dict[str, Any] = {} modified_kwargs.update(kwargs) - modified_kwargs['rdtype'] = dns.rdatatype.PTR - modified_kwargs['rdclass'] = dns.rdataclass.IN - return await self.resolve(dns.reversename.from_address(ipaddr), - *args, **modified_kwargs) + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return await self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) # pylint: disable=redefined-outer-name @@ -180,13 +199,18 @@ def reset_default_resolver() -> None: default_resolver = Resolver() -async def resolve(qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp: bool=False, source: Optional[str]=None, - raise_on_no_answer: bool=True, source_port: int=0, - lifetime: Optional[float]=None, search: Optional[bool]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: +async def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. This is a convenience function that uses the default resolver @@ -196,13 +220,23 @@ async def resolve(qname: Union[dns.name.Name, str], information on the parameters. """ - return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, - source, raise_on_no_answer, - source_port, lifetime, search, - backend) - - -async def resolve_address(ipaddr: str, *args: Any, **kwargs: Dict[str, Any]) -> dns.resolver.Answer: + return await get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) + + +async def resolve_address( + ipaddr: str, *args: Any, **kwargs: Dict[str, Any] +) -> dns.resolver.Answer: """Use a resolver to run a reverse query for PTR records. See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more @@ -211,6 +245,7 @@ async def resolve_address(ipaddr: str, *args: Any, **kwargs: Dict[str, Any]) -> return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) + async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -220,10 +255,14 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return await get_default_resolver().canonical_name(name) -async def zone_for_name(name: Union[dns.name.Name, str], - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, - tcp: bool=False, resolver: Optional[Resolver]=None, - backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name: + +async def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.name.Name: """Find the name of the zone which contains the specified name. See :py:func:`dns.resolver.Resolver.zone_for_name` for more @@ -238,8 +277,9 @@ async def zone_for_name(name: Union[dns.name.Name, str], raise NotAbsolute(name) while True: try: - answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, - tcp, backend=backend) + answer = await resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, backend=backend + ) assert answer.rrset is not None if answer.rrset.name == name: return name diff --git a/dns/dnssec.py b/dns/dnssec.py index 331f4afc..b325f9f8 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -83,17 +83,19 @@ def key_id(key: DNSKEY) -> int: else: total = 0 for i in range(len(rdata) // 2): - total += (rdata[2 * i] << 8) + \ - rdata[2 * i + 1] + total += (rdata[2 * i] << 8) + rdata[2 * i + 1] if len(rdata) % 2 != 0: total += rdata[len(rdata) - 1] << 8 - total += ((total >> 16) & 0xffff) - return total & 0xffff + total += (total >> 16) & 0xFFFF + return total & 0xFFFF -def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata, - algorithm: Union[DSDigest, str], - origin: Optional[dns.name.Name]=None) -> DS: +def make_ds( + name: Union[dns.name.Name, str], + key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, +) -> DS: """Create a DS record for a DNSSEC key. *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. @@ -118,7 +120,7 @@ def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata, except Exception: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) if not isinstance(key, DNSKEY): - raise ValueError('key is not a DNSKEY') + raise ValueError("key is not a DNSKEY") if algorithm == DSDigest.SHA1: dshash = hashlib.sha1() elif algorithm == DSDigest.SHA256: @@ -136,15 +138,16 @@ def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata, dshash.update(key.to_wire(origin=origin)) digest = dshash.digest() - dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \ - digest - ds = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, - len(dsrdata)) + dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest + ds = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata) + ) return cast(DS, ds) -def _find_candidate_keys(keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, dns.node.Node]], - rrsig: RRSIG) -> Optional[List[DNSKEY]]: +def _find_candidate_keys( + keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, dns.node.Node]], rrsig: RRSIG +) -> Optional[List[DNSKEY]]: value = keys.get(rrsig.signer) if isinstance(value, dns.node.Node): rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY) @@ -152,14 +155,21 @@ def _find_candidate_keys(keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, rdataset = value if rdataset is None: return None - return [cast(DNSKEY, rd) for rd in rdataset if - rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag] + return [ + cast(DNSKEY, rd) + for rd in rdataset + if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag + ] def _is_rsa(algorithm: int) -> bool: - return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1, - Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256, - Algorithm.RSASHA512) + return algorithm in ( + Algorithm.RSAMD5, + Algorithm.RSASHA1, + Algorithm.RSASHA1NSEC3SHA1, + Algorithm.RSASHA256, + Algorithm.RSASHA512, + ) def _is_dsa(algorithm: int) -> bool: @@ -183,8 +193,12 @@ def _is_md5(algorithm: int) -> bool: def _is_sha1(algorithm: int) -> bool: - return algorithm in (Algorithm.DSA, Algorithm.RSASHA1, - Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1) + return algorithm in ( + Algorithm.DSA, + Algorithm.RSASHA1, + Algorithm.DSANSEC3SHA1, + Algorithm.RSASHA1NSEC3SHA1, + ) def _is_sha256(algorithm: int) -> bool: @@ -215,35 +229,36 @@ def _make_hash(algorithm: int) -> Any: if algorithm == Algorithm.ED448: return hashes.SHAKE256(114) - raise ValidationFailure('unknown hash for algorithm %u' % algorithm) + raise ValidationFailure("unknown hash for algorithm %u" % algorithm) def _bytes_to_long(b: bytes) -> int: - return int.from_bytes(b, 'big') + return int.from_bytes(b, "big") def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: keyptr: bytes if _is_rsa(key.algorithm): - # we ignore because mypy is confused and thinks key.key is a str for unknown reasons. + # we ignore because mypy is confused and thinks key.key is a str for unknown + # reasons. keyptr = key.key - (bytes_,) = struct.unpack('!B', keyptr[0:1]) + (bytes_,) = struct.unpack("!B", keyptr[0:1]) keyptr = keyptr[1:] if bytes_ == 0: - (bytes_,) = struct.unpack('!H', keyptr[0:2]) + (bytes_,) = struct.unpack("!H", keyptr[0:2]) keyptr = keyptr[2:] rsa_e = keyptr[0:bytes_] rsa_n = keyptr[bytes_:] try: rsa_public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), - _bytes_to_long(rsa_n)).public_key(default_backend()) + _bytes_to_long(rsa_e), _bytes_to_long(rsa_n) + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') + raise ValidationFailure("invalid public key") rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) elif _is_dsa(key.algorithm): keyptr = key.key - (t,) = struct.unpack('!B', keyptr[0:1]) + (t,) = struct.unpack("!B", keyptr[0:1]) keyptr = keyptr[1:] octets = 64 + t * 8 dsa_q = keyptr[0:20] @@ -257,11 +272,11 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) dsa_public_key = dsa.DSAPublicNumbers( _bytes_to_long(dsa_y), dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), - _bytes_to_long(dsa_q), - _bytes_to_long(dsa_g))).public_key(default_backend()) + _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) + ), + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') + raise ValidationFailure("invalid public key") dsa_public_key.verify(sig, data, chosen_hash) elif _is_ecdsa(key.algorithm): keyptr = key.key @@ -273,14 +288,13 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) curve = ec.SECP384R1() octets = 48 ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets:octets * 2] + ecdsa_y = keyptr[octets : octets * 2] try: ecdsa_public_key = ec.EllipticCurvePublicNumbers( - curve=curve, - x=_bytes_to_long(ecdsa_x), - y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) + curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y) + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') + raise ValidationFailure("invalid public key") ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) elif _is_eddsa(key.algorithm): keyptr = key.key @@ -292,20 +306,24 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) try: eddsa_public_key = loader.from_public_bytes(keyptr) except ValueError: - raise ValidationFailure('invalid public key') + raise ValidationFailure("invalid public key") eddsa_public_key.verify(sig, data) elif _is_gost(key.algorithm): raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' % - algorithm_to_text(key.algorithm)) + 'algorithm "%s" not supported by dnspython' + % algorithm_to_text(key.algorithm) + ) else: - raise ValidationFailure('unknown algorithm %u' % key.algorithm) + raise ValidationFailure("unknown algorithm %u" % key.algorithm) -def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], - rrsig: RRSIG, - keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], - origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None: +def _validate_rrsig( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, +) -> None: """Validate an RRset against a single signature rdata, throwing an exception if validation is not successful. @@ -340,7 +358,7 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata candidate_keys = _find_candidate_keys(keys, rrsig) if candidate_keys is None: - raise ValidationFailure('unknown key') + raise ValidationFailure("unknown key") # For convenience, allow the rrset to be specified as a (name, # rdataset) tuple as well as a proper rrset @@ -354,15 +372,14 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata if now is None: now = time.time() if rrsig.expiration < now: - raise ValidationFailure('expired') + raise ValidationFailure("expired") if rrsig.inception > now: - raise ValidationFailure('not yet valid') + raise ValidationFailure("not yet valid") if _is_dsa(rrsig.algorithm): sig_r = rrsig.signature[1:21] sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) elif _is_ecdsa(rrsig.algorithm): if rrsig.algorithm == Algorithm.ECDSAP256SHA256: octets = 32 @@ -370,34 +387,32 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata octets = 48 sig_r = rrsig.signature[0:octets] sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) else: sig = rrsig.signature - data = b'' + data = b"" data += rrsig.to_wire(origin=origin)[:18] data += rrsig.signer.to_digestable(origin) # Derelativize the name before considering labels. if not rrname.is_absolute(): if origin is None: - raise ValidationFailure('relative RR name without an origin specified') + raise ValidationFailure("relative RR name without an origin specified") rrname = rrname.derelativize(origin) if len(rrname) - 1 < rrsig.labels: - raise ValidationFailure('owner name longer than RRSIG labels') + raise ValidationFailure("owner name longer than RRSIG labels") elif rrsig.labels < len(rrname) - 1: suffix = rrname.split(rrsig.labels + 1)[1] - rrname = dns.name.from_text('*', suffix) + rrname = dns.name.from_text("*", suffix) rrnamebuf = rrname.to_digestable() - rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, - rrsig.original_ttl) + rrfixed = struct.pack("!HHI", rdataset.rdtype, rdataset.rdclass, rrsig.original_ttl) rdatas = [rdata.to_digestable(origin) for rdata in rdataset] for rdata in sorted(rdatas): data += rrnamebuf data += rrfixed - rrlen = struct.pack('!H', len(rdata)) + rrlen = struct.pack("!H", len(rdata)) data += rrlen data += rdata @@ -411,13 +426,16 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata # this happens on an individual validation failure continue # nothing verified -- raise failure: - raise ValidationFailure('verify failure') + raise ValidationFailure("verify failure") -def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], - rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], - keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], - origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None: +def _validate( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, +) -> None: """Validate an RRset against a signature RRset, throwing an exception if none of the signatures validate. @@ -468,7 +486,7 @@ def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rd for rrsig in rrsigrdataset: if not isinstance(rrsig, RRSIG): - raise ValidationFailure('expected an RRSIG') + raise ValidationFailure("expected an RRSIG") try: _validate_rrsig(rrset, rrsig, keys, origin, now) return @@ -477,8 +495,12 @@ def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rd raise ValidationFailure("no RRSIGs validated") -def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str, bytes]], - iterations: int, algorithm: Union[int, str]) -> str: +def nsec3_hash( + domain: Union[dns.name.Name, str], + salt: Optional[Union[str, bytes]], + iterations: int, + algorithm: Union[int, str], +) -> str: """ Calculate the NSEC3 hash, according to https://tools.ietf.org/html/rfc5155#section-5 @@ -510,7 +532,7 @@ def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str, byte raise ValueError("Wrong hash algorithm (only SHA1 is supported)") if salt is None: - salt_encoded = b'' + salt_encoded = b"" elif isinstance(salt, str): if len(salt) % 2 == 0: salt_encoded = bytes.fromhex(salt) @@ -535,8 +557,9 @@ def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str, byte def _need_pyca(*args, **kwargs): - raise ImportError("DNSSEC validation requires " + - "python cryptography") # pragma: no cover + raise ImportError( + "DNSSEC validation requires " + "python cryptography" + ) # pragma: no cover try: @@ -555,8 +578,8 @@ except ImportError: # pragma: no cover validate_rrsig = _need_pyca _have_pyca = False else: - validate = _validate # type: ignore - validate_rrsig = _validate_rrsig # type: ignore + validate = _validate # type: ignore + validate_rrsig = _validate_rrsig # type: ignore _have_pyca = True ### BEGIN generated Algorithm constants diff --git a/dns/e164.py b/dns/e164.py index 6e34ae5d..453736d4 100644 --- a/dns/e164.py +++ b/dns/e164.py @@ -24,10 +24,12 @@ import dns.name import dns.resolver #: The public E.164 domain. -public_enum_domain = dns.name.from_text('e164.arpa.') +public_enum_domain = dns.name.from_text("e164.arpa.") -def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> dns.name.Name: +def from_e164( + text: str, origin: Optional[dns.name.Name] = public_enum_domain +) -> dns.name.Name: """Convert an E.164 number in textual form into a Name object whose value is the ENUM domain name for that number. @@ -44,11 +46,14 @@ def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> parts = [d for d in text if d.isdigit()] parts.reverse() - return dns.name.from_text('.'.join(parts), origin=origin) + return dns.name.from_text(".".join(parts), origin=origin) -def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain, - want_plus_prefix: bool=True) -> str: +def to_e164( + name: dns.name.Name, + origin: Optional[dns.name.Name] = public_enum_domain, + want_plus_prefix: bool = True, +) -> str: """Convert an ENUM domain name into an E.164 number. Note that dnspython does not have any information about preferred @@ -72,16 +77,19 @@ def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_dom name = name.relativize(origin) dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] if len(dlabels) != len(name.labels): - raise dns.exception.SyntaxError('non-digit labels in ENUM domain name') + raise dns.exception.SyntaxError("non-digit labels in ENUM domain name") dlabels.reverse() - text = b''.join(dlabels) + text = b"".join(dlabels) if want_plus_prefix: - text = b'+' + text + text = b"+" + text return text.decode() -def query(number: str, domains: Iterable[Union[dns.name.Name, str]], - resolver: Optional[dns.resolver.Resolver]=None) -> dns.resolver.Answer: +def query( + number: str, + domains: Iterable[Union[dns.name.Name, str]], + resolver: Optional[dns.resolver.Resolver] = None, +) -> dns.resolver.Answer: """Look for NAPTR RRs for the specified number in the specified domains. e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) @@ -102,7 +110,7 @@ def query(number: str, domains: Iterable[Union[dns.name.Name, str]], domain = dns.name.from_text(domain) qname = dns.e164.from_e164(number, domain) try: - return resolver.resolve(qname, 'NAPTR') + return resolver.resolve(qname, "NAPTR") except dns.resolver.NXDOMAIN as e: e_nx += e raise e_nx diff --git a/dns/edns.py b/dns/edns.py index d4dca55a..64436cde 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -69,7 +69,7 @@ class Option: """ self.otype = OptionType.make(otype) - def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: """Convert an option to wire format. Returns a ``bytes`` or ``None``. @@ -78,7 +78,7 @@ class Option: raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, otype: OptionType, parser: 'dns.wire.Parser') -> 'Option': + def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": """Build an EDNS option object from wire format. *otype*, a ``dns.edns.OptionType``, is the option type. @@ -118,26 +118,22 @@ class Option: return self._cmp(other) != 0 def __lt__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) < 0 def __le__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) <= 0 def __ge__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) >= 0 def __gt__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) > 0 @@ -157,7 +153,7 @@ class GenericOption(Option): # lgtm[py/missing-equals] super().__init__(otype) self.data = dns.rdata.Rdata._as_bytes(data, True) - def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: if file: file.write(self.data) return None @@ -168,14 +164,16 @@ class GenericOption(Option): # lgtm[py/missing-equals] return "Generic %d" % self.otype @classmethod - def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: return cls(otype, parser.get_remaining()) class ECSOption(Option): # lgtm[py/missing-equals] """EDNS Client Subnet (ECS, RFC7871)""" - def __init__(self, address: str, srclen: Optional[int]=None, scopelen: int=0): + def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0): """*address*, a ``str``, is the client address information. *srclen*, an ``int``, the source prefix length, which is the @@ -204,7 +202,7 @@ class ECSOption(Option): # lgtm[py/missing-equals] srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) else: # pragma: no cover (this will never happen) - raise ValueError('Bad address family') + raise ValueError("Bad address family") assert srclen is not None self.address = address @@ -219,13 +217,11 @@ class ECSOption(Option): # lgtm[py/missing-equals] self.addrdata = addrdata[:nbytes] nbits = srclen % 8 if nbits != 0: - last = struct.pack('B', - ord(self.addrdata[-1:]) & (0xff << (8 - nbits))) + last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits))) self.addrdata = self.addrdata[:-1] + last def to_text(self) -> str: - return "ECS {}/{} scope/{}".format(self.address, self.srclen, - self.scopelen) + return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) @staticmethod def from_text(text: str) -> Option: @@ -251,7 +247,7 @@ class ECSOption(Option): # lgtm[py/missing-equals] >>> # it understands results from `dns.edns.ECSOption.to_text()` >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') """ - optional_prefix = 'ECS' + optional_prefix = "ECS" tokens = text.split() ecs_text = None if len(tokens) == 1: @@ -262,29 +258,32 @@ class ECSOption(Option): # lgtm[py/missing-equals] ecs_text = tokens[1] else: raise ValueError('could not parse ECS from "{}"'.format(text)) - n_slashes = ecs_text.count('/') + n_slashes = ecs_text.count("/") if n_slashes == 1: - address, tsrclen = ecs_text.split('/') - tscope = '0' + address, tsrclen = ecs_text.split("/") + tscope = "0" elif n_slashes == 2: - address, tsrclen, tscope = ecs_text.split('/') + address, tsrclen, tscope = ecs_text.split("/") else: raise ValueError('could not parse ECS from "{}"'.format(text)) try: scope = int(tscope) except ValueError: - raise ValueError('invalid scope ' + - '"{}": scope must be an integer'.format(tscope)) + raise ValueError( + "invalid scope " + '"{}": scope must be an integer'.format(tscope) + ) try: srclen = int(tsrclen) except ValueError: - raise ValueError('invalid srclen ' + - '"{}": srclen must be an integer'.format(tsrclen)) + raise ValueError( + "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen) + ) return ECSOption(address, srclen, scope) - def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: - value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + - self.addrdata) + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = ( + struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata + ) if file: file.write(value) return None @@ -292,18 +291,20 @@ class ECSOption(Option): # lgtm[py/missing-equals] return value @classmethod - def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: - family, src, scope = parser.get_struct('!HBB') + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + family, src, scope = parser.get_struct("!HBB") addrlen = int(math.ceil(src / 8.0)) prefix = parser.get_bytes(addrlen) if family == 1: pad = 4 - addrlen - addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) + addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad) elif family == 2: pad = 16 - addrlen - addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) + addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad) else: - raise ValueError('unsupported family') + raise ValueError("unsupported family") return cls(addr, src, scope) @@ -343,7 +344,7 @@ class EDECode(dns.enum.IntEnum): class EDEOption(Option): # lgtm[py/missing-equals] """Extended DNS Error (EDE, RFC8914)""" - def __init__(self, code: Union[EDECode, str], text: Optional[str]=None): + def __init__(self, code: Union[EDECode, str], text: Optional[str] = None): """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the extended error. @@ -355,19 +356,19 @@ class EDEOption(Option): # lgtm[py/missing-equals] self.code = EDECode.make(code) if text is not None and not isinstance(text, str): - raise ValueError('text must be string or None') + raise ValueError("text must be string or None") self.text = text def to_text(self) -> str: - output = f'EDE {self.code}' + output = f"EDE {self.code}" if self.text is not None: - output += f': {self.text}' + output += f": {self.text}" return output - def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: - value = struct.pack('!H', self.code) + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = struct.pack("!H", self.code) if self.text is not None: - value += self.text.encode('utf8') + value += self.text.encode("utf8") if file: file.write(value) @@ -376,14 +377,16 @@ class EDEOption(Option): # lgtm[py/missing-equals] return value @classmethod - def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: the_code = EDECode.make(parser.get_uint16()) text = parser.get_remaining() if text: if text[-1] == 0: # text MAY be null-terminated text = text[:-1] - btext = text.decode('utf8') + btext = text.decode("utf8") else: btext = None @@ -409,7 +412,9 @@ def get_option_class(otype: OptionType) -> Any: return cls -def option_from_wire_parser(otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: +def option_from_wire_parser( + otype: Union[OptionType, str], parser: "dns.wire.Parser" +) -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -424,7 +429,9 @@ def option_from_wire_parser(otype: Union[OptionType, str], parser: 'dns.wire.Par return cls.from_wire_parser(otype, parser) -def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, olen: int) -> Option: +def option_from_wire( + otype: Union[OptionType, str], wire: bytes, current: int, olen: int +) -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -442,6 +449,7 @@ def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, o with parser.restrict_to(olen): return option_from_wire_parser(otype, parser) + def register_type(implementation: Any, otype: OptionType) -> None: """Register the implementation of an option type. @@ -452,6 +460,7 @@ def register_type(implementation: Any, otype: OptionType) -> None: _type_to_class[otype] = implementation + ### BEGIN generated OptionType constants NSID = OptionType.NSID diff --git a/dns/entropy.py b/dns/entropy.py index 7da2e04a..50103562 100644 --- a/dns/entropy.py +++ b/dns/entropy.py @@ -21,10 +21,11 @@ import os import hashlib import random import time + try: import threading as _threading except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore + import dummy_threading as _threading # type: ignore class EntropyPool: @@ -34,14 +35,14 @@ class EntropyPool: # leaving this code doesn't hurt anything as the library code # is used if present. - def __init__(self, seed: Optional[bytes]=None): + def __init__(self, seed: Optional[bytes] = None): self.pool_index = 0 self.digest: Optional[bytearray] = None self.next_byte = 0 self.lock = _threading.Lock() self.hash = hashlib.sha1() self.hash_len = 20 - self.pool = bytearray(b'\0' * self.hash_len) + self.pool = bytearray(b"\0" * self.hash_len) if seed is not None: self._stir(seed) self.seeded = True @@ -54,7 +55,7 @@ class EntropyPool: for c in entropy: if self.pool_index == self.hash_len: self.pool_index = 0 - b = c & 0xff + b = c & 0xFF self.pool[self.pool_index] ^= b self.pool_index += 1 @@ -68,7 +69,7 @@ class EntropyPool: seed = os.urandom(16) except Exception: # pragma: no cover try: - with open('/dev/urandom', 'rb', 0) as r: + with open("/dev/urandom", "rb", 0) as r: seed = r.read(16) except Exception: seed = str(time.time()).encode() @@ -99,7 +100,7 @@ class EntropyPool: def random_between(self, first: int, last: int) -> int: size = last - first + 1 if size > 4294967296: - raise ValueError('too big') + raise ValueError("too big") if size > 65536: rand = self.random_32 max = 4294967295 @@ -111,6 +112,7 @@ class EntropyPool: max = 255 return first + size * rand() // (max + 1) + pool = EntropyPool() system_random: Optional[Any] @@ -119,12 +121,14 @@ try: except Exception: # pragma: no cover system_random = None + def random_16() -> int: if system_random is not None: return system_random.randrange(0, 65536) else: return pool.random_16() + def between(first: int, last: int) -> int: if system_random is not None: return system_random.randrange(first, last + 1) diff --git a/dns/enum.py b/dns/enum.py index b822dd51..9c674883 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -17,6 +17,7 @@ import enum + class IntEnum(enum.IntEnum): @classmethod def _check_value(cls, value): @@ -33,8 +34,8 @@ class IntEnum(enum.IntEnum): except KeyError: pass prefix = cls._prefix() - if text.startswith(prefix) and text[len(prefix):].isdigit(): - value = int(text[len(prefix):]) + if text.startswith(prefix) and text[len(prefix) :].isdigit(): + value = int(text[len(prefix) :]) cls._check_value(value) try: return cls(value) @@ -83,7 +84,7 @@ class IntEnum(enum.IntEnum): @classmethod def _prefix(cls): - return '' + return "" @classmethod def _unknown_exception_class(cls): diff --git a/dns/exception.py b/dns/exception.py index aa0144d4..3b2f1cdc 100644 --- a/dns/exception.py +++ b/dns/exception.py @@ -73,14 +73,15 @@ class DNSException(Exception): For sanity we do not allow to mix old and new behavior.""" if args or kwargs: - assert bool(args) != bool(kwargs), \ - 'keyword arguments are mutually exclusive with positional args' + assert bool(args) != bool( + kwargs + ), "keyword arguments are mutually exclusive with positional args" def _check_kwargs(self, **kwargs): if kwargs: - assert set(kwargs.keys()) == self.supp_kwargs, \ - 'following set of keyword args is required: %s' % ( - self.supp_kwargs) + assert ( + set(kwargs.keys()) == self.supp_kwargs + ), "following set of keyword args is required: %s" % (self.supp_kwargs) return kwargs def _fmt_kwargs(self, **kwargs): @@ -129,10 +130,12 @@ class TooBig(DNSException): class Timeout(DNSException): """The DNS operation timed out.""" - supp_kwargs = {'timeout'} + + supp_kwargs = {"timeout"} fmt = "The DNS operation timed out after {timeout:.3f} seconds" - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -145,7 +148,6 @@ class ExceptionWrapper: return self def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None and not isinstance(exc_val, - self.exception_class): + if exc_type is not None and not isinstance(exc_val, self.exception_class): raise self.exception_class(str(exc_val)) from exc_val return False diff --git a/dns/flags.py b/dns/flags.py index 6fe1afd3..b21b8e3b 100644 --- a/dns/flags.py +++ b/dns/flags.py @@ -23,6 +23,7 @@ import enum # Standard DNS flags + class Flag(enum.IntFlag): #: Query Response QR = 0x8000 @@ -42,6 +43,7 @@ class Flag(enum.IntFlag): # EDNS flags + class EDNSFlag(enum.IntFlag): #: DNSSEC answer OK DO = 0x8000 @@ -60,7 +62,7 @@ def _to_text(flags: int, enum_class: Any) -> str: for k, v in enum_class.__members__.items(): if flags & v != 0: text_flags.append(k) - return ' '.join(text_flags) + return " ".join(text_flags) def from_text(text: str) -> int: @@ -102,6 +104,7 @@ def edns_to_text(flags: int) -> str: return _to_text(flags, EDNSFlag) + ### BEGIN generated Flag constants QR = Flag.QR diff --git a/dns/grange.py b/dns/grange.py index ebb64d2d..3a52278f 100644 --- a/dns/grange.py +++ b/dns/grange.py @@ -21,6 +21,7 @@ from typing import Tuple import dns + def from_text(text: str) -> Tuple[int, int, int]: """Convert the text form of a range in a ``$GENERATE`` statement to an integer. @@ -33,22 +34,22 @@ def from_text(text: str) -> Tuple[int, int, int]: start = -1 stop = -1 step = 1 - cur = '' + cur = "" state = 0 # state 0 1 2 # x - y / z - if text and text[0] == '-': + if text and text[0] == "-": raise dns.exception.SyntaxError("Start cannot be a negative number") for c in text: - if c == '-' and state == 0: + if c == "-" and state == 0: start = int(cur) - cur = '' + cur = "" state = 1 - elif c == '/': + elif c == "/": stop = int(cur) - cur = '' + cur = "" state = 2 elif c.isdigit(): cur += c @@ -66,6 +67,6 @@ def from_text(text: str) -> Tuple[int, int, int]: assert step >= 1 assert start >= 0 if start > stop: - raise dns.exception.SyntaxError('start must be <= stop') + raise dns.exception.SyntaxError("start must be <= stop") return (start, stop, step) diff --git a/dns/immutable.py b/dns/immutable.py index 8a426210..38fbe597 100644 --- a/dns/immutable.py +++ b/dns/immutable.py @@ -9,7 +9,7 @@ from dns._immutable_ctx import immutable @immutable class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] - def __init__(self, dictionary: Any, no_copy: bool=False): + def __init__(self, dictionary: Any, no_copy: bool = False): """Make an immutable dictionary from the specified dictionary. If *no_copy* is `True`, then *dictionary* will be wrapped instead @@ -30,7 +30,7 @@ class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] h = 0 for key in sorted(self._odict.keys()): h ^= hash(key) - object.__setattr__(self, '_hash', h) + object.__setattr__(self, "_hash", h) # this does return an int, but pylint doesn't figure that out return self._hash diff --git a/dns/inet.py b/dns/inet.py index b3ed9995..11180c96 100644 --- a/dns/inet.py +++ b/dns/inet.py @@ -137,7 +137,9 @@ def is_address(text: str) -> bool: return False -def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) -> Any: +def low_level_address_tuple( + high_tuple: Tuple[str, int], af: Optional[int] = None +) -> Any: """Given a "high-level" address tuple, i.e. an (address, port) return the appropriate "low-level" address tuple suitable for use in socket calls. @@ -152,13 +154,13 @@ def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) if af == AF_INET: return (address, port) elif af == AF_INET6: - i = address.find('%') + i = address.find("%") if i < 0: # no scope, shortcut! return (address, port, 0, 0) # try to avoid getaddrinfo() addrpart = address[:i] - scope = address[i + 1:] + scope = address[i + 1 :] if scope.isdigit(): return (addrpart, port, 0, int(scope)) try: @@ -168,4 +170,4 @@ def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) return tup else: - raise NotImplementedError(f'unknown address family {af}') + raise NotImplementedError(f"unknown address family {af}") diff --git a/dns/ipv4.py b/dns/ipv4.py index fddad1b1..b8e148f3 100644 --- a/dns/ipv4.py +++ b/dns/ipv4.py @@ -23,6 +23,7 @@ import struct import dns.exception + def inet_ntoa(address: bytes) -> str: """Convert an IPv4 address in binary form to text form. @@ -33,8 +34,8 @@ def inet_ntoa(address: bytes) -> str: if len(address) != 4: raise dns.exception.SyntaxError - return ('%u.%u.%u.%u' % (address[0], address[1], - address[2], address[3])) + return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3]) + def inet_aton(text: Union[str, bytes]) -> bytes: """Convert an IPv4 address in text form to binary form. @@ -48,17 +49,17 @@ def inet_aton(text: Union[str, bytes]) -> bytes: btext = text.encode() else: btext = text - parts = btext.split(b'.') + parts = btext.split(b".") if len(parts) != 4: raise dns.exception.SyntaxError for part in parts: if not part.isdigit(): raise dns.exception.SyntaxError - if len(part) > 1 and part[0] == ord('0'): + if len(part) > 1 and part[0] == ord("0"): # No leading zeros raise dns.exception.SyntaxError try: b = [int(part) for part in parts] - return struct.pack('BBBB', *b) + return struct.pack("BBBB", *b) except Exception: raise dns.exception.SyntaxError diff --git a/dns/ipv6.py b/dns/ipv6.py index 9e6e8b6a..fbd49623 100644 --- a/dns/ipv6.py +++ b/dns/ipv6.py @@ -25,7 +25,8 @@ import binascii import dns.exception import dns.ipv4 -_leading_zero = re.compile(r'0+([0-9a-f]+)') +_leading_zero = re.compile(r"0+([0-9a-f]+)") + def inet_ntoa(address: bytes) -> str: """Convert an IPv6 address in binary form to text form. @@ -43,7 +44,7 @@ def inet_ntoa(address: bytes) -> str: i = 0 l = len(hex) while i < l: - chunk = hex[i:i + 4].decode() + chunk = hex[i : i + 4].decode() # strip leading zeros. we do this with an re instead of # with lstrip() because lstrip() didn't support chars until # python 2.2.2 @@ -60,7 +61,7 @@ def inet_ntoa(address: bytes) -> str: start = -1 last_was_zero = False for i in range(8): - if chunks[i] != '0': + if chunks[i] != "0": if last_was_zero: end = i current_len = end - start @@ -78,27 +79,30 @@ def inet_ntoa(address: bytes) -> str: best_start = start best_len = current_len if best_len > 1: - if best_start == 0 and \ - (best_len == 6 or - best_len == 5 and chunks[5] == 'ffff'): + if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"): # We have an embedded IPv4 address if best_len == 6: - prefix = '::' + prefix = "::" else: - prefix = '::ffff:' + prefix = "::ffff:" thex = prefix + dns.ipv4.inet_ntoa(address[12:]) else: - thex = ':'.join(chunks[:best_start]) + '::' + \ - ':'.join(chunks[best_start + best_len:]) + thex = ( + ":".join(chunks[:best_start]) + + "::" + + ":".join(chunks[best_start + best_len :]) + ) else: - thex = ':'.join(chunks) + thex = ":".join(chunks) return thex -_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') -_colon_colon_start = re.compile(br'::.*') -_colon_colon_end = re.compile(br'.*::$') -def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes: +_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$") +_colon_colon_start = re.compile(rb"::.*") +_colon_colon_end = re.compile(rb".*::$") + + +def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: """Convert an IPv6 address in text form to binary form. *text*, a ``str``, the IPv6 address in textual form. @@ -118,30 +122,32 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes: btext = text if ignore_scope: - parts = btext.split(b'%') + parts = btext.split(b"%") l = len(parts) if l == 2: btext = parts[0] elif l > 2: raise dns.exception.SyntaxError - if btext == b'': + if btext == b"": raise dns.exception.SyntaxError - elif btext.endswith(b':') and not btext.endswith(b'::'): + elif btext.endswith(b":") and not btext.endswith(b"::"): raise dns.exception.SyntaxError - elif btext.startswith(b':') and not btext.startswith(b'::'): + elif btext.startswith(b":") and not btext.startswith(b"::"): raise dns.exception.SyntaxError - elif btext == b'::': - btext = b'0::' + elif btext == b"::": + btext = b"0::" # # Get rid of the icky dot-quad syntax if we have it. # m = _v4_ending.match(btext) if m is not None: b = dns.ipv4.inet_aton(m.group(2)) - btext = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), - b[0], b[1], b[2], - b[3])).encode() + btext = ( + "{}:{:02x}{:02x}:{:02x}{:02x}".format( + m.group(1).decode(), b[0], b[1], b[2], b[3] + ) + ).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' @@ -156,29 +162,29 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes: # # Now canonicalize into 8 chunks of 4 hex digits each # - chunks = btext.split(b':') + chunks = btext.split(b":") l = len(chunks) if l > 8: raise dns.exception.SyntaxError seen_empty = False canonical: List[bytes] = [] for c in chunks: - if c == b'': + if c == b"": if seen_empty: raise dns.exception.SyntaxError seen_empty = True for _ in range(0, 8 - l + 1): - canonical.append(b'0000') + canonical.append(b"0000") else: lc = len(c) if lc > 4: raise dns.exception.SyntaxError if lc != 4: - c = (b'0' * (4 - lc)) + c + c = (b"0" * (4 - lc)) + c canonical.append(c) if l < 8 and not seen_empty: raise dns.exception.SyntaxError - btext = b''.join(canonical) + btext = b"".join(canonical) # # Finally we can go to binary. @@ -188,7 +194,9 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes: except (binascii.Error, TypeError): raise dns.exception.SyntaxError -_mapped_prefix = b'\x00' * 10 + b'\xff\xff' + +_mapped_prefix = b"\x00" * 10 + b"\xff\xff" + def is_mapped(address: bytes) -> bool: """Is the specified address a mapped IPv4 address? diff --git a/dns/message.py b/dns/message.py index 0e1e4336..967fefea 100644 --- a/dns/message.py +++ b/dns/message.py @@ -73,9 +73,10 @@ class UnknownTSIGKey(dns.exception.DNSException): class Truncated(dns.exception.DNSException): """The truncated flag is set.""" - supp_kwargs = {'message'} + supp_kwargs = {"message"} - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -84,7 +85,7 @@ class Truncated(dns.exception.DNSException): Returns a ``dns.message.Message``. """ - return self.kwargs['message'] + return self.kwargs["message"] class NotQueryResponse(dns.exception.DNSException): @@ -98,12 +99,14 @@ class ChainTooLong(dns.exception.DNSException): class AnswerForNXDOMAIN(dns.exception.DNSException): """The rcode is NXDOMAIN but an answer was found.""" + class NoPreviousName(dns.exception.SyntaxError): """No previous name was known.""" class MessageSection(dns.enum.IntEnum): """Message sections""" + QUESTION = 0 ANSWER = 1 AUTHORITY = 2 @@ -123,18 +126,24 @@ class MessageError: DEFAULT_EDNS_PAYLOAD = 1232 MAX_CHAIN = 16 -IndexKeyType = Tuple[int, dns.name.Name, dns.rdataclass.RdataClass, - dns.rdatatype.RdataType, Optional[dns.rdatatype.RdataType], - Optional[dns.rdataclass.RdataClass]] +IndexKeyType = Tuple[ + int, + dns.name.Name, + dns.rdataclass.RdataClass, + dns.rdatatype.RdataType, + Optional[dns.rdatatype.RdataType], + Optional[dns.rdataclass.RdataClass], +] IndexType = Dict[IndexKeyType, dns.rrset.RRset] SectionType = Union[int, List[dns.rrset.RRset]] + class Message: """A DNS message.""" _section_enum = MessageSection - def __init__(self, id: Optional[int]=None): + def __init__(self, id: Optional[int] = None): if id is None: self.id = dns.entropy.random_16() else: @@ -145,7 +154,7 @@ class Message: self.request_payload = 0 self.keyring: Any = None self.tsig: Optional[dns.rrset.RRset] = None - self.request_mac = b'' + self.request_mac = b"" self.xfr = False self.origin: Optional[dns.name.Name] = None self.tsig_ctx: Optional[Any] = None @@ -155,7 +164,7 @@ class Message: @property def question(self) -> List[dns.rrset.RRset]: - """ The question section.""" + """The question section.""" return self.sections[0] @question.setter @@ -164,7 +173,7 @@ class Message: @property def answer(self) -> List[dns.rrset.RRset]: - """ The answer section.""" + """The answer section.""" return self.sections[1] @answer.setter @@ -173,7 +182,7 @@ class Message: @property def authority(self) -> List[dns.rrset.RRset]: - """ The authority section.""" + """The authority section.""" return self.sections[2] @authority.setter @@ -182,7 +191,7 @@ class Message: @property def additional(self) -> List[dns.rrset.RRset]: - """ The additional data section.""" + """The additional data section.""" return self.sections[3] @additional.setter @@ -190,13 +199,17 @@ class Message: self.sections[3] = v def __repr__(self): - return '' + return "" def __str__(self): return self.to_text() - def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, - **kw: Dict[str, Any]) -> str: + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: """Convert the message to text. The *origin*, *relativize*, and any other keyword @@ -206,23 +219,22 @@ class Message: """ s = io.StringIO() - s.write('id %d\n' % self.id) - s.write('opcode %s\n' % dns.opcode.to_text(self.opcode())) - s.write('rcode %s\n' % dns.rcode.to_text(self.rcode())) - s.write('flags %s\n' % dns.flags.to_text(self.flags)) + s.write("id %d\n" % self.id) + s.write("opcode %s\n" % dns.opcode.to_text(self.opcode())) + s.write("rcode %s\n" % dns.rcode.to_text(self.rcode())) + s.write("flags %s\n" % dns.flags.to_text(self.flags)) if self.edns >= 0: - s.write('edns %s\n' % self.edns) + s.write("edns %s\n" % self.edns) if self.ednsflags != 0: - s.write('eflags %s\n' % - dns.flags.edns_to_text(self.ednsflags)) - s.write('payload %d\n' % self.payload) + s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags)) + s.write("payload %d\n" % self.payload) for opt in self.options: - s.write('option %s\n' % opt.to_text()) + s.write("option %s\n" % opt.to_text()) for (name, which) in self._section_enum.__members__.items(): - s.write(f';{name}\n') + s.write(f";{name}\n") for rrset in self.section_from_number(which): s.write(rrset.to_text(origin, relativize, **kw)) - s.write('\n') + s.write("\n") # # We strip off the final \n so the caller can print the result without # doing weird things to get around eccentricities in Python print @@ -256,20 +268,25 @@ class Message: def __ne__(self, other): return not self.__eq__(other) - def is_response(self, other: 'Message') -> bool: + def is_response(self, other: "Message") -> bool: """Is *other*, also a ``dns.message.Message``, a response to this message? Returns a ``bool``. """ - if other.flags & dns.flags.QR == 0 or \ - self.id != other.id or \ - dns.opcode.from_flags(self.flags) != \ - dns.opcode.from_flags(other.flags): + if ( + other.flags & dns.flags.QR == 0 + or self.id != other.id + or dns.opcode.from_flags(self.flags) != dns.opcode.from_flags(other.flags) + ): return False - if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL, - dns.rcode.NOTIMP, dns.rcode.REFUSED}: + if other.rcode() in { + dns.rcode.FORMERR, + dns.rcode.SERVFAIL, + dns.rcode.NOTIMP, + dns.rcode.REFUSED, + }: # We don't check the question section in these cases if # the other question section is empty, even though they # still really ought to have a question section. @@ -303,7 +320,7 @@ class Message: for i, our_section in enumerate(self.sections): if section is our_section: return self._section_enum(i) - raise ValueError('unknown section') + raise ValueError("unknown section") def section_from_number(self, number: int) -> List[dns.rrset.RRset]: """Return the section list associated with the specified section @@ -320,15 +337,17 @@ class Message: section = self._section_enum.make(number) return self.sections[section] - def find_rrset(self, - section: SectionType, - name: dns.name.Name, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - deleting: Optional[dns.rdataclass.RdataClass]=None, - create: bool=False, - force_unique: bool=False) -> dns.rrset.RRset: + def find_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + ) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. *section*, an ``int`` section number, or one of the section @@ -378,8 +397,7 @@ class Message: return rrset else: for rrset in the_section: - if rrset.full_match(name, rdclass, rdtype, covers, - deleting): + if rrset.full_match(name, rdclass, rdtype, covers, deleting): return rrset if not create: raise KeyError @@ -389,15 +407,17 @@ class Message: self.index[key] = rrset return rrset - def get_rrset(self, - section: SectionType, - name: dns.name.Name, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - deleting: Optional[dns.rdataclass.RdataClass]=None, - create: bool=False, - force_unique: bool=False) -> Optional[dns.rrset.RRset]: + def get_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + ) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. @@ -433,14 +453,21 @@ class Message: """ try: - rrset = self.find_rrset(section, name, rdclass, rdtype, covers, - deleting, create, force_unique) + rrset = self.find_rrset( + section, name, rdclass, rdtype, covers, deleting, create, force_unique + ) except KeyError: rrset = None return rrset - def to_wire(self, origin: Optional[dns.name.Name]=None, max_size: int=0, - multi: bool=False, tsig_ctx: Optional[Any]=None, **kw: Dict[str, Any]) -> bytes: + def to_wire( + self, + origin: Optional[dns.name.Name] = None, + max_size: int = 0, + multi: bool = False, + tsig_ctx: Optional[Any] = None, + **kw: Dict[str, Any], + ) -> bytes: """Return a string containing the message in DNS compressed wire format. @@ -490,13 +517,15 @@ class Message: r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) r.write_header() if self.tsig is not None: - (new_tsig, ctx) = dns.tsig.sign(r.get_wire(), - self.keyring, - self.tsig[0], - int(time.time()), - self.request_mac, - tsig_ctx, - multi) + (new_tsig, ctx) = dns.tsig.sign( + r.get_wire(), + self.keyring, + self.tsig[0], + int(time.time()), + self.request_mac, + tsig_ctx, + multi, + ) self.tsig.clear() self.tsig.add(new_tsig) r.add_rrset(dns.renderer.ADDITIONAL, self.tsig) @@ -506,17 +535,32 @@ class Message: return r.get_wire() @staticmethod - def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id, - error, other): - tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, - algorithm, time_signed, fudge, mac, - original_id, error, other) + def _make_tsig( + keyname, algorithm, time_signed, fudge, mac, original_id, error, other + ): + tsig = dns.rdtypes.ANY.TSIG.TSIG( + dns.rdataclass.ANY, + dns.rdatatype.TSIG, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) return dns.rrset.from_rdata(keyname, 0, tsig) - def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None, - fudge: int=300, original_id: Optional[int]=None, tsig_error: int=0, - other_data: bytes=b'', - algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None: + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + fudge: int = 300, + original_id: Optional[int] = None, + tsig_error: int = 0, + other_data: bytes = b"", + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: """When sending, a TSIG signature using the specified key should be added. @@ -570,8 +614,16 @@ class Message: self.keyring = key if original_id is None: original_id = self.id - self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge, - b'', original_id, tsig_error, other_data) + self.tsig = self._make_tsig( + keyname, + self.keyring.algorithm, + 0, + fudge, + b"", + original_id, + tsig_error, + other_data, + ) @property def keyname(self) -> Optional[dns.name.Name]: @@ -607,13 +659,17 @@ class Message: @staticmethod def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None): - opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, - options or ()) + opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, options or ()) return dns.rrset.from_rdata(dns.name.root, int(flags), opt) - def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0, payload: int=DEFAULT_EDNS_PAYLOAD, - request_payload: Optional[int]=None, - options: Optional[List[dns.edns.Option]]=None) -> None: + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = DEFAULT_EDNS_PAYLOAD, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + ) -> None: """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -645,7 +701,7 @@ class Message: else: # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF - ednsflags |= (edns << 16) + ednsflags |= edns << 16 if options is None: options = [] self.opt = self._make_opt(ednsflags, payload, options) @@ -656,7 +712,7 @@ class Message: @property def edns(self) -> int: if self.opt: - return (self.ednsflags & 0xff0000) >> 16 + return (self.ednsflags & 0xFF0000) >> 16 else: return -1 @@ -688,7 +744,7 @@ class Message: else: return () - def want_dnssec(self, wanted: bool=True) -> None: + def want_dnssec(self, wanted: bool = True) -> None: """Enable or disable 'DNSSEC desired' flag in requests. *wanted*, a ``bool``. If ``True``, then DNSSEC data is @@ -746,16 +802,20 @@ class Message: # pylint: enable=unused-argument - def _parse_special_rr_header(self, section, count, position, - name, rdclass, rdtype): + def _parse_special_rr_header(self, section, count, position, name, rdclass, rdtype): if rdtype == dns.rdatatype.OPT: - if section != MessageSection.ADDITIONAL or self.opt or \ - name != dns.name.root: + if ( + section != MessageSection.ADDITIONAL + or self.opt + or name != dns.name.root + ): raise BadEDNS elif rdtype == dns.rdatatype.TSIG: - if section != MessageSection.ADDITIONAL or \ - rdclass != dns.rdatatype.ANY or \ - position != count - 1: + if ( + section != MessageSection.ADDITIONAL + or rdclass != dns.rdatatype.ANY + or position != count - 1 + ): raise BadTSIG return (rdclass, rdtype, None, False) @@ -778,8 +838,14 @@ class ChainingResult: The ``cnames`` attribute is a list of all the CNAME RRSets followed to get to the canonical name. """ - def __init__(self, canonical_name: dns.name.Name, answer: Optional[dns.rrset.RRset], - minimum_ttl: int, cnames: List[dns.rrset.RRset]): + + def __init__( + self, + canonical_name: dns.name.Name, + answer: Optional[dns.rrset.RRset], + minimum_ttl: int, + cnames: List[dns.rrset.RRset], + ): self.canonical_name = canonical_name self.answer = answer self.minimum_ttl = minimum_ttl @@ -815,16 +881,17 @@ class QueryMessage(Message): cnames = [] while count < MAX_CHAIN: try: - answer = self.find_rrset(self.answer, qname, question.rdclass, - question.rdtype) + answer = self.find_rrset( + self.answer, qname, question.rdclass, question.rdtype + ) min_ttl = min(min_ttl, answer.ttl) break except KeyError: if question.rdtype != dns.rdatatype.CNAME: try: - crrset = self.find_rrset(self.answer, qname, - question.rdclass, - dns.rdatatype.CNAME) + crrset = self.find_rrset( + self.answer, qname, question.rdclass, dns.rdatatype.CNAME + ) cnames.append(crrset) min_ttl = min(min_ttl, crrset.ttl) for rd in crrset: @@ -849,9 +916,9 @@ class QueryMessage(Message): # Look for an SOA RR whose owner name is a superdomain # of qname. try: - srrset = self.find_rrset(self.authority, auname, - question.rdclass, - dns.rdatatype.SOA) + srrset = self.find_rrset( + self.authority, auname, question.rdclass, dns.rdatatype.SOA + ) min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum) break except KeyError: @@ -915,9 +982,17 @@ class _WireReader: raising them. """ - def __init__(self, wire, initialize_message, question_only=False, - one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False, continue_on_error=False): + def __init__( + self, + wire, + initialize_message, + question_only=False, + one_rr_per_rrset=False, + ignore_trailing=False, + keyring=None, + multi=False, + continue_on_error=False, + ): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -937,12 +1012,13 @@ class _WireReader: section = self.message.sections[section_number] for _ in range(qcount): qname = self.parser.get_name(self.message.origin) - (rdtype, rdclass) = self.parser.get_struct('!HH') - (rdclass, rdtype, _, _) = \ - self.message._parse_rr_header(section_number, qname, rdclass, - rdtype) - self.message.find_rrset(section, qname, rdclass, rdtype, - create=True, force_unique=True) + (rdtype, rdclass) = self.parser.get_struct("!HH") + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, qname, rdclass, rdtype + ) + self.message.find_rrset( + section, qname, rdclass, rdtype, create=True, force_unique=True + ) def _add_error(self, e): self.errors.append(MessageError(e, self.parser.current)) @@ -964,16 +1040,20 @@ class _WireReader: name = absolute_name.relativize(self.message.origin) else: name = absolute_name - (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct('!HHIH') + (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct("!HHIH") if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_special_rr_header(section_number, - count, i, name, - rdclass, rdtype) + ( + rdclass, + rdtype, + deleting, + empty, + ) = self.message._parse_special_rr_header( + section_number, count, i, name, rdclass, rdtype + ) else: - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_rr_header(section_number, - name, rdclass, rdtype) + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) try: rdata_start = self.parser.current if empty: @@ -983,9 +1063,9 @@ class _WireReader: covers = dns.rdatatype.NONE else: with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) + rd = dns.rdata.from_wire_parser( + rdclass, rdtype, self.parser, self.message.origin + ) covers = rd.covers() if self.message.xfr and rdtype == dns.rdatatype.SOA: force_unique = True @@ -993,8 +1073,7 @@ class _WireReader: self.message.opt = dns.rrset.from_rdata(name, ttl, rd) elif rdtype == dns.rdatatype.TSIG: if self.keyring is None: - raise UnknownTSIGKey('got signed message without ' - 'keyring') + raise UnknownTSIGKey("got signed message without " "keyring") if isinstance(self.keyring, dict): key = self.keyring.get(absolute_name) if isinstance(key, bytes): @@ -1006,25 +1085,31 @@ class _WireReader: if key is None: raise UnknownTSIGKey("key '%s' unknown" % name) self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, - rd) + self.message.tsig_ctx = dns.tsig.validate( + self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi, + ) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + force_unique, + ) if rd is not None: - if ttl > 0x7fffffff: + if ttl > 0x7FFFFFFF: ttl = 0 rrset.add(rd, ttl) except Exception as e: @@ -1040,14 +1125,16 @@ class _WireReader: if self.parser.remaining() < 12: raise ShortHeader - (id, flags, qcount, ancount, aucount, adcount) = \ - self.parser.get_struct('!HHHHHH') + (id, flags, qcount, ancount, aucount, adcount) = self.parser.get_struct( + "!HHHHHH" + ) factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) self.message = factory(id=id) self.message.flags = dns.flags.Flag(flags) self.initialize_message(self.message) - self.one_rr_per_rrset = \ - self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) + self.one_rr_per_rrset = self.message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) try: self._get_question(MessageSection.QUESTION, qcount) if self.question_only: @@ -1057,8 +1144,7 @@ class _WireReader: self._get_section(MessageSection.ADDITIONAL, adcount) if not self.ignore_trailing and self.parser.remaining() != 0: raise TrailingJunk - if self.multi and self.message.tsig_ctx and \ - not self.message.had_tsig: + if self.multi and self.message.tsig_ctx and not self.message.had_tsig: self.message.tsig_ctx.update(self.parser.wire) except Exception as e: if self.continue_on_error: @@ -1068,73 +1154,78 @@ class _WireReader: return self.message -def from_wire(wire: bytes, keyring: Optional[Any]=None, request_mac: Optional[bytes]=b'', - xfr: bool=False, origin: Optional[dns.name.Name]=None, - tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]]=None, - multi: bool=False, question_only: bool=False, one_rr_per_rrset: bool=False, - ignore_trailing: bool=False, raise_on_truncation: bool=False, - continue_on_error: bool=False) -> Message: +def from_wire( + wire: bytes, + keyring: Optional[Any] = None, + request_mac: Optional[bytes] = b"", + xfr: bool = False, + origin: Optional[dns.name.Name] = None, + tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, + multi: bool = False, + question_only: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + continue_on_error: bool = False, +) -> Message: """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the - message is signed. + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message + is signed. - *request_mac*, a ``bytes`` or ``None``. If the message is a response to a TSIG-signed - request, *request_mac* should be set to the MAC of that request. + *request_mac*, a ``bytes`` or ``None``. If the message is a response to a + TSIG-signed request, *request_mac* should be set to the MAC of that request. - *xfr*, a ``bool``, should be set to ``True`` if this message is part of a - zone transfer. + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a zone + transfer. *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone - transfer, *origin* should be the origin name of the zone. If not ``None``, - names will be relativized to the origin. + transfer, *origin* should be the origin name of the zone. If not ``None``, names + will be relativized to the origin. - *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the - ongoing TSIG context, used when validating zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG + context, used when validating zone transfers. - *multi*, a ``bool``, should be set to ``True`` if this message is part of a - multiple message sequence. + *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple + message sequence. - *question_only*, a ``bool``. If ``True``, read only up to the end of the - question section. + *question_only*, a ``bool``. If ``True``, read only up to the end of the question + section. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of - the message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + message. - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the - TC bit is set. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is + set. - *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even - if errors occur. Erroneous rdata will be ignored. Errors will be - accumulated as a list of MessageError objects in the message's ``errors`` - attribute. This option is recommended only for DNS analysis tools, or for - use in a server as part of an error handling path. The default is - ``False``. + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even if + errors occur. Erroneous rdata will be ignored. Errors will be accumulated as a + list of MessageError objects in the message's ``errors`` attribute. This option is + recommended only for DNS analysis tools, or for use in a server as part of an error + handling path. The default is ``False``. - Raises ``dns.message.ShortHeader`` if the message is less than 12 octets - long. + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. - Raises ``dns.message.TrailingJunk`` if there were octets in the message past - the end of the proper DNS message, and *ignore_trailing* is ``False``. + Raises ``dns.message.TrailingJunk`` if there were octets in the message past the end + of the proper DNS message, and *ignore_trailing* is ``False``. Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or occurred more than once. - Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of - the additional data section. + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of the + additional data section. - Raises ``dns.message.Truncated`` if the TC flag is set and - *raise_on_truncation* is ``True``. + Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is + ``True``. Returns a ``dns.message.Message``. """ # We permit None for request_mac solely for backwards compatibility if request_mac is None: - request_mac = b'' + request_mac = b"" def initialize_message(message): message.request_mac = request_mac @@ -1142,14 +1233,24 @@ def from_wire(wire: bytes, keyring: Optional[Any]=None, request_mac: Optional[by message.origin = origin message.tsig_ctx = tsig_ctx - reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi, - continue_on_error) + reader = _WireReader( + wire, + initialize_message, + question_only, + one_rr_per_rrset, + ignore_trailing, + keyring, + multi, + continue_on_error, + ) try: m = reader.read() except dns.exception.FormError: - if reader.message and (reader.message.flags & dns.flags.TC) and \ - raise_on_truncation: + if ( + reader.message + and (reader.message.flags & dns.flags.TC) + and raise_on_truncation + ): raise Truncated(message=reader.message) else: raise @@ -1177,8 +1278,15 @@ class _TextReader: relativize_to: the origin to relativize to. """ - def __init__(self, text, idna_codec, one_rr_per_rrset=False, - origin=None, relativize=True, relativize_to=None): + def __init__( + self, + text, + idna_codec, + one_rr_per_rrset=False, + origin=None, + relativize=True, + relativize_to=None, + ): self.message = None self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None @@ -1199,19 +1307,19 @@ class _TextReader: token = self.tok.get() what = token.value - if what == 'id': + if what == "id": self.id = self.tok.get_int() - elif what == 'flags': + elif what == "flags": while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break self.flags = self.flags | dns.flags.from_text(token.value) - elif what == 'edns': + elif what == "edns": self.edns = self.tok.get_int() self.ednsflags = self.ednsflags | (self.edns << 16) - elif what == 'eflags': + elif what == "eflags": if self.edns < 0: self.edns = 0 while True: @@ -1219,17 +1327,16 @@ class _TextReader: if not token.is_identifier(): self.tok.unget(token) break - self.ednsflags = self.ednsflags | \ - dns.flags.edns_from_text(token.value) - elif what == 'payload': + self.ednsflags = self.ednsflags | dns.flags.edns_from_text(token.value) + elif what == "payload": self.payload = self.tok.get_int() if self.edns < 0: self.edns = 0 - elif what == 'opcode': + elif what == "opcode": text = self.tok.get_string() self.opcode = dns.opcode.from_text(text) self.flags = self.flags | dns.opcode.to_flags(self.opcode) - elif what == 'rcode': + elif what == "rcode": text = self.tok.get_string() self.rcode = dns.rcode.from_text(text) else: @@ -1242,9 +1349,9 @@ class _TextReader: section = self.message.sections[section_number] token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.message.origin, - self.relativize, - self.relativize_to) + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) name = self.last_name if name is None: raise NoPreviousName @@ -1263,10 +1370,12 @@ class _TextReader: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - (rdclass, rdtype, _, _) = \ - self.message._parse_rr_header(section_number, name, rdclass, rdtype) - self.message.find_rrset(section, name, rdclass, rdtype, create=True, - force_unique=True) + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + self.message.find_rrset( + section, name, rdclass, rdtype, create=True, force_unique=True + ) self.tok.get_eol() def _rr_line(self, section_number): @@ -1278,9 +1387,9 @@ class _TextReader: # Name token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.message.origin, - self.relativize, - self.relativize_to) + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) name = self.last_name if name is None: raise NoPreviousName @@ -1309,8 +1418,9 @@ class _TextReader: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_rr_header(section_number, name, rdclass, rdtype) + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) token = self.tok.get() if empty and not token.is_eol_or_eof(): raise dns.exception.SyntaxError @@ -1318,16 +1428,28 @@ class _TextReader: raise dns.exception.UnexpectedEnd if not token.is_eol_or_eof(): self.tok.unget(token) - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.message.origin, self.relativize, - self.relativize_to) + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.message.origin, + self.relativize, + self.relativize_to, + ) covers = rd.covers() else: rd = None covers = dns.rdatatype.NONE - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, self.one_rr_per_rrset) + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + self.one_rr_per_rrset, + ) if rd is not None: rrset.add(rd, ttl) @@ -1355,7 +1477,7 @@ class _TextReader: break if token.is_comment(): u = token.value.upper() - if u == 'HEADER': + if u == "HEADER": line_method = self._header_line if self.message: @@ -1370,8 +1492,9 @@ class _TextReader: # use the one we just created. if not self.message: self.message = message - self.one_rr_per_rrset = \ - message._get_one_rr_per_rrset(self.one_rr_per_rrset) + self.one_rr_per_rrset = message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) if section_number == MessageSection.QUESTION: line_method = self._question_line else: @@ -1388,9 +1511,14 @@ class _TextReader: return self.message -def from_text(text: str, idna_codec: Optional[dns.name.IDNACodec]=None, - one_rr_per_rrset: bool=False, origin: Optional[dns.name.Name]=None, - relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Message: +def from_text( + text: str, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Message: """Convert the text format message into a message object. The reader stops after reading the first blank line in the input to @@ -1425,12 +1553,17 @@ def from_text(text: str, idna_codec: Optional[dns.name.IDNACodec]=None, # since it's an implementation detail. The official file # interface is from_file(). - reader = _TextReader(text, idna_codec, one_rr_per_rrset, origin, - relativize, relativize_to) + reader = _TextReader( + text, idna_codec, one_rr_per_rrset, origin, relativize, relativize_to + ) return reader.read() -def from_file(f: Any, idna_codec: Optional[dns.name.IDNACodec]=None, one_rr_per_rrset: bool=False) -> Message: +def from_file( + f: Any, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, +) -> Message: """Read the next text format message from the specified file. Message blocks are separated by a single blank line. @@ -1459,14 +1592,20 @@ def from_file(f: Any, idna_codec: Optional[dns.name.IDNACodec]=None, one_rr_per_ assert False # for mypy lgtm[py/unreachable-statement] -def make_query(qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - use_edns: Optional[Union[int, bool]]=None, - want_dnssec: bool=False, ednsflags: Optional[int]=None, payload: Optional[int]=None, - request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, - idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None, - flags: int=dns.flags.RD) -> QueryMessage: +def make_query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + use_edns: Optional[Union[int, bool]] = None, + want_dnssec: bool = False, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + id: Optional[int] = None, + flags: int = dns.flags.RD, +) -> QueryMessage: """Make a query message. The query name, type, and class may all be specified either @@ -1523,30 +1662,36 @@ def make_query(qname: Union[dns.name.Name, str], the_rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset(m.question, qname, the_rdclass, the_rdtype, create=True, - force_unique=True) + m.find_rrset( + m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True + ) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. kwargs: Dict[str, Any] = {} if ednsflags is not None: - kwargs['ednsflags'] = ednsflags + kwargs["ednsflags"] = ednsflags if payload is not None: - kwargs['payload'] = payload + kwargs["payload"] = payload if request_payload is not None: - kwargs['request_payload'] = request_payload + kwargs["request_payload"] = request_payload if options is not None: - kwargs['options'] = options + kwargs["options"] = options if kwargs and use_edns is None: use_edns = 0 - kwargs['edns'] = use_edns + kwargs["edns"] = use_edns m.use_edns(**kwargs) m.want_dnssec(want_dnssec) return m -def make_response(query: Message, recursion_available: bool=False, our_payload: int=8192, - fudge: int=300, tsig_error: int=0) -> Message: +def make_response( + query: Message, + recursion_available: bool = False, + our_payload: int = 8192, + fudge: int = 300, + tsig_error: int = 0, +) -> Message: """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the @@ -1573,7 +1718,7 @@ def make_response(query: Message, recursion_available: bool=False, our_payload: """ if query.flags & dns.flags.QR: - raise dns.exception.FormError('specified query message is not a query') + raise dns.exception.FormError("specified query message is not a query") factory = _message_factory_from_opcode(query.opcode()) response = factory(id=query.id) response.flags = dns.flags.QR | (query.flags & dns.flags.RD) @@ -1584,11 +1729,19 @@ def make_response(query: Message, recursion_available: bool=False, our_payload: if query.edns >= 0: response.use_edns(0, 0, our_payload, query.payload) if query.had_tsig: - response.use_tsig(query.keyring, query.keyname, fudge, None, - tsig_error, b'', query.keyalgorithm) + response.use_tsig( + query.keyring, + query.keyname, + fudge, + None, + tsig_error, + b"", + query.keyalgorithm, + ) response.request_mac = query.mac return response + ### BEGIN generated MessageSection constants QUESTION = MessageSection.QUESTION diff --git a/dns/name.py b/dns/name.py index daf1259c..2ebda4a4 100644 --- a/dns/name.py +++ b/dns/name.py @@ -23,9 +23,11 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import copy import struct -import encodings.idna # type: ignore +import encodings.idna # type: ignore + try: - import idna # type: ignore + import idna # type: ignore + have_idna_2008 = True except ImportError: # pragma: no cover have_idna_2008 = False @@ -36,7 +38,7 @@ import dns.exception import dns.immutable -CompressType = Dict['Name', int] +CompressType = Dict["Name", int] class NameRelation(dns.enum.IntEnum): @@ -111,6 +113,7 @@ class NoParent(dns.exception.DNSException): """An attempt was made to get the parent of the root name or the empty name.""" + class NoIDNA2008(dns.exception.DNSException): """IDNA 2008 processing was requested but the idna module is not available.""" @@ -119,10 +122,11 @@ class NoIDNA2008(dns.exception.DNSException): class IDNAException(dns.exception.DNSException): """IDNA processing raised an exception.""" - supp_kwargs = {'idna_exception'} + supp_kwargs = {"idna_exception"} fmt = "IDNA processing exception: {idna_exception}" - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -130,6 +134,7 @@ class IDNAException(dns.exception.DNSException): _escaped = b'"().;\\@$' _escaped_text = '"().;\\@$' + def _escapify(label: Union[bytes, str]) -> str: """Escape the characters in label which need it. @returns: the escaped string @@ -137,23 +142,23 @@ def _escapify(label: Union[bytes, str]) -> str: if isinstance(label, bytes): # Ordinary DNS label mode. Escape special characters and values # < 0x20 or > 0x7f. - text = '' + text = "" for c in label: if c in _escaped: - text += '\\' + chr(c) + text += "\\" + chr(c) elif c > 0x20 and c < 0x7F: text += chr(c) else: - text += '\\%03d' % c + text += "\\%03d" % c return text # Unicode label mode. Escape only special characters and values < 0x20 - text = '' + text = "" for uc in label: if uc in _escaped_text: - text += '\\' + uc - elif uc <= '\x20': - text += '\\%03d' % ord(uc) + text += "\\" + uc + elif uc <= "\x20": + text += "\\%03d" % ord(uc) else: text += uc return text @@ -166,7 +171,7 @@ class IDNACodec: pass def is_idna(self, label: bytes) -> bool: - return label.lower().startswith(b'xn--') + return label.lower().startswith(b"xn--") def encode(self, label: str) -> bytes: raise NotImplementedError # pragma: no cover @@ -175,7 +180,7 @@ class IDNACodec: # We do not apply any IDNA policy on decode. if self.is_idna(label): try: - slabel = label[4:].decode('punycode') + slabel = label[4:].decode("punycode") return _escapify(slabel) except Exception as e: raise IDNAException(idna_exception=e) @@ -186,7 +191,7 @@ class IDNACodec: class IDNA2003Codec(IDNACodec): """IDNA 2003 encoder/decoder.""" - def __init__(self, strict_decode: bool=False): + def __init__(self, strict_decode: bool = False): """Initialize the IDNA 2003 encoder/decoder. *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking @@ -200,8 +205,8 @@ class IDNA2003Codec(IDNACodec): def encode(self, label: str) -> bytes: """Encode *label*.""" - if label == '': - return b'' + if label == "": + return b"" try: return encodings.idna.ToASCII(label) except UnicodeError: @@ -211,8 +216,8 @@ class IDNA2003Codec(IDNACodec): """Decode *label*.""" if not self.strict_decode: return super().decode(label) - if label == b'': - return '' + if label == b"": + return "" try: return _escapify(encodings.idna.ToUnicode(label)) except Exception as e: @@ -220,11 +225,15 @@ class IDNA2003Codec(IDNACodec): class IDNA2008Codec(IDNACodec): - """IDNA 2008 encoder/decoder. - """ - - def __init__(self, uts_46: bool=False, transitional: bool=False, - allow_pure_ascii: bool=False, strict_decode: bool=False): + """IDNA 2008 encoder/decoder.""" + + def __init__( + self, + uts_46: bool = False, + transitional: bool = False, + allow_pure_ascii: bool = False, + strict_decode: bool = False, + ): """Initialize the IDNA 2008 encoder/decoder. *uts_46* is a ``bool``. If True, apply Unicode IDNA @@ -254,10 +263,10 @@ class IDNA2008Codec(IDNACodec): self.strict_decode = strict_decode def encode(self, label: str) -> bytes: - if label == '': - return b'' + if label == "": + return b"" if self.allow_pure_ascii and is_all_ascii(label): - encoded = label.encode('ascii') + encoded = label.encode("ascii") if len(encoded) > 63: raise LabelTooLong return encoded @@ -268,7 +277,7 @@ class IDNA2008Codec(IDNACodec): label = idna.uts46_remap(label, False, self.transitional) return idna.alabel(label) except idna.IDNAError as e: - if e.args[0] == 'Label too long': + if e.args[0] == "Label too long": raise LabelTooLong else: raise IDNAException(idna_exception=e) @@ -276,8 +285,8 @@ class IDNA2008Codec(IDNACodec): def decode(self, label: bytes) -> str: if not self.strict_decode: return super().decode(label) - if label == b'': - return '' + if label == b"": + return "" if not have_idna_2008: raise NoIDNA2008 try: @@ -288,6 +297,7 @@ class IDNA2008Codec(IDNACodec): except (idna.IDNAError, UnicodeError) as e: raise IDNAException(idna_exception=e) + IDNA_2003_Practical = IDNA2003Codec(False) IDNA_2003_Strict = IDNA2003Codec(True) IDNA_2003 = IDNA_2003_Practical @@ -297,6 +307,7 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) IDNA_2008 = IDNA_2008_Practical + def _validate_labels(labels: Tuple[bytes, ...]) -> None: """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. @@ -318,7 +329,7 @@ def _validate_labels(labels: Tuple[bytes, ...]) -> None: total += ll + 1 if ll > 63: raise LabelTooLong - if i < 0 and label == b'': + if i < 0 and label == b"": i = j j += 1 if total > 255: @@ -350,11 +361,10 @@ class Name: of the class are immutable. """ - __slots__ = ['labels'] + __slots__ = ["labels"] def __init__(self, labels: Iterable[Union[bytes, str]]): - """*labels* is any iterable whose values are ``str`` or ``bytes``. - """ + """*labels* is any iterable whose values are ``str`` or ``bytes``.""" blabels = [_maybe_convert_to_binary(x) for x in labels] self.labels = tuple(blabels) @@ -368,10 +378,10 @@ class Name: def __getstate__(self): # Names can be pickled - return {'labels': self.labels} + return {"labels": self.labels} def __setstate__(self, state): - super().__setattr__('labels', state['labels']) + super().__setattr__("labels", state["labels"]) _validate_labels(self.labels) def is_absolute(self) -> bool: @@ -380,7 +390,7 @@ class Name: Returns a ``bool``. """ - return len(self.labels) > 0 and self.labels[-1] == b'' + return len(self.labels) > 0 and self.labels[-1] == b"" def is_wild(self) -> bool: """Is this name wild? (I.e. Is the least significant label '*'?) @@ -388,7 +398,7 @@ class Name: Returns a ``bool``. """ - return len(self.labels) > 0 and self.labels[0] == b'*' + return len(self.labels) > 0 and self.labels[0] == b"*" def __hash__(self) -> int: """Return a case-insensitive hash of the name. @@ -402,7 +412,7 @@ class Name: h += (h << 3) + c return h - def fullcompare(self, other: 'Name') -> Tuple[NameRelation, int, int]: + def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]: """Compare two names, returning a 3-tuple ``(relation, order, nlabels)``. @@ -478,7 +488,7 @@ class Name: namereln = NameRelation.EQUAL return (namereln, order, nlabels) - def is_subdomain(self, other: 'Name') -> bool: + def is_subdomain(self, other: "Name") -> bool: """Is self a subdomain of other? Note that the notion of subdomain includes equality, e.g. @@ -492,7 +502,7 @@ class Name: return True return False - def is_superdomain(self, other: 'Name') -> bool: + def is_superdomain(self, other: "Name") -> bool: """Is self a superdomain of other? Note that the notion of superdomain includes equality, e.g. @@ -506,7 +516,7 @@ class Name: return True return False - def canonicalize(self) -> 'Name': + def canonicalize(self) -> "Name": """Return a name which is equal to the current name, but is in DNSSEC canonical form. """ @@ -550,12 +560,12 @@ class Name: return NotImplemented def __repr__(self): - return '' + return "" def __str__(self): return self.to_text(False) - def to_text(self, omit_final_dot: bool=False) -> str: + def to_text(self, omit_final_dot: bool = False) -> str: """Convert name to DNS text format. *omit_final_dot* is a ``bool``. If True, don't emit the final @@ -566,17 +576,19 @@ class Name: """ if len(self.labels) == 0: - return '@' - if len(self.labels) == 1 and self.labels[0] == b'': - return '.' + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels - s = '.'.join(map(_escapify, l)) + s = ".".join(map(_escapify, l)) return s - def to_unicode(self, omit_final_dot: bool=False, idna_codec: Optional[IDNACodec]=None) -> str: + def to_unicode( + self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None + ) -> str: """Convert name to Unicode text format. IDN ACE labels are converted to Unicode. @@ -595,18 +607,18 @@ class Name: """ if len(self.labels) == 0: - return '@' - if len(self.labels) == 1 and self.labels[0] == b'': - return '.' + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels if idna_codec is None: idna_codec = IDNA_2003_Practical - return '.'.join([idna_codec.decode(x) for x in l]) + return ".".join([idna_codec.decode(x) for x in l]) - def to_digestable(self, origin: Optional['Name']=None) -> bytes: + def to_digestable(self, origin: Optional["Name"] = None) -> bytes: """Convert name to a format suitable for digesting in hashes. The name is canonicalized and converted to uncompressed wire @@ -627,8 +639,13 @@ class Name: assert digest is not None return digest - def to_wire(self, file: Optional[Any]=None, compress: Optional[CompressType]=None, - origin: Optional['Name']=None, canonicalize: bool=False) -> Optional[bytes]: + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[CompressType] = None, + origin: Optional["Name"] = None, + canonicalize: bool = False, + ) -> Optional[bytes]: """Convert name to wire format, possibly compressing it. *file* is the file where the name is emitted (typically an @@ -691,17 +708,17 @@ class Name: else: pos = None if pos is not None: - value = 0xc000 + pos - s = struct.pack('!H', value) + value = 0xC000 + pos + s = struct.pack("!H", value) file.write(s) break else: if compress is not None and len(n) > 1: pos = file.tell() - if pos <= 0x3fff: + if pos <= 0x3FFF: compress[n] = pos l = len(label) - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) if l > 0: if canonicalize: file.write(label.lower()) @@ -726,7 +743,7 @@ class Name: def __sub__(self, other): return self.relativize(other) - def split(self, depth: int) -> Tuple['Name', 'Name']: + def split(self, depth: int) -> Tuple["Name", "Name"]: """Split a name into a prefix and suffix names at the specified depth. *depth* is an ``int`` specifying the number of labels in the suffix @@ -743,11 +760,10 @@ class Name: elif depth == l: return (dns.name.empty, self) elif depth < 0 or depth > l: - raise ValueError( - 'depth must be >= 0 and <= the length of the name') - return (Name(self[: -depth]), Name(self[-depth:])) + raise ValueError("depth must be >= 0 and <= the length of the name") + return (Name(self[:-depth]), Name(self[-depth:])) - def concatenate(self, other: 'Name') -> 'Name': + def concatenate(self, other: "Name") -> "Name": """Return a new name which is the concatenation of self and other. Raises ``dns.name.AbsoluteConcatenation`` if the name is @@ -762,7 +778,7 @@ class Name: labels.extend(list(other.labels)) return Name(labels) - def relativize(self, origin: 'Name') -> 'Name': + def relativize(self, origin: "Name") -> "Name": """If the name is a subdomain of *origin*, return a new name which is the name relative to origin. Otherwise return the name. @@ -778,7 +794,7 @@ class Name: else: return self - def derelativize(self, origin: 'Name') -> 'Name': + def derelativize(self, origin: "Name") -> "Name": """If the name is a relative name, return a new name which is the concatenation of the name and origin. Otherwise return the name. @@ -794,7 +810,9 @@ class Name: else: return self - def choose_relativity(self, origin: Optional['Name']=None, relativize: bool=True) -> 'Name': + def choose_relativity( + self, origin: Optional["Name"] = None, relativize: bool = True + ) -> "Name": """Return a name with the relativity desired by the caller. If *origin* is ``None``, then the name is returned. @@ -813,7 +831,7 @@ class Name: else: return self - def parent(self) -> 'Name': + def parent(self) -> "Name": """Return the parent of the name. For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. @@ -828,13 +846,17 @@ class Name: raise NoParent return Name(self.labels[1:]) + #: The root name, '.' -root = Name([b'']) +root = Name([b""]) #: The empty name. empty = Name([]) -def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name: + +def from_unicode( + text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None +) -> Name: """Convert unicode text into a Name object. Labels are encoded in IDN ACE form according to rules specified by @@ -857,17 +879,17 @@ def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[ID if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") labels = [] - label = '' + label = "" escaping = False edigits = 0 total = 0 if idna_codec is None: idna_codec = IDNA_2003 - if text == '@': - text = '' + if text == "@": + text = "" if text: - if text in ['.', '\u3002', '\uff0e', '\uff61']: - return Name([b'']) # no Unicode "u" on this constant! + if text in [".", "\u3002", "\uff0e", "\uff61"]: + return Name([b""]) # no Unicode "u" on this constant! for c in text: if escaping: if edigits == 0: @@ -886,12 +908,12 @@ def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[ID if edigits == 3: escaping = False label += chr(total) - elif c in ['.', '\u3002', '\uff0e', '\uff61']: + elif c in [".", "\u3002", "\uff0e", "\uff61"]: if len(label) == 0: raise EmptyLabel labels.append(idna_codec.encode(label)) - label = '' - elif c == '\\': + label = "" + elif c == "\\": escaping = True edigits = 0 total = 0 @@ -902,19 +924,25 @@ def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[ID if len(label) > 0: labels.append(idna_codec.encode(label)) else: - labels.append(b'') + labels.append(b"") - if (len(labels) == 0 or labels[-1] != b'') and origin is not None: + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: labels.extend(list(origin.labels)) return Name(labels) + def is_all_ascii(text: str) -> bool: for c in text: - if ord(c) > 0x7f: + if ord(c) > 0x7F: return False return True -def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name: + +def from_text( + text: Union[bytes, str], + origin: Optional[Name] = root, + idna_codec: Optional[IDNACodec] = None, +) -> Name: """Convert text into a Name object. *text*, a ``bytes`` or ``str``, is the text to convert into a name. @@ -941,23 +969,23 @@ def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: # # then it's still "all ASCII" even though the domain name has # codepoints > 127. - text = text.encode('ascii') + text = text.encode("ascii") if not isinstance(text, bytes): raise ValueError("input to from_text() must be a string") if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") labels = [] - label = b'' + label = b"" escaping = False edigits = 0 total = 0 - if text == b'@': - text = b'' + if text == b"@": + text = b"" if text: - if text == b'.': - return Name([b'']) + if text == b".": + return Name([b""]) for c in text: - byte_ = struct.pack('!B', c) + byte_ = struct.pack("!B", c) if escaping: if edigits == 0: if byte_.isdigit(): @@ -974,13 +1002,13 @@ def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: edigits += 1 if edigits == 3: escaping = False - label += struct.pack('!B', total) - elif byte_ == b'.': + label += struct.pack("!B", total) + elif byte_ == b".": if len(label) == 0: raise EmptyLabel labels.append(label) - label = b'' - elif byte_ == b'\\': + label = b"" + elif byte_ == b"\\": escaping = True edigits = 0 total = 0 @@ -991,14 +1019,16 @@ def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: if len(label) > 0: labels.append(label) else: - labels.append(b'') - if (len(labels) == 0 or labels[-1] != b'') and origin is not None: + labels.append(b"") + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: labels.extend(list(origin.labels)) return Name(labels) + # we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other. -def from_wire_parser(parser: 'dns.wire.Parser') -> Name: + +def from_wire_parser(parser: "dns.wire.Parser") -> Name: """Convert possibly compressed wire format into a Name. *parser* is a dns.wire.Parser. @@ -1019,7 +1049,7 @@ def from_wire_parser(parser: 'dns.wire.Parser') -> Name: if count < 64: labels.append(parser.get_bytes(count)) elif count >= 192: - current = (count & 0x3f) * 256 + parser.get_uint8() + current = (count & 0x3F) * 256 + parser.get_uint8() if current >= biggest_pointer: raise BadPointer biggest_pointer = current @@ -1027,7 +1057,7 @@ def from_wire_parser(parser: 'dns.wire.Parser') -> Name: else: raise BadLabelType count = parser.get_uint8() - labels.append(b'') + labels.append(b"") return Name(labels) diff --git a/dns/namedict.py b/dns/namedict.py index ec0750ce..fe118a35 100644 --- a/dns/namedict.py +++ b/dns/namedict.py @@ -62,7 +62,7 @@ class NameDict(MutableMapping): def __setitem__(self, key, value): if not isinstance(key, dns.name.Name): - raise ValueError('NameDict key must be a name') + raise ValueError("NameDict key must be a name") self.__store[key] = value self.__update_max_depth(key) diff --git a/dns/node.py b/dns/node.py index 5270b53a..d870a299 100644 --- a/dns/node.py +++ b/dns/node.py @@ -37,26 +37,28 @@ _cname_types = { # "neutral" types can coexist with a CNAME and thus are not "other data" _neutral_types = { - dns.rdatatype.NSEC, # RFC 4035 section 2.5 + dns.rdatatype.NSEC, # RFC 4035 section 2.5 dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible! - dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 + dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 } + def _matches_type_or_its_signature(rdtypes, rdtype, covers): - return rdtype in rdtypes or \ - (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) + return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) @enum.unique class NodeKind(enum.Enum): - """Rdatasets in nodes - """ - REGULAR = 0 # a.k.a "other data" + """Rdatasets in nodes""" + + REGULAR = 0 # a.k.a "other data" NEUTRAL = 1 CNAME = 2 @classmethod - def classify(cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType) -> 'NodeKind': + def classify( + cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType + ) -> "NodeKind": if _matches_type_or_its_signature(_cname_types, rdtype, covers): return NodeKind.CNAME elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): @@ -65,7 +67,7 @@ class NodeKind(enum.Enum): return NodeKind.REGULAR @classmethod - def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> 'NodeKind': + def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind": return cls.classify(rdataset.rdtype, rdataset.covers) @@ -86,7 +88,7 @@ class Node: deleted. """ - __slots__ = ['rdatasets'] + __slots__ = ["rdatasets"] def __init__(self): # the set of rdatasets, represented as a list. @@ -109,11 +111,11 @@ class Node: for rds in self.rdatasets: if len(rds) > 0: s.write(rds.to_text(name, **kw)) # type: ignore[arg-type] - s.write('\n') + s.write("\n") return s.getvalue()[:-1] def __repr__(self): - return '' + return "" def __eq__(self, other): # @@ -149,22 +151,28 @@ class Node: if len(self.rdatasets) > 0: kind = NodeKind.classify_rdataset(rdataset) if kind == NodeKind.CNAME: - self.rdatasets = [rds for rds in self.rdatasets if - NodeKind.classify_rdataset(rds) != - NodeKind.REGULAR] + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR + ] elif kind == NodeKind.REGULAR: - self.rdatasets = [rds for rds in self.rdatasets if - NodeKind.classify_rdataset(rds) != - NodeKind.CNAME] + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.CNAME + ] # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to # edit self.rdatasets. self.rdatasets.append(rdataset) - def find_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> dns.rdataset.Rdataset: + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: """Find an rdataset matching the specified properties in the current node. @@ -199,11 +207,13 @@ class Node: self._append_rdataset(rds) return rds - def get_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: """Get an rdataset matching the specified properties in the current node. @@ -234,10 +244,12 @@ class Node: rds = None return rds - def delete_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: """Delete the rdataset matching the specified properties in the current node. @@ -270,13 +282,14 @@ class Node: """ if not isinstance(replacement, dns.rdataset.Rdataset): - raise ValueError('replacement is not an rdataset') + raise ValueError("replacement is not an rdataset") if isinstance(replacement, dns.rrset.RRset): # RRsets are not good replacements as the match() method # is not compatible. replacement = replacement.to_rdataset() - self.delete_rdataset(replacement.rdclass, replacement.rdtype, - replacement.covers) + self.delete_rdataset( + replacement.rdclass, replacement.rdtype, replacement.covers + ) self._append_rdataset(replacement) def classify(self) -> NodeKind: @@ -312,28 +325,34 @@ class ImmutableNode(Node): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> dns.rdataset.Rdataset: + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, - rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: raise TypeError("immutable") def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: diff --git a/dns/opcode.py b/dns/opcode.py index 971b62c8..78b43d2c 100644 --- a/dns/opcode.py +++ b/dns/opcode.py @@ -20,6 +20,7 @@ import dns.enum import dns.exception + class Opcode(dns.enum.IntEnum): #: Query QUERY = 0 @@ -104,6 +105,7 @@ def is_update(flags: int) -> bool: return from_flags(flags) == Opcode.UPDATE + ### BEGIN generated Opcode constants QUERY = Opcode.QUERY diff --git a/dns/query.py b/dns/query.py index 09d51078..2c3da4f8 100644 --- a/dns/query.py +++ b/dns/query.py @@ -46,6 +46,7 @@ try: import requests from requests_toolbelt.adapters.source import SourceAddressAdapter from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter + _have_requests = True except ImportError: # pragma: no cover _have_requests = False @@ -54,6 +55,7 @@ _have_httpx = False _have_http2 = False try: import httpx + _have_httpx = True try: # See if http2 support is available. @@ -69,8 +71,8 @@ have_doh = _have_requests or _have_httpx try: import ssl except ImportError: # pragma: no cover - class ssl: # type: ignore + class ssl: # type: ignore class WantReadException(Exception): pass @@ -85,12 +87,14 @@ except ImportError: # pragma: no cover @classmethod def create_default_context(cls, *args, **kwargs): - raise Exception('no ssl support') + raise Exception("no ssl support") + # Function used to create a socket. Can be overridden if needed in special # situations. socket_factory = socket.socket + class UnexpectedSource(dns.exception.DNSException): """A DNS query response came from an unexpected address or port.""" @@ -151,7 +155,8 @@ def _set_selector_class(selector_class): _selector_class = selector_class -if hasattr(selectors, 'PollSelector'): + +if hasattr(selectors, "PollSelector"): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). @@ -188,18 +193,20 @@ def _matches_destination(af, from_address, destination, ignore_unexpected): # sent to destination. if not destination: return True - if _addresses_equal(af, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _addresses_equal(af, from_address, destination) or ( + dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:] + ): return True elif ignore_unexpected: return False - raise UnexpectedSource(f'got a response from {from_address} instead of ' - f'{destination}') + raise UnexpectedSource( + f"got a response from {from_address} instead of " f"{destination}" + ) -def _destination_and_source(where, port, source, source_port, - where_must_be_address=True): +def _destination_and_source( + where, port, source, source_port, where_must_be_address=True +): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). af = None @@ -216,8 +223,9 @@ def _destination_and_source(where, port, source, source_port, if af: # We know the destination af, so source had better agree! if saf != af: - raise ValueError('different address families for source ' + - 'and destination') + raise ValueError( + "different address families for source " + "and destination" + ) else: # We didn't know the destination af, but we know the source, # so that's our af. @@ -227,12 +235,11 @@ def _destination_and_source(where, port, source, source_port, # need to return a source, and we need to use the appropriate # wildcard address as the address. if af == socket.AF_INET: - source = '0.0.0.0' + source = "0.0.0.0" elif af == socket.AF_INET6: - source = '::' + source = "::" else: - raise ValueError('source_port specified but address family is ' - 'unknown') + raise ValueError("source_port specified but address family is " "unknown") # Convert high-level (address, port) tuples into low-level address # tuples. if destination: @@ -241,6 +248,7 @@ def _destination_and_source(where, port, source, source_port, source = dns.inet.low_level_address_tuple((source, source_port), af) return (af, destination, source) + def _make_socket(af, type, source, ssl_context=None, server_hostname=None): s = socket_factory(af, type) try: @@ -249,19 +257,33 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): s.bind(source) if ssl_context: # LGTM gets a false positive here, as our default context is OK - return ssl_context.wrap_socket(s, do_handshake_on_connect=False, # lgtm[py/insecure-protocol] - server_hostname=server_hostname) + return ssl_context.wrap_socket( + s, + do_handshake_on_connect=False, # lgtm[py/insecure-protocol] + server_hostname=server_hostname, + ) else: return s except Exception: s.close() raise -def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port: int=443, source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - session: Optional[Any]=None, path: str='/dns-query', post: bool=True, - bootstrap_address: Optional[str]=None, verify: bool=True) -> dns.message.Message: + +def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + session: Optional[Any] = None, + path: str = "/dns-query", + post: bool = True, + bootstrap_address: Optional[str] = None, + verify: bool = True, +) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. @@ -304,29 +326,26 @@ def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, """ if not have_doh: - raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover + raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover _httpx_ok = _have_httpx wire = q.to_wire() - (af, _, source) = _destination_and_source(where, port, source, source_port, - False) + (af, _, source) = _destination_and_source(where, port, source, source_port, False) transport_adapter = None transport = None - headers = { - "accept": "application/dns-message" - } + headers = {"accept": "application/dns-message"} if af is not None: if af == socket.AF_INET: - url = 'https://{}:{}{}'.format(where, port, path) + url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: - url = 'https://[{}]:{}{}'.format(where, port, path) + url = "https://[{}]:{}{}".format(where, port, path) elif bootstrap_address is not None: _httpx_ok = False split_url = urllib.parse.urlsplit(where) if split_url.hostname is None: - raise ValueError('DoH URL has no hostname') - headers['Host'] = split_url.hostname + raise ValueError("DoH URL has no hostname") + headers["Host"] = split_url.hostname url = where.replace(split_url.hostname, bootstrap_address) if _have_requests: transport_adapter = HostHeaderSSLAdapter() @@ -348,22 +367,29 @@ def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, else: _is_httpx = False if _is_httpx and not _httpx_ok: - raise NoDOH('Session is httpx, but httpx cannot be used for ' - 'the requested operation.') + raise NoDOH( + "Session is httpx, but httpx cannot be used for " + "the requested operation." + ) else: _is_httpx = _httpx_ok if not _httpx_ok and not _have_requests: - raise NoDOH('Cannot use httpx for this operation, and ' - 'requests is not available.') + raise NoDOH( + "Cannot use httpx for this operation, and " "requests is not available." + ) with contextlib.ExitStack() as stack: if not session: if _is_httpx: - session = stack.enter_context(httpx.Client(http1=True, - http2=_have_http2, - verify=verify, - transport=transport)) + session = stack.enter_context( + httpx.Client( + http1=True, + http2=_have_http2, + verify=verify, + transport=transport, + ) + ) else: session = stack.enter_context(requests.sessions.Session()) @@ -373,45 +399,56 @@ def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) if _is_httpx: - response = session.post(url, headers=headers, content=wire, - timeout=timeout) + response = session.post( + url, headers=headers, content=wire, timeout=timeout + ) else: - response = session.post(url, headers=headers, data=wire, - timeout=timeout, verify=verify) + response = session.post( + url, headers=headers, data=wire, timeout=timeout, verify=verify + ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") if _is_httpx: twire = wire.decode() # httpx does a repr() if we give it bytes - response = session.get(url, headers=headers, - timeout=timeout, - params={"dns": twire}) + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) else: - response = session.get(url, headers=headers, - timeout=timeout, verify=verify, - params={"dns": wire}) + response = session.get( + url, + headers=headers, + timeout=timeout, + verify=verify, + params={"dns": wire}, + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}' - '\nResponse body: {}'.format(where, - response.status_code, - response.content)) - r = dns.message.from_wire(response.content, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {}".format(where, response.status_code, response.content) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r + def _udp_recv(sock, max_size, expiration): """Reads a datagram from the socket. A Timeout exception will be raised if the operation is not completed @@ -439,8 +476,12 @@ def _udp_send(sock, data, destination, expiration): _wait_for_writable(sock, expiration) -def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: Any, - expiration: Optional[float]=None) -> Tuple[int, float]: +def send_udp( + sock: Any, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``socket``. @@ -464,10 +505,17 @@ def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: An return (n, sent_time) -def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'', - ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any: +def receive_udp( + sock: Any, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, +) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``socket``. @@ -509,26 +557,41 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional the message arrived from. """ - wire = b'' + wire = b"" while True: (wire, from_address) = _udp_recv(sock, 65535, expiration) - if _matches_destination(sock.family, from_address, destination, - ignore_unexpected): + if _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): break received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) if destination: return (r, received_time) else: return (r, received_time, from_address) -def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - raise_on_truncation: bool=False, sock: Optional[Any]=None) -> dns.message.Message: + +def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[Any] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -568,8 +631,9 @@ def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: """ wire = q.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) (begin_time, expiration) = _compute_times(timeout) with contextlib.ExitStack() as stack: if sock: @@ -577,21 +641,39 @@ def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: else: s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source)) send_udp(s, wire, destination, expiration) - (r, received_time) = receive_udp(s, destination, expiration, - ignore_unexpected, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing, - raise_on_truncation) + (r, received_time) = receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r - assert False # help mypy figure out we can't get here lgtm[py/unreachable-statement] - -def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - udp_sock: Optional[Any]=None, - tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]: + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) + + +def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[Any] = None, + tcp_sock: Optional[Any] = None, +) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -635,26 +717,46 @@ def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[floa if and only if TCP was used. """ try: - response = udp(q, where, timeout, port, source, source_port, - ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True, udp_sock) + response = udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + ) return (response, False) except dns.message.Truncated: - response = tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, tcp_sock) + response = tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + ) return (response, True) + def _net_read(sock, count, expiration): """Read the specified number of bytes from sock. Keep trying until we either get the desired amount, or we hit EOF. A Timeout exception will be raised if the operation is not completed by the expiration time. """ - s = b'' + s = b"" while count > 0: try: n = sock.recv(count) - if n == b'': + if n == b"": raise EOFError count -= len(n) s += n @@ -681,8 +783,11 @@ def _net_write(sock, data, expiration): _wait_for_readable(sock, expiration) -def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], - expiration: Optional[float]=None) -> Tuple[int, float]: +def send_tcp( + sock: Any, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``socket``. @@ -709,10 +814,15 @@ def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) -def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: bool=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, - request_mac: Optional[bytes]=b'', - ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]: + +def receive_tcp( + sock: Any, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``socket``. @@ -742,11 +852,16 @@ def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: b (l,) = struct.unpack("!H", ldata) wire = _net_read(sock, l, expiration) received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) return (r, received_time) + def _connect(s, address, expiration): err = s.connect_ex(address) if err == 0: @@ -758,10 +873,17 @@ def _connect(s, address, expiration): raise OSError(err, os.strerror(err)) -def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, - source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, - sock: Optional[Any]=None) -> dns.message.Message: +def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[Any] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *q*, a ``dns.message.Message``, the query to send @@ -800,20 +922,22 @@ def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: if sock: s = sock else: - (af, destination, source) = _destination_and_source(where, port, - source, - source_port) - s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, - source)) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, source)) _connect(s, destination, expiration) send_tcp(s, wire, expiration) - (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r - assert False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) def _tls_handshake(s, expiration): @@ -827,11 +951,19 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) -def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port: int=853, source: Optional[str]=None, source_port: int=0, - one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[ssl.SSLSocket]=None, - ssl_context: Optional[ssl.SSLContext]=None, - server_hostname: Optional[str]=None) -> dns.message.Message: +def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[ssl.SSLSocket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *q*, a ``dns.message.Message``, the query to send @@ -878,13 +1010,23 @@ def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, # # If a socket was provided, there's no special TLS handling needed. # - return tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, sock) + return tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + sock, + ) wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) if ssl_context is None and not sock: # LGTM complains about this because the default might permit TLS < 1.2 # for compatibility, but the python documentation says that explicit @@ -897,28 +1039,45 @@ def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, if server_hostname is None: ssl_context.check_hostname = False - with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context, - server_hostname=server_hostname) as s: + with _make_socket( + af, + socket.SOCK_STREAM, + source, + ssl_context=ssl_context, + server_hostname=server_hostname, + ) as s: _connect(s, destination, expiration) _tls_handshake(s, expiration) send_tcp(s, wire, expiration) - (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r - assert False # help mypy figure out we can't get here lgtm[py/unreachable-statement] - -def xfr(where: str, zone: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.AXFR, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - timeout: Optional[float]=None, port: int=53, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, - keyname: Optional[Union[dns.name.Name, str]]=None, relativize: bool=True, - lifetime: Optional[float]=None, source: Optional[str]=None, source_port: int=0, - serial: int=0, use_udp: bool=False, - keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> Any: + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) + + +def xfr( + where: str, + zone: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + timeout: Optional[float] = None, + port: int = 53, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + keyname: Optional[Union[dns.name.Name, str]] = None, + relativize: bool = True, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + serial: int = 0, + use_udp: bool = False, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Any: """Return a generator for the responses to a zone transfer. *where*, a ``str`` containing an IPv4 or IPv6 address, where @@ -976,16 +1135,16 @@ def xfr(where: str, zone: Union[dns.name.Name, str], rdtype = dns.rdatatype.RdataType.make(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: - rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA', - '. . %u 0 0 0 0' % serial) + rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) q.authority.append(rrset) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) if use_udp and rdtype != dns.rdatatype.IXFR: - raise ValueError('cannot do a UDP AXFR') + raise ValueError("cannot do a UDP AXFR") sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM with _make_socket(af, sock_type, source) as s: (_, expiration) = _compute_times(lifetime) @@ -1009,8 +1168,9 @@ def xfr(where: str, zone: Union[dns.name.Name, str], tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if use_udp: (wire, _) = _udp_recv(s, 65535, mexpiration) @@ -1018,11 +1178,17 @@ def xfr(where: str, zone: Union[dns.name.Name, str], ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) wire = _net_read(s, l, mexpiration) - is_ixfr = (rdtype == dns.rdatatype.IXFR) - r = dns.message.from_wire(wire, keyring=q.keyring, - request_mac=q.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=True, one_rr_per_rrset=is_ixfr) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=True, + one_rr_per_rrset=is_ixfr, + ) rcode = r.rcode() if rcode != dns.rcode.NOERROR: raise TransferError(rcode) @@ -1030,8 +1196,7 @@ def xfr(where: str, zone: Union[dns.name.Name, str], answer_index = 0 if soa_rrset is None: if not r.answer or r.answer[0].name != oname: - raise dns.exception.FormError( - "No answer or RRset not for qname") + raise dns.exception.FormError("No answer or RRset not for qname") rrset = r.answer[0] if rrset.rdtype != dns.rdatatype.SOA: raise dns.exception.FormError("first RRset is not an SOA") @@ -1055,8 +1220,7 @@ def xfr(where: str, zone: Union[dns.name.Name, str], if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: if expecting_SOA: if rrset[0].serial != serial: - raise dns.exception.FormError( - "IXFR base serial mismatch") + raise dns.exception.FormError("IXFR base serial mismatch") expecting_SOA = False elif rdtype == dns.rdatatype.IXFR: delete_mode = not delete_mode @@ -1065,9 +1229,10 @@ def xfr(where: str, zone: Union[dns.name.Name, str], # finished. If this is an IXFR we also check that we're # seeing the record in the expected part of the response. # - if rrset == soa_rrset and \ - (rdtype == dns.rdatatype.AXFR or - (rdtype == dns.rdatatype.IXFR and delete_mode)): + if rrset == soa_rrset and ( + rdtype == dns.rdatatype.AXFR + or (rdtype == dns.rdatatype.IXFR and delete_mode) + ): done = True elif expecting_SOA: # @@ -1089,15 +1254,23 @@ class UDPMode(enum.IntEnum): TRY_FIRST means "try to use UDP but fall back to TCP if needed" ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" """ + NEVER = 0 TRY_FIRST = 1 ONLY = 2 -def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, - query: Optional[dns.message.Message]=None, - port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, - source: Optional[str]=None, source_port: int=0, udp_mode: UDPMode=UDPMode.NEVER) -> None: +def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, +) -> None: """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -1142,8 +1315,9 @@ def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() wire = query.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) (_, expiration) = _compute_times(lifetime) retry = True while retry: @@ -1161,14 +1335,14 @@ def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, else: tcpmsg = struct.pack("!H", len(wire)) + wire _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, - is_udp) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if is_udp: (rwire, _) = _udp_recv(s, 65535, mexpiration) @@ -1176,11 +1350,16 @@ def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) rwire = _net_read(s, l, mexpiration) - r = dns.message.from_wire(rwire, keyring=query.keyring, - request_mac=query.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) try: done = inbound.process_message(r) except dns.xfr.UseTCP: diff --git a/dns/rcode.py b/dns/rcode.py index 16e1ed4b..8e6386f8 100644 --- a/dns/rcode.py +++ b/dns/rcode.py @@ -22,6 +22,7 @@ from typing import Tuple import dns.enum import dns.exception + class Rcode(dns.enum.IntEnum): #: No error NOERROR = 0 @@ -104,7 +105,7 @@ def from_flags(flags: int, ednsflags: int) -> Rcode: Returns a ``dns.rcode.Rcode``. """ - value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) + value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0) return Rcode.make(value) @@ -119,13 +120,13 @@ def to_flags(value: Rcode) -> Tuple[int, int]: """ if value < 0 or value > 4095: - raise ValueError('rcode must be >= 0 and <= 4095') - v = value & 0xf - ev = (value & 0xff0) << 20 + raise ValueError("rcode must be >= 0 and <= 4095") + v = value & 0xF + ev = (value & 0xFF0) << 20 return (v, ev) -def to_text(value: Rcode, tsig: bool=False) -> str: +def to_text(value: Rcode, tsig: bool = False) -> str: """Convert rcode into text. *value*, a ``dns.rcode.Rcode``, the rcode. @@ -136,9 +137,10 @@ def to_text(value: Rcode, tsig: bool=False) -> str: """ if tsig and value == Rcode.BADVERS: - return 'BADSIG' + return "BADSIG" return Rcode.to_text(value) + ### BEGIN generated Rcode constants NOERROR = Rcode.NOERROR diff --git a/dns/rdata.py b/dns/rdata.py index 155e1248..dc2ad97a 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -57,21 +57,22 @@ class NoRelativeRdataOrdering(dns.exception.DNSException): """ -def _wordbreak(data, chunksize=_chunksize, separator=b' '): +def _wordbreak(data, chunksize=_chunksize, separator=b" "): """Break a binary string into chunks of chunksize characters separated by a space. """ if not chunksize: return data.decode() - return separator.join([data[i:i + chunksize] - for i - in range(0, len(data), chunksize)]).decode() + return separator.join( + [data[i : i + chunksize] for i in range(0, len(data), chunksize)] + ).decode() # pylint: disable=unused-argument -def _hexify(data, chunksize=_chunksize, separator=b' ', **kw): + +def _hexify(data, chunksize=_chunksize, separator=b" ", **kw): """Convert a binary string into its hex encoding, broken up into chunks of chunksize characters separated by a separator. """ @@ -79,17 +80,19 @@ def _hexify(data, chunksize=_chunksize, separator=b' ', **kw): return _wordbreak(binascii.hexlify(data), chunksize, separator) -def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw): +def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw): """Convert a binary string into its base64 encoding, broken up into chunks of chunksize characters separated by a separator. """ return _wordbreak(base64.b64encode(data), chunksize, separator) + # pylint: enable=unused-argument __escaped = b'"\\' + def _escapify(qstring): """Escape the characters in a quoted string which need it.""" @@ -98,14 +101,14 @@ def _escapify(qstring): if not isinstance(qstring, bytearray): qstring = bytearray(qstring) - text = '' + text = "" for c in qstring: if c in __escaped: - text += '\\' + chr(c) + text += "\\" + chr(c) elif c >= 0x20 and c < 0x7F: text += chr(c) else: - text += '\\%03d' % c + text += "\\%03d" % c return text @@ -116,9 +119,10 @@ def _truncate_bitmap(what): for i in range(len(what) - 1, -1, -1): if what[i] != 0: - return what[0: i + 1] + return what[0 : i + 1] return what[0:1] + # So we don't have to edit all the rdata classes... _constify = dns.immutable.constify @@ -127,7 +131,7 @@ _constify = dns.immutable.constify class Rdata: """Base class for all DNS rdata types.""" - __slots__ = ['rdclass', 'rdtype', 'rdcomment'] + __slots__ = ["rdclass", "rdtype", "rdcomment"] def __init__(self, rdclass, rdtype): """Initialize an rdata. @@ -142,8 +146,9 @@ class Rdata: self.rdcomment: Optional[str] = None def _get_all_slots(self): - return itertools.chain.from_iterable(getattr(cls, '__slots__', []) - for cls in self.__class__.__mro__) + return itertools.chain.from_iterable( + getattr(cls, "__slots__", []) for cls in self.__class__.__mro__ + ) def __getstate__(self): # We used to try to do a tuple of all slots here, but it @@ -162,10 +167,10 @@ class Rdata: def __setstate__(self, state): for slot, val in state.items(): object.__setattr__(self, slot, val) - if not hasattr(self, 'rdcomment'): + if not hasattr(self, "rdcomment"): # Pickled rdata from 2.0.x might not have a rdcomment, so add # it if needed. - object.__setattr__(self, 'rdcomment', None) + object.__setattr__(self, "rdcomment", None) def covers(self) -> dns.rdatatype.RdataType: """Return the type a Rdata covers. @@ -191,7 +196,12 @@ class Rdata: return self.covers() << 16 | self.rdtype - def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw: Dict[str, Any]) -> str: + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: """Convert an rdata to text format. Returns a ``str``. @@ -199,12 +209,22 @@ class Rdata: raise NotImplementedError # pragma: no cover - def _to_wire(self, file: Optional[Any], compress: Optional[dns.name.CompressType]=None, - origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes: + def _to_wire( + self, + file: Optional[Any], + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: raise NotImplementedError # pragma: no cover - def to_wire(self, file: Optional[Any]=None, compress: Optional[dns.name.CompressType]=None, - origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes: + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: """Convert an rdata to wire format. Returns a ``bytes`` or ``None``. @@ -217,15 +237,18 @@ class Rdata: self._to_wire(f, compress, origin, canonicalize) return f.getvalue() - def to_generic(self, origin: Optional[dns.name.Name]=None) -> 'dns.rdata.GenericRdata': + def to_generic( + self, origin: Optional[dns.name.Name] = None + ) -> "dns.rdata.GenericRdata": """Creates a dns.rdata.GenericRdata equivalent of this rdata. Returns a ``dns.rdata.GenericRdata``. """ - return dns.rdata.GenericRdata(self.rdclass, self.rdtype, - self.to_wire(origin=origin)) + return dns.rdata.GenericRdata( + self.rdclass, self.rdtype, self.to_wire(origin=origin) + ) - def to_digestable(self, origin: Optional[dns.name.Name]=None) -> bytes: + def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes: """Convert rdata to a format suitable for digesting in hashes. This is also the DNSSEC canonical form. @@ -237,12 +260,19 @@ class Rdata: def __repr__(self): covers = self.covers() if covers == dns.rdatatype.NONE: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(covers) + ')' - return '' + ctext = "(" + dns.rdatatype.to_text(covers) + ")" + return ( + "" + ) def __str__(self): return self.to_text() @@ -323,27 +353,39 @@ class Rdata: return not self.__eq__(other) def __lt__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) < 0 def __le__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) <= 0 def __ge__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) >= 0 def __gt__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) > 0 @@ -351,19 +393,28 @@ class Rdata: return hash(self.to_digestable(dns.name.root)) @classmethod - def from_text(cls, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize: bool=True, - relativize_to: Optional[dns.name.Name]=None) -> 'Rdata': + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> "Rdata": raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> 'Rdata': + def from_wire_parser( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, + ) -> "Rdata": raise NotImplementedError # pragma: no cover - def replace(self, **kwargs: Dict[str, Any]) -> 'Rdata': + def replace(self, **kwargs: Dict[str, Any]) -> "Rdata": """ Create a new Rdata instance based on the instance replace was invoked on. It is possible to pass different parameters to @@ -381,14 +432,20 @@ class Rdata: # Ensure that all of the arguments correspond to valid fields. # Don't allow rdclass or rdtype to be changed, though. for key in kwargs: - if key == 'rdcomment': + if key == "rdcomment": continue if key not in parameters: - raise AttributeError("'{}' object has no attribute '{}'" - .format(self.__class__.__name__, key)) - if key in ('rdclass', 'rdtype'): - raise AttributeError("Cannot overwrite '{}' attribute '{}'" - .format(self.__class__.__name__, key)) + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, key + ) + ) + if key in ("rdclass", "rdtype"): + raise AttributeError( + "Cannot overwrite '{}' attribute '{}'".format( + self.__class__.__name__, key + ) + ) # Construct the parameter list. For each field, use the value in # kwargs if present, and the current value otherwise. @@ -398,9 +455,9 @@ class Rdata: rd = self.__class__(*args) # The comment is not set in the constructor, so give it special # handling. - rdcomment = kwargs.get('rdcomment', self.rdcomment) + rdcomment = kwargs.get("rdcomment", self.rdcomment) if rdcomment is not None: - object.__setattr__(rd, 'rdcomment', rdcomment) + object.__setattr__(rd, "rdcomment", rdcomment) return rd # Type checking and conversion helpers. These are class methods as @@ -415,8 +472,13 @@ class Rdata: return dns.rdatatype.RdataType.make(value) @classmethod - def _as_bytes(cls, value: Any, encode: bool=False, max_length: Optional[int]=None, - empty_ok: bool=True) -> bytes: + def _as_bytes( + cls, + value: Any, + encode: bool = False, + max_length: Optional[int] = None, + empty_ok: bool = True, + ) -> bytes: if encode and isinstance(value, str): bvalue = value.encode() elif isinstance(value, bytearray): @@ -424,11 +486,11 @@ class Rdata: elif isinstance(value, bytes): bvalue = value else: - raise ValueError('not bytes') + raise ValueError("not bytes") if max_length is not None and len(bvalue) > max_length: - raise ValueError('too long') + raise ValueError("too long") if not empty_ok and len(bvalue) == 0: - raise ValueError('empty bytes not allowed') + raise ValueError("empty bytes not allowed") return bvalue @classmethod @@ -439,49 +501,49 @@ class Rdata: if isinstance(value, str): return dns.name.from_text(value) elif not isinstance(value, dns.name.Name): - raise ValueError('not a name') + raise ValueError("not a name") return value @classmethod def _as_uint8(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 255: - raise ValueError('not a uint8') + raise ValueError("not a uint8") return value @classmethod def _as_uint16(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 65535: - raise ValueError('not a uint16') + raise ValueError("not a uint16") return value @classmethod def _as_uint32(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 4294967295: - raise ValueError('not a uint32') + raise ValueError("not a uint32") return value @classmethod def _as_uint48(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 281474976710655: - raise ValueError('not a uint48') + raise ValueError("not a uint48") return value @classmethod def _as_int(cls, value, low=None, high=None): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if low is not None and value < low: - raise ValueError('value too small') + raise ValueError("value too small") if high is not None and value > high: - raise ValueError('value too large') + raise ValueError("value too large") return value @classmethod @@ -493,7 +555,7 @@ class Rdata: elif isinstance(value, bytes): return dns.ipv4.inet_ntoa(value) else: - raise ValueError('not an IPv4 address') + raise ValueError("not an IPv4 address") @classmethod def _as_ipv6_address(cls, value): @@ -504,14 +566,14 @@ class Rdata: elif isinstance(value, bytes): return dns.ipv6.inet_ntoa(value) else: - raise ValueError('not an IPv6 address') + raise ValueError("not an IPv6 address") @classmethod def _as_bool(cls, value): if isinstance(value, bool): return value else: - raise ValueError('not a boolean') + raise ValueError("not a boolean") @classmethod def _as_ttl(cls, value): @@ -520,7 +582,7 @@ class Rdata: elif isinstance(value, str): return dns.ttl.from_text(value) else: - raise ValueError('not a TTL') + raise ValueError("not a TTL") @classmethod def _as_tuple(cls, value, as_value): @@ -541,6 +603,7 @@ class Rdata: random.shuffle(items) return items + @dns.immutable.immutable class GenericRdata(Rdata): @@ -550,28 +613,32 @@ class GenericRdata(Rdata): implementation. It implements the DNS "unknown RRs" scheme. """ - __slots__ = ['data'] + __slots__ = ["data"] def __init__(self, rdclass, rdtype, data): super().__init__(rdclass, rdtype) self.data = data - def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw: Dict[str, Any]) -> str: - return r'\# %d ' % len(self.data) + _hexify(self.data, **kw) + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: + return r"\# %d " % len(self.data) + _hexify(self.data, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): token = tok.get() - if not token.is_identifier() or token.value != r'\#': - raise dns.exception.SyntaxError( - r'generic rdata does not start with \#') + if not token.is_identifier() or token.value != r"\#": + raise dns.exception.SyntaxError(r"generic rdata does not start with \#") length = tok.get_int() hex = tok.concatenate_remaining_identifiers(True).encode() data = binascii.unhexlify(hex) if len(data) != length: - raise dns.exception.SyntaxError( - 'generic rdata hex data has wrong length') + raise dns.exception.SyntaxError("generic rdata hex data has wrong length") return cls(rdclass, rdtype, data) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -581,8 +648,12 @@ class GenericRdata(Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): return cls(rdclass, rdtype, parser.get_remaining()) -_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = {} -_module_prefix = 'dns.rdtypes' + +_rdata_classes: Dict[ + Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any +] = {} +_module_prefix = "dns.rdtypes" + def get_rdata_class(rdclass, rdtype): cls = _rdata_classes.get((rdclass, rdtype)) @@ -591,16 +662,16 @@ def get_rdata_class(rdclass, rdtype): if not cls: rdclass_text = dns.rdataclass.to_text(rdclass) rdtype_text = dns.rdatatype.to_text(rdtype) - rdtype_text = rdtype_text.replace('-', '_') + rdtype_text = rdtype_text.replace("-", "_") try: - mod = import_module('.'.join([_module_prefix, - rdclass_text, rdtype_text])) + mod = import_module( + ".".join([_module_prefix, rdclass_text, rdtype_text]) + ) cls = getattr(mod, rdtype_text) _rdata_classes[(rdclass, rdtype)] = cls except ImportError: try: - mod = import_module('.'.join([_module_prefix, - 'ANY', rdtype_text])) + mod = import_module(".".join([_module_prefix, "ANY", rdtype_text])) cls = getattr(mod, rdtype_text) _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls @@ -612,12 +683,15 @@ def get_rdata_class(rdclass, rdtype): return cls -def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - tok: Union[dns.tokenizer.Tokenizer, str], - origin: Optional[dns.name.Name]=None, - relativize: bool=True, relativize_to: Optional[dns.name.Name]=None, - idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata: +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + tok: Union[dns.tokenizer.Tokenizer, str], + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> Rdata: """Build an rdata object from text format. This function attempts to dynamically load a class which @@ -665,17 +739,18 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], # peek at first token token = tok.get() tok.unget(token) - if token.is_identifier() and \ - token.value == r'\#': + if token.is_identifier() and token.value == r"\#": # # Known type using the generic syntax. Extract the # wire form from the generic syntax, and then run # from_wire on it. # - grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, - relativize, relativize_to) - rdata = from_wire(rdclass, rdtype, grdata.data, 0, - len(grdata.data), origin) + grdata = GenericRdata.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) + rdata = from_wire( + rdclass, rdtype, grdata.data, 0, len(grdata.data), origin + ) # # If this comparison isn't equal, then there must have been # compressed names in the wire format, which is an error, @@ -683,21 +758,27 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], # rwire = rdata.to_wire() if rwire != grdata.data: - raise dns.exception.SyntaxError('compressed data in ' - 'generic syntax form ' - 'of known rdatatype') + raise dns.exception.SyntaxError( + "compressed data in " + "generic syntax form " + "of known rdatatype" + ) if rdata is None: - rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, - relativize_to) + rdata = cls.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) token = tok.get_eol_as_token() if token.comment is not None: - object.__setattr__(rdata, 'rdcomment', token.comment) + object.__setattr__(rdata, "rdcomment", token.comment) return rdata -def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> Rdata: +def from_wire_parser( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, +) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which @@ -728,10 +809,14 @@ def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], return cls.from_wire_parser(rdclass, rdtype, parser, origin) -def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - wire: bytes, current: int, rdlen: int, - origin: Optional[dns.name.Name]=None) -> Rdata: +def from_wire( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + wire: bytes, + current: int, + rdlen: int, + origin: Optional[dns.name.Name] = None, +) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which @@ -765,13 +850,21 @@ def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], class RdatatypeExists(dns.exception.DNSException): """DNS rdatatype already exists.""" - supp_kwargs = {'rdclass', 'rdtype'} - fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \ - "already exists." + + supp_kwargs = {"rdclass", "rdtype"} + fmt = ( + "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + + "already exists." + ) -def register_type(implementation: Any, rdtype: int, rdtype_text: str, is_singleton: bool=False, - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN) -> None: +def register_type( + implementation: Any, + rdtype: int, + rdtype_text: str, + is_singleton: bool = False, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, +) -> None: """Dynamically register a module to handle an rdatatype. *implementation*, a module implementing the type in the usual dnspython @@ -797,6 +890,7 @@ def register_type(implementation: Any, rdtype: int, rdtype_text: str, is_singlet raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) except ValueError: pass - _rdata_classes[(rdclass, the_rdtype)] = getattr(implementation, - rdtype_text.replace('-', '_')) + _rdata_classes[(rdclass, the_rdtype)] = getattr( + implementation, rdtype_text.replace("-", "_") + ) dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) diff --git a/dns/rdataclass.py b/dns/rdataclass.py index 28670548..89b85a79 100644 --- a/dns/rdataclass.py +++ b/dns/rdataclass.py @@ -20,8 +20,10 @@ import dns.enum import dns.exception + class RdataClass(dns.enum.IntEnum): """DNS Rdata Class""" + RESERVED0 = 0 IN = 1 INTERNET = IN @@ -100,6 +102,7 @@ def is_metaclass(rdclass: RdataClass) -> bool: return True return False + ### BEGIN generated RdataClass constants RESERVED0 = RdataClass.RESERVED0 diff --git a/dns/rdataset.py b/dns/rdataset.py index c4b86445..072e7f72 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -17,7 +17,7 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" -from typing import Any, cast, Collection, Dict, Iterable, List, Optional, Union +from typing import Any, cast, Collection, Dict, List, Optional, Union import io import random @@ -49,11 +49,15 @@ class Rdataset(dns.set.Set): """A DNS rdataset.""" - __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] + __slots__ = ["rdclass", "rdtype", "covers", "ttl"] - def __init__(self, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, ttl: int=0): + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ttl: int = 0, + ): """Create a new rdataset of the specified class and type. *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass. @@ -94,7 +98,9 @@ class Rdataset(dns.set.Set): elif ttl < self.ttl: self.ttl = ttl - def add(self, rd: dns.rdata.Rdata, ttl: Optional[int]=None) -> None: # pylint: disable=arguments-differ + def add( + self, rd: dns.rdata.Rdata, ttl: Optional[int] = None + ) -> None: # pylint: disable=arguments-differ """Add the specified rdata to the rdataset. If the optional *ttl* parameter is supplied, then @@ -121,8 +127,7 @@ class Rdataset(dns.set.Set): raise IncompatibleTypes if ttl is not None: self.update_ttl(ttl) - if self.rdtype == dns.rdatatype.RRSIG or \ - self.rdtype == dns.rdatatype.SIG: + if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG: covers = rd.covers() if len(self) == 0 and self.covers == dns.rdatatype.NONE: self.covers = covers @@ -153,19 +158,26 @@ class Rdataset(dns.set.Set): def _rdata_repr(self): def maybe_truncate(s): if len(s) > 100: - return s[:100] + '...' + return s[:100] + "..." return s - return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr)) - for rr in self) + + return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self) def __repr__(self): if self.covers == 0: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' - return '' + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" + return ( + "" + ) def __str__(self): return self.to_text() @@ -173,20 +185,26 @@ class Rdataset(dns.set.Set): def __eq__(self, other): if not isinstance(other, Rdataset): return False - if self.rdclass != other.rdclass or \ - self.rdtype != other.rdtype or \ - self.covers != other.covers: + if ( + self.rdclass != other.rdclass + or self.rdtype != other.rdtype + or self.covers != other.covers + ): return False return super().__eq__(other) def __ne__(self, other): return not self.__eq__(other) - def to_text(self, name: Optional[dns.name.Name]=None, - origin: Optional[dns.name.Name]=None, - relativize: bool=True, - override_rdclass: Optional[dns.rdataclass.RdataClass]=None, - want_comments: bool=False, **kw: Dict[str, Any]) -> str: + def to_text( + self, + name: Optional[dns.name.Name] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_comments: bool = False, + **kw: Dict[str, Any], + ) -> str: """Convert the rdataset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -215,10 +233,10 @@ class Rdataset(dns.set.Set): if name is not None: name = name.choose_relativity(origin, relativize) ntext = str(name) - pad = ' ' + pad = " " else: - ntext = '' - pad = '' + ntext = "" + pad = "" s = io.StringIO() if override_rdclass is not None: rdclass = override_rdclass @@ -230,31 +248,46 @@ class Rdataset(dns.set.Set): # some dynamic updates, so we don't need to print out the TTL # (which is meaningless anyway). # - s.write('{}{}{} {}\n'.format(ntext, pad, - dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype))) + s.write( + "{}{}{} {}\n".format( + ntext, + pad, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + ) + ) else: for rd in self: - extra = '' + extra = "" if want_comments: if rd.rdcomment: - extra = f' ;{rd.rdcomment}' - s.write('%s%s%d %s %s %s%s\n' % - (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype), - rd.to_text(origin=origin, relativize=relativize, - **kw), - extra)) + extra = f" ;{rd.rdcomment}" + s.write( + "%s%s%d %s %s %s%s\n" + % ( + ntext, + pad, + self.ttl, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + rd.to_text(origin=origin, relativize=relativize, **kw), + extra, + ) + ) # # We strip off the final \n for the caller's convenience in printing # return s.getvalue()[:-1] - def to_wire(self, name: dns.name.Name, file: Any, - compress: Optional[dns.name.CompressType]=None, - origin: Optional[dns.name.Name]=None, - override_rdclass: Optional[dns.rdataclass.RdataClass]=None, - want_shuffle: bool=True) -> int: + def to_wire( + self, + name: dns.name.Name, + file: Any, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_shuffle: bool = True, + ) -> int: """Convert the rdataset to wire format. *name*, a ``dns.name.Name`` is the owner name to use. @@ -299,8 +332,7 @@ class Rdataset(dns.set.Set): l = self for rd in l: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, - self.ttl, 0) + stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0) file.write(stuff) start = file.tell() rd.to_wire(file, compress, origin) @@ -312,15 +344,16 @@ class Rdataset(dns.set.Set): file.seek(0, io.SEEK_END) return len(self) - def match(self, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType) -> bool: + def match( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> bool: """Returns ``True`` if this rdataset matches the specified class, type, and covers. """ - if self.rdclass == rdclass and \ - self.rdtype == rdtype and \ - self.covers == covers: + if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers: return True return False @@ -349,46 +382,47 @@ class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] def __init__(self, rdataset: Rdataset): """Create an immutable rdataset from the specified rdataset.""" - super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, - rdataset.ttl) + super().__init__( + rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl + ) self.items = dns.immutable.Dict(rdataset.items) def update_ttl(self, ttl): - raise TypeError('immutable') + raise TypeError("immutable") def add(self, rd, ttl=None): - raise TypeError('immutable') + raise TypeError("immutable") def union_update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def intersection_update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def __delitem__(self, i): - raise TypeError('immutable') + raise TypeError("immutable") # lgtm complains about these not raising ArithmeticError, but there is # precedent for overrides of these methods in other classes to raise # TypeError, and it seems like the better exception. def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method] - raise TypeError('immutable') + raise TypeError("immutable") def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method] - raise TypeError('immutable') + raise TypeError("immutable") def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method] - raise TypeError('immutable') + raise TypeError("immutable") def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method] - raise TypeError('immutable') + raise TypeError("immutable") def clear(self): - raise TypeError('immutable') + raise TypeError("immutable") def __copy__(self): return ImmutableRdataset(super().copy()) @@ -409,12 +443,16 @@ class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] return ImmutableRdataset(super().symmetric_difference(other)) -def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - ttl: int, text_rdatas: Collection[str], - idna_codec: Optional[dns.name.IDNACodec]=None, - origin: Optional[dns.name.Name]=None, - relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset: +def from_text_list( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. @@ -438,15 +476,19 @@ def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], r = Rdataset(the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, - relativize_to, idna_codec) + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) r.add(rd) return r -def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - ttl: int, *text_rdatas: Any) -> Rdataset: +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + *text_rdatas: Any, +) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified rdatas in text format. diff --git a/dns/rdatatype.py b/dns/rdatatype.py index aded5bdb..0a2854da 100644 --- a/dns/rdatatype.py +++ b/dns/rdatatype.py @@ -22,8 +22,10 @@ from typing import Dict import dns.enum import dns.exception + class RdataType(dns.enum.IntEnum): """DNS Rdata Type""" + TYPE0 = 0 NONE = 0 A = 1 @@ -122,13 +124,19 @@ class RdataType(dns.enum.IntEnum): def _unknown_exception_class(cls): return UnknownRdatatype + _registered_by_text: Dict[str, RdataType] = {} _registered_by_value: Dict[RdataType, str] = {} _metatypes = {RdataType.OPT} -_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, - RdataType.NSEC, RdataType.CNAME} +_singletons = { + RdataType.SOA, + RdataType.NXT, + RdataType.DNAME, + RdataType.NSEC, + RdataType.CNAME, +} class UnknownRdatatype(dns.exception.DNSException): @@ -150,7 +158,7 @@ def from_text(text: str) -> RdataType: Returns a ``dns.rdatatype.RdataType``. """ - text = text.upper().replace('-', '_') + text = text.upper().replace("-", "_") try: return RdataType.from_text(text) except UnknownRdatatype: @@ -176,7 +184,7 @@ def to_text(value: RdataType) -> str: registered_text = _registered_by_value.get(value) if registered_text: text = registered_text - return text.replace('_', '-') + return text.replace("_", "-") def is_metatype(rdtype: RdataType) -> bool: @@ -211,8 +219,11 @@ def is_singleton(rdtype: RdataType) -> bool: return True return False + # pylint: disable=redefined-outer-name -def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False) -> None: +def register_type( + rdtype: RdataType, rdtype_text: str, is_singleton: bool = False +) -> None: """Dynamically register an rdatatype. *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register. @@ -228,6 +239,7 @@ def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False) if is_singleton: _singletons.add(rdtype) + ### BEGIN generated RdataType constants TYPE0 = RdataType.TYPE0 diff --git a/dns/rdtypes/ANY/AMTRELAY.py b/dns/rdtypes/ANY/AMTRELAY.py index 9f093dee..dfe7abc3 100644 --- a/dns/rdtypes/ANY/AMTRELAY.py +++ b/dns/rdtypes/ANY/AMTRELAY.py @@ -23,7 +23,7 @@ import dns.rdtypes.util class Relay(dns.rdtypes.util.Gateway): - name = 'AMTRELAY relay' + name = "AMTRELAY relay" @property def relay(self): @@ -37,10 +37,11 @@ class AMTRELAY(dns.rdata.Rdata): # see: RFC 8777 - __slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay'] + __slots__ = ["precedence", "discovery_optional", "relay_type", "relay"] - def __init__(self, rdclass, rdtype, precedence, discovery_optional, - relay_type, relay): + def __init__( + self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay + ): super().__init__(rdclass, rdtype) relay = Relay(relay_type, relay) self.precedence = self._as_uint8(precedence) @@ -50,37 +51,42 @@ class AMTRELAY(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) - return '%d %d %d %s' % (self.precedence, self.discovery_optional, - self.relay_type, relay) + return "%d %d %d %s" % ( + self.precedence, + self.discovery_optional, + self.relay_type, + relay, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): precedence = tok.get_uint8() discovery_optional = tok.get_uint8() if discovery_optional > 1: - raise dns.exception.SyntaxError('expecting 0 or 1') + raise dns.exception.SyntaxError("expecting 0 or 1") discovery_optional = bool(discovery_optional) relay_type = tok.get_uint8() - if relay_type > 0x7f: - raise dns.exception.SyntaxError('expecting an integer <= 127') - relay = Relay.from_text(relay_type, tok, origin, relativize, - relativize_to) - return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay.relay) + if relay_type > 0x7F: + raise dns.exception.SyntaxError("expecting an integer <= 127") + relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): relay_type = self.relay_type | (self.discovery_optional << 7) header = struct.pack("!BB", self.precedence, relay_type) file.write(header) - Relay(self.relay_type, self.relay).to_wire(file, compress, origin, - canonicalize) + Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (precedence, relay_type) = parser.get_struct('!BB') + (precedence, relay_type) = parser.get_struct("!BB") discovery_optional = bool(relay_type >> 7) - relay_type &= 0x7f + relay_type &= 0x7F relay = Relay.from_wire_parser(relay_type, parser, origin) - return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay.relay) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) diff --git a/dns/rdtypes/ANY/CAA.py b/dns/rdtypes/ANY/CAA.py index c86b45ea..8afb538c 100644 --- a/dns/rdtypes/ANY/CAA.py +++ b/dns/rdtypes/ANY/CAA.py @@ -30,7 +30,7 @@ class CAA(dns.rdata.Rdata): # see: RFC 6844 - __slots__ = ['flags', 'tag', 'value'] + __slots__ = ["flags", "tag", "value"] def __init__(self, rdclass, rdtype, flags, tag, value): super().__init__(rdclass, rdtype) @@ -41,23 +41,26 @@ class CAA(dns.rdata.Rdata): self.value = self._as_bytes(value) def to_text(self, origin=None, relativize=True, **kw): - return '%u %s "%s"' % (self.flags, - dns.rdata._escapify(self.tag), - dns.rdata._escapify(self.value)) + return '%u %s "%s"' % ( + self.flags, + dns.rdata._escapify(self.tag), + dns.rdata._escapify(self.value), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): flags = tok.get_uint8() tag = tok.get_string().encode() value = tok.get_string().encode() return cls(rdclass, rdtype, flags, tag, value) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!B', self.flags)) + file.write(struct.pack("!B", self.flags)) l = len(self.tag) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.tag) file.write(self.value) diff --git a/dns/rdtypes/ANY/CDNSKEY.py b/dns/rdtypes/ANY/CDNSKEY.py index 7ea8f2a9..869523fb 100644 --- a/dns/rdtypes/ANY/CDNSKEY.py +++ b/dns/rdtypes/ANY/CDNSKEY.py @@ -19,9 +19,15 @@ import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 lgtm[py/unused-import] +from dns.rdtypes.dnskeybase import ( + SEP, + REVOKE, + ZONE, +) # noqa: F401 lgtm[py/unused-import] + # pylint: enable=unused-import + @dns.immutable.immutable class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/dns/rdtypes/ANY/CERT.py b/dns/rdtypes/ANY/CERT.py index f8990ebe..1b0cbeca 100644 --- a/dns/rdtypes/ANY/CERT.py +++ b/dns/rdtypes/ANY/CERT.py @@ -25,29 +25,29 @@ import dns.rdata import dns.tokenizer _ctype_by_value = { - 1: 'PKIX', - 2: 'SPKI', - 3: 'PGP', - 4: 'IPKIX', - 5: 'ISPKI', - 6: 'IPGP', - 7: 'ACPKIX', - 8: 'IACPKIX', - 253: 'URI', - 254: 'OID', + 1: "PKIX", + 2: "SPKI", + 3: "PGP", + 4: "IPKIX", + 5: "ISPKI", + 6: "IPGP", + 7: "ACPKIX", + 8: "IACPKIX", + 253: "URI", + 254: "OID", } _ctype_by_name = { - 'PKIX': 1, - 'SPKI': 2, - 'PGP': 3, - 'IPKIX': 4, - 'ISPKI': 5, - 'IPGP': 6, - 'ACPKIX': 7, - 'IACPKIX': 8, - 'URI': 253, - 'OID': 254, + "PKIX": 1, + "SPKI": 2, + "PGP": 3, + "IPKIX": 4, + "ISPKI": 5, + "IPGP": 6, + "ACPKIX": 7, + "IACPKIX": 8, + "URI": 253, + "OID": 254, } @@ -72,10 +72,11 @@ class CERT(dns.rdata.Rdata): # see RFC 4398 - __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] + __slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"] - def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, - certificate): + def __init__( + self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate + ): super().__init__(rdclass, rdtype) self.certificate_type = self._as_uint16(certificate_type) self.key_tag = self._as_uint16(key_tag) @@ -84,24 +85,28 @@ class CERT(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) - return "%s %d %s %s" % (certificate_type, self.key_tag, - dns.dnssectypes.Algorithm.to_text(self.algorithm), - dns.rdata._base64ify(self.certificate, **kw)) + return "%s %d %s %s" % ( + certificate_type, + self.key_tag, + dns.dnssectypes.Algorithm.to_text(self.algorithm), + dns.rdata._base64ify(self.certificate, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): certificate_type = _ctype_from_text(tok.get_string()) key_tag = tok.get_uint16() algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) b64 = tok.concatenate_remaining_identifiers().encode() certificate = base64.b64decode(b64) - return cls(rdclass, rdtype, certificate_type, key_tag, - algorithm, certificate) + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - prefix = struct.pack("!HHB", self.certificate_type, self.key_tag, - self.algorithm) + prefix = struct.pack( + "!HHB", self.certificate_type, self.key_tag, self.algorithm + ) file.write(prefix) file.write(self.certificate) @@ -109,5 +114,4 @@ class CERT(dns.rdata.Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") certificate = parser.get_remaining() - return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, - certificate) + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) diff --git a/dns/rdtypes/ANY/CSYNC.py b/dns/rdtypes/ANY/CSYNC.py index 979028ae..f819c08c 100644 --- a/dns/rdtypes/ANY/CSYNC.py +++ b/dns/rdtypes/ANY/CSYNC.py @@ -27,7 +27,7 @@ import dns.rdtypes.util @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'CSYNC' + type_name = "CSYNC" @dns.immutable.immutable @@ -35,7 +35,7 @@ class CSYNC(dns.rdata.Rdata): """CSYNC record""" - __slots__ = ['serial', 'flags', 'windows'] + __slots__ = ["serial", "flags", "windows"] def __init__(self, rdclass, rdtype, serial, flags, windows): super().__init__(rdclass, rdtype) @@ -47,18 +47,19 @@ class CSYNC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): text = Bitmap(self.windows).to_text() - return '%d %d%s' % (self.serial, self.flags, text) + return "%d %d%s" % (self.serial, self.flags, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): serial = tok.get_uint32() flags = tok.get_uint16() bitmap = Bitmap.from_text(tok) return cls(rdclass, rdtype, serial, flags, bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!IH', self.serial, self.flags)) + file.write(struct.pack("!IH", self.serial, self.flags)) Bitmap(self.windows).to_wire(file) @classmethod diff --git a/dns/rdtypes/ANY/DNSKEY.py b/dns/rdtypes/ANY/DNSKEY.py index cc0bf8cf..50fa05b7 100644 --- a/dns/rdtypes/ANY/DNSKEY.py +++ b/dns/rdtypes/ANY/DNSKEY.py @@ -19,9 +19,15 @@ import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 lgtm[py/unused-import] +from dns.rdtypes.dnskeybase import ( + SEP, + REVOKE, + ZONE, +) # noqa: F401 lgtm[py/unused-import] + # pylint: enable=unused-import + @dns.immutable.immutable class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/dns/rdtypes/ANY/GPOS.py b/dns/rdtypes/ANY/GPOS.py index 29fa8f8b..30aab321 100644 --- a/dns/rdtypes/ANY/GPOS.py +++ b/dns/rdtypes/ANY/GPOS.py @@ -26,19 +26,19 @@ import dns.tokenizer def _validate_float_string(what): if len(what) == 0: raise dns.exception.FormError - if what[0] == b'-'[0] or what[0] == b'+'[0]: + if what[0] == b"-"[0] or what[0] == b"+"[0]: what = what[1:] if what.isdigit(): return try: - (left, right) = what.split(b'.') + (left, right) = what.split(b".") except ValueError: raise dns.exception.FormError - if left == b'' and right == b'': + if left == b"" and right == b"": raise dns.exception.FormError - if not left == b'' and not left.decode().isdigit(): + if not left == b"" and not left.decode().isdigit(): raise dns.exception.FormError - if not right == b'' and not right.decode().isdigit(): + if not right == b"" and not right.decode().isdigit(): raise dns.exception.FormError @@ -49,18 +49,15 @@ class GPOS(dns.rdata.Rdata): # see: RFC 1712 - __slots__ = ['latitude', 'longitude', 'altitude'] + __slots__ = ["latitude", "longitude", "altitude"] def __init__(self, rdclass, rdtype, latitude, longitude, altitude): super().__init__(rdclass, rdtype) - if isinstance(latitude, float) or \ - isinstance(latitude, int): + if isinstance(latitude, float) or isinstance(latitude, int): latitude = str(latitude) - if isinstance(longitude, float) or \ - isinstance(longitude, int): + if isinstance(longitude, float) or isinstance(longitude, int): longitude = str(longitude) - if isinstance(altitude, float) or \ - isinstance(altitude, int): + if isinstance(altitude, float) or isinstance(altitude, int): altitude = str(altitude) latitude = self._as_bytes(latitude, True, 255) longitude = self._as_bytes(longitude, True, 255) @@ -73,19 +70,20 @@ class GPOS(dns.rdata.Rdata): self.altitude = altitude flat = self.float_latitude if flat < -90.0 or flat > 90.0: - raise dns.exception.FormError('bad latitude') + raise dns.exception.FormError("bad latitude") flong = self.float_longitude if flong < -180.0 or flong > 180.0: - raise dns.exception.FormError('bad longitude') + raise dns.exception.FormError("bad longitude") def to_text(self, origin=None, relativize=True, **kw): - return '{} {} {}'.format(self.latitude.decode(), - self.longitude.decode(), - self.altitude.decode()) + return "{} {} {}".format( + self.latitude.decode(), self.longitude.decode(), self.altitude.decode() + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): latitude = tok.get_string() longitude = tok.get_string() altitude = tok.get_string() @@ -94,15 +92,15 @@ class GPOS(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.latitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.latitude) l = len(self.longitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.longitude) l = len(self.altitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.altitude) @classmethod diff --git a/dns/rdtypes/ANY/HINFO.py b/dns/rdtypes/ANY/HINFO.py index cd049693..513c155a 100644 --- a/dns/rdtypes/ANY/HINFO.py +++ b/dns/rdtypes/ANY/HINFO.py @@ -30,7 +30,7 @@ class HINFO(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['cpu', 'os'] + __slots__ = ["cpu", "os"] def __init__(self, rdclass, rdtype, cpu, os): super().__init__(rdclass, rdtype) @@ -38,12 +38,14 @@ class HINFO(dns.rdata.Rdata): self.os = self._as_bytes(os, True, 255) def to_text(self, origin=None, relativize=True, **kw): - return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), - dns.rdata._escapify(self.os)) + return '"{}" "{}"'.format( + dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os) + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): cpu = tok.get_string(max_length=255) os = tok.get_string(max_length=255) return cls(rdclass, rdtype, cpu, os) @@ -51,11 +53,11 @@ class HINFO(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.cpu) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.cpu) l = len(self.os) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.os) @classmethod diff --git a/dns/rdtypes/ANY/HIP.py b/dns/rdtypes/ANY/HIP.py index e887359b..01fec822 100644 --- a/dns/rdtypes/ANY/HIP.py +++ b/dns/rdtypes/ANY/HIP.py @@ -32,7 +32,7 @@ class HIP(dns.rdata.Rdata): # see: RFC 5205 - __slots__ = ['hit', 'algorithm', 'key', 'servers'] + __slots__ = ["hit", "algorithm", "key", "servers"] def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): super().__init__(rdclass, rdtype) @@ -43,18 +43,19 @@ class HIP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() - key = base64.b64encode(self.key).replace(b'\n', b'').decode() - text = '' + key = base64.b64encode(self.key).replace(b"\n", b"").decode() + text = "" servers = [] for server in self.servers: servers.append(server.choose_relativity(origin, relativize)) if len(servers) > 0: - text += (' ' + ' '.join((x.to_unicode() for x in servers))) - return '%u %s %s%s' % (self.algorithm, hit, key, text) + text += " " + " ".join((x.to_unicode() for x in servers)) + return "%u %s %s%s" % (self.algorithm, hit, key, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() hit = binascii.unhexlify(tok.get_string().encode()) key = base64.b64decode(tok.get_string().encode()) @@ -75,7 +76,7 @@ class HIP(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (lh, algorithm, lk) = parser.get_struct('!BBH') + (lh, algorithm, lk) = parser.get_struct("!BBH") hit = parser.get_bytes(lh) key = parser.get_bytes(lk) servers = [] diff --git a/dns/rdtypes/ANY/ISDN.py b/dns/rdtypes/ANY/ISDN.py index b9a49adb..536a35d6 100644 --- a/dns/rdtypes/ANY/ISDN.py +++ b/dns/rdtypes/ANY/ISDN.py @@ -30,7 +30,7 @@ class ISDN(dns.rdata.Rdata): # see: RFC 1183 - __slots__ = ['address', 'subaddress'] + __slots__ = ["address", "subaddress"] def __init__(self, rdclass, rdtype, address, subaddress): super().__init__(rdclass, rdtype) @@ -39,31 +39,33 @@ class ISDN(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: - return '"{}" "{}"'.format(dns.rdata._escapify(self.address), - dns.rdata._escapify(self.subaddress)) + return '"{}" "{}"'.format( + dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress) + ) else: return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() tokens = tok.get_remaining(max_tokens=1) if len(tokens) >= 1: subaddress = tokens[0].unescape().value else: - subaddress = '' + subaddress = "" return cls(rdclass, rdtype, address, subaddress) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.address) l = len(self.subaddress) if l > 0: assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.subaddress) @classmethod @@ -72,5 +74,5 @@ class ISDN(dns.rdata.Rdata): if parser.remaining() > 0: subaddress = parser.get_counted_bytes() else: - subaddress = b'' + subaddress = b"" return cls(rdclass, rdtype, address, subaddress) diff --git a/dns/rdtypes/ANY/L32.py b/dns/rdtypes/ANY/L32.py index 038fc3a3..14be01f9 100644 --- a/dns/rdtypes/ANY/L32.py +++ b/dns/rdtypes/ANY/L32.py @@ -13,7 +13,7 @@ class L32(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'locator32'] + __slots__ = ["preference", "locator32"] def __init__(self, rdclass, rdtype, preference, locator32): super().__init__(rdclass, rdtype) @@ -21,17 +21,18 @@ class L32(dns.rdata.Rdata): self.locator32 = self._as_ipv4_address(locator32) def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.locator32}' + return f"{self.preference} {self.locator32}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() nodeid = tok.get_identifier() return cls(rdclass, rdtype, preference, nodeid) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) + file.write(struct.pack("!H", self.preference)) file.write(dns.ipv4.inet_aton(self.locator32)) @classmethod diff --git a/dns/rdtypes/ANY/L64.py b/dns/rdtypes/ANY/L64.py index aab36a82..d083d403 100644 --- a/dns/rdtypes/ANY/L64.py +++ b/dns/rdtypes/ANY/L64.py @@ -13,33 +13,33 @@ class L64(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'locator64'] + __slots__ = ["preference", "locator64"] def __init__(self, rdclass, rdtype, preference, locator64): super().__init__(rdclass, rdtype) self.preference = self._as_uint16(preference) if isinstance(locator64, bytes): if len(locator64) != 8: - raise ValueError('invalid locator64') - self.locator64 = dns.rdata._hexify(locator64, 4, b':') + raise ValueError("invalid locator64") + self.locator64 = dns.rdata._hexify(locator64, 4, b":") else: - dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':') + dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":") self.locator64 = locator64 def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.locator64}' + return f"{self.preference} {self.locator64}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() locator64 = tok.get_identifier() return cls(rdclass, rdtype, preference, locator64) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) - file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, - 4, 4, ':')) + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":")) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/dns/rdtypes/ANY/LOC.py b/dns/rdtypes/ANY/LOC.py index c9398994..52c97532 100644 --- a/dns/rdtypes/ANY/LOC.py +++ b/dns/rdtypes/ANY/LOC.py @@ -93,15 +93,15 @@ def _decode_size(what, desc): def _check_coordinate_list(value, low, high): if value[0] < low or value[0] > high: - raise ValueError(f'not in range [{low}, {high}]') + raise ValueError(f"not in range [{low}, {high}]") if value[1] < 0 or value[1] > 59: - raise ValueError('bad minutes value') + raise ValueError("bad minutes value") if value[2] < 0 or value[2] > 59: - raise ValueError('bad seconds value') + raise ValueError("bad seconds value") if value[3] < 0 or value[3] > 999: - raise ValueError('bad milliseconds value') + raise ValueError("bad milliseconds value") if value[4] != 1 and value[4] != -1: - raise ValueError('bad hemisphere value') + raise ValueError("bad hemisphere value") @dns.immutable.immutable @@ -111,12 +111,26 @@ class LOC(dns.rdata.Rdata): # see: RFC 1876 - __slots__ = ['latitude', 'longitude', 'altitude', 'size', - 'horizontal_precision', 'vertical_precision'] - - def __init__(self, rdclass, rdtype, latitude, longitude, altitude, - size=_default_size, hprec=_default_hprec, - vprec=_default_vprec): + __slots__ = [ + "latitude", + "longitude", + "altitude", + "size", + "horizontal_precision", + "vertical_precision", + ] + + def __init__( + self, + rdclass, + rdtype, + latitude, + longitude, + altitude, + size=_default_size, + hprec=_default_hprec, + vprec=_default_vprec, + ): """Initialize a LOC record instance. The parameters I{latitude} and I{longitude} may be either a 4-tuple @@ -145,34 +159,44 @@ class LOC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.latitude[4] > 0: - lat_hemisphere = 'N' + lat_hemisphere = "N" else: - lat_hemisphere = 'S' + lat_hemisphere = "S" if self.longitude[4] > 0: - long_hemisphere = 'E' + long_hemisphere = "E" else: - long_hemisphere = 'W' + long_hemisphere = "W" text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % ( - self.latitude[0], self.latitude[1], - self.latitude[2], self.latitude[3], lat_hemisphere, - self.longitude[0], self.longitude[1], self.longitude[2], - self.longitude[3], long_hemisphere, - self.altitude / 100.0 + self.latitude[0], + self.latitude[1], + self.latitude[2], + self.latitude[3], + lat_hemisphere, + self.longitude[0], + self.longitude[1], + self.longitude[2], + self.longitude[3], + long_hemisphere, + self.altitude / 100.0, ) # do not print default values - if self.size != _default_size or \ - self.horizontal_precision != _default_hprec or \ - self.vertical_precision != _default_vprec: + if ( + self.size != _default_size + or self.horizontal_precision != _default_hprec + or self.vertical_precision != _default_vprec + ): text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( - self.size / 100.0, self.horizontal_precision / 100.0, - self.vertical_precision / 100.0 + self.size / 100.0, + self.horizontal_precision / 100.0, + self.vertical_precision / 100.0, ) return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): latitude = [0, 0, 0, 0, 1] longitude = [0, 0, 0, 0, 1] size = _default_size @@ -184,16 +208,14 @@ class LOC(dns.rdata.Rdata): if t.isdigit(): latitude[1] = int(t) t = tok.get_string() - if '.' in t: - (seconds, milliseconds) = t.split('.') + if "." in t: + (seconds, milliseconds) = t.split(".") if not seconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad latitude seconds value') + raise dns.exception.SyntaxError("bad latitude seconds value") latitude[2] = int(seconds) l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad latitude milliseconds value') + raise dns.exception.SyntaxError("bad latitude milliseconds value") if l == 1: m = 100 elif l == 2: @@ -205,26 +227,24 @@ class LOC(dns.rdata.Rdata): elif t.isdigit(): latitude[2] = int(t) t = tok.get_string() - if t == 'S': + if t == "S": latitude[4] = -1 - elif t != 'N': - raise dns.exception.SyntaxError('bad latitude hemisphere value') + elif t != "N": + raise dns.exception.SyntaxError("bad latitude hemisphere value") longitude[0] = tok.get_int() t = tok.get_string() if t.isdigit(): longitude[1] = int(t) t = tok.get_string() - if '.' in t: - (seconds, milliseconds) = t.split('.') + if "." in t: + (seconds, milliseconds) = t.split(".") if not seconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad longitude seconds value') + raise dns.exception.SyntaxError("bad longitude seconds value") longitude[2] = int(seconds) l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad longitude milliseconds value') + raise dns.exception.SyntaxError("bad longitude milliseconds value") if l == 1: m = 100 elif l == 2: @@ -236,64 +256,75 @@ class LOC(dns.rdata.Rdata): elif t.isdigit(): longitude[2] = int(t) t = tok.get_string() - if t == 'W': + if t == "W": longitude[4] = -1 - elif t != 'E': - raise dns.exception.SyntaxError('bad longitude hemisphere value') + elif t != "E": + raise dns.exception.SyntaxError("bad longitude hemisphere value") t = tok.get_string() - if t[-1] == 'm': - t = t[0: -1] - altitude = float(t) * 100.0 # m -> cm + if t[-1] == "m": + t = t[0:-1] + altitude = float(t) * 100.0 # m -> cm tokens = tok.get_remaining(max_tokens=3) if len(tokens) >= 1: value = tokens[0].unescape().value - if value[-1] == 'm': - value = value[0: -1] - size = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + size = float(value) * 100.0 # m -> cm if len(tokens) >= 2: value = tokens[1].unescape().value - if value[-1] == 'm': - value = value[0: -1] - hprec = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + hprec = float(value) * 100.0 # m -> cm if len(tokens) >= 3: value = tokens[2].unescape().value - if value[-1] == 'm': - value = value[0: -1] - vprec = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + vprec = float(value) * 100.0 # m -> cm # Try encoding these now so we raise if they are bad _encode_size(size, "size") _encode_size(hprec, "horizontal precision") _encode_size(vprec, "vertical precision") - return cls(rdclass, rdtype, latitude, longitude, altitude, - size, hprec, vprec) + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - milliseconds = (self.latitude[0] * 3600000 + - self.latitude[1] * 60000 + - self.latitude[2] * 1000 + - self.latitude[3]) * self.latitude[4] + milliseconds = ( + self.latitude[0] * 3600000 + + self.latitude[1] * 60000 + + self.latitude[2] * 1000 + + self.latitude[3] + ) * self.latitude[4] latitude = 0x80000000 + milliseconds - milliseconds = (self.longitude[0] * 3600000 + - self.longitude[1] * 60000 + - self.longitude[2] * 1000 + - self.longitude[3]) * self.longitude[4] + milliseconds = ( + self.longitude[0] * 3600000 + + self.longitude[1] * 60000 + + self.longitude[2] * 1000 + + self.longitude[3] + ) * self.longitude[4] longitude = 0x80000000 + milliseconds altitude = int(self.altitude) + 10000000 size = _encode_size(self.size, "size") hprec = _encode_size(self.horizontal_precision, "horizontal precision") vprec = _encode_size(self.vertical_precision, "vertical precision") - wire = struct.pack("!BBBBIII", 0, size, hprec, vprec, latitude, - longitude, altitude) + wire = struct.pack( + "!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude + ) file.write(wire) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (version, size, hprec, vprec, latitude, longitude, altitude) = \ - parser.get_struct("!BBBBIII") + ( + version, + size, + hprec, + vprec, + latitude, + longitude, + altitude, + ) = parser.get_struct("!BBBBIII") if version != 0: raise dns.exception.FormError("LOC version not zero") if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: @@ -312,8 +343,7 @@ class LOC(dns.rdata.Rdata): size = _decode_size(size, "size") hprec = _decode_size(hprec, "horizontal precision") vprec = _decode_size(vprec, "vertical precision") - return cls(rdclass, rdtype, latitude, longitude, altitude, - size, hprec, vprec) + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) @property def float_latitude(self): diff --git a/dns/rdtypes/ANY/LP.py b/dns/rdtypes/ANY/LP.py index a4adffb3..8a7c5125 100644 --- a/dns/rdtypes/ANY/LP.py +++ b/dns/rdtypes/ANY/LP.py @@ -13,7 +13,7 @@ class LP(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'fqdn'] + __slots__ = ["preference", "fqdn"] def __init__(self, rdclass, rdtype, preference, fqdn): super().__init__(rdclass, rdtype) @@ -22,17 +22,18 @@ class LP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): fqdn = self.fqdn.choose_relativity(origin, relativize) - return '%d %s' % (self.preference, fqdn) + return "%d %s" % (self.preference, fqdn) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() fqdn = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, preference, fqdn) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) + file.write(struct.pack("!H", self.preference)) self.fqdn.to_wire(file, compress, origin, canonicalize) @classmethod diff --git a/dns/rdtypes/ANY/NID.py b/dns/rdtypes/ANY/NID.py index 74951bbf..ad54aca3 100644 --- a/dns/rdtypes/ANY/NID.py +++ b/dns/rdtypes/ANY/NID.py @@ -13,32 +13,33 @@ class NID(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'nodeid'] + __slots__ = ["preference", "nodeid"] def __init__(self, rdclass, rdtype, preference, nodeid): super().__init__(rdclass, rdtype) self.preference = self._as_uint16(preference) if isinstance(nodeid, bytes): if len(nodeid) != 8: - raise ValueError('invalid nodeid') - self.nodeid = dns.rdata._hexify(nodeid, 4, b':') + raise ValueError("invalid nodeid") + self.nodeid = dns.rdata._hexify(nodeid, 4, b":") else: - dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':') + dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":") self.nodeid = nodeid def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.nodeid}' + return f"{self.preference} {self.nodeid}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() nodeid = tok.get_identifier() return cls(rdclass, rdtype, preference, nodeid) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) - file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':')) + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":")) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/dns/rdtypes/ANY/NSEC.py b/dns/rdtypes/ANY/NSEC.py index dc31f4c4..7af7b77f 100644 --- a/dns/rdtypes/ANY/NSEC.py +++ b/dns/rdtypes/ANY/NSEC.py @@ -25,7 +25,7 @@ import dns.rdtypes.util @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'NSEC' + type_name = "NSEC" @dns.immutable.immutable @@ -33,7 +33,7 @@ class NSEC(dns.rdata.Rdata): """NSEC record""" - __slots__ = ['next', 'windows'] + __slots__ = ["next", "windows"] def __init__(self, rdclass, rdtype, next, windows): super().__init__(rdclass, rdtype) @@ -45,11 +45,12 @@ class NSEC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) text = Bitmap(self.windows).to_text() - return '{}{}'.format(next, text) + return "{}{}".format(next, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): next = tok.get_name(origin, relativize, relativize_to) windows = Bitmap.from_text(tok) return cls(rdclass, rdtype, next, windows) diff --git a/dns/rdtypes/ANY/NSEC3.py b/dns/rdtypes/ANY/NSEC3.py index 14242bda..6eae16e0 100644 --- a/dns/rdtypes/ANY/NSEC3.py +++ b/dns/rdtypes/ANY/NSEC3.py @@ -26,10 +26,12 @@ import dns.rdatatype import dns.rdtypes.util -b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', - b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') -b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', - b'0123456789ABCDEFGHIJKLMNOPQRSTUV') +b32_hex_to_normal = bytes.maketrans( + b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" +) +b32_normal_to_hex = bytes.maketrans( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV" +) # hash algorithm constants SHA1 = 1 @@ -40,7 +42,7 @@ OPTOUT = 1 @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'NSEC3' + type_name = "NSEC3" @dns.immutable.immutable @@ -48,10 +50,11 @@ class NSEC3(dns.rdata.Rdata): """NSEC3 record""" - __slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows'] + __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"] - def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, - next, windows): + def __init__( + self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows + ): super().__init__(rdclass, rdtype) self.algorithm = self._as_uint8(algorithm) self.flags = self._as_uint8(flags) @@ -63,38 +66,41 @@ class NSEC3(dns.rdata.Rdata): self.windows = tuple(windows.windows) def to_text(self, origin=None, relativize=True, **kw): - next = base64.b32encode(self.next).translate( - b32_normal_to_hex).lower().decode() - if self.salt == b'': - salt = '-' + next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() + if self.salt == b"": + salt = "-" else: salt = binascii.hexlify(self.salt).decode() text = Bitmap(self.windows).to_text() - return '%u %u %u %s %s%s' % (self.algorithm, self.flags, - self.iterations, salt, next, text) + return "%u %u %u %s %s%s" % ( + self.algorithm, + self.flags, + self.iterations, + salt, + next, + text, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() salt = tok.get_string() - if salt == '-': - salt = b'' + if salt == "-": + salt = b"" else: - salt = binascii.unhexlify(salt.encode('ascii')) - next = tok.get_string().encode( - 'ascii').upper().translate(b32_hex_to_normal) + salt = binascii.unhexlify(salt.encode("ascii")) + next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) next = base64.b32decode(next) bitmap = Bitmap.from_text(tok) - return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - bitmap) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) - file.write(struct.pack("!BBHB", self.algorithm, self.flags, - self.iterations, l)) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) file.write(self.salt) l = len(self.next) file.write(struct.pack("!B", l)) @@ -103,9 +109,8 @@ class NSEC3(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (algorithm, flags, iterations) = parser.get_struct('!BBH') + (algorithm, flags, iterations) = parser.get_struct("!BBH") salt = parser.get_counted_bytes() next = parser.get_counted_bytes() bitmap = Bitmap.from_wire_parser(parser) - return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - bitmap) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) diff --git a/dns/rdtypes/ANY/NSEC3PARAM.py b/dns/rdtypes/ANY/NSEC3PARAM.py index 299bf6ed..1b7269a0 100644 --- a/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/dns/rdtypes/ANY/NSEC3PARAM.py @@ -28,7 +28,7 @@ class NSEC3PARAM(dns.rdata.Rdata): """NSEC3PARAM record""" - __slots__ = ['algorithm', 'flags', 'iterations', 'salt'] + __slots__ = ["algorithm", "flags", "iterations", "salt"] def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): super().__init__(rdclass, rdtype) @@ -38,34 +38,33 @@ class NSEC3PARAM(dns.rdata.Rdata): self.salt = self._as_bytes(salt, True, 255) def to_text(self, origin=None, relativize=True, **kw): - if self.salt == b'': - salt = '-' + if self.salt == b"": + salt = "-" else: salt = binascii.hexlify(self.salt).decode() - return '%u %u %u %s' % (self.algorithm, self.flags, self.iterations, - salt) + return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() salt = tok.get_string() - if salt == '-': - salt = '' + if salt == "-": + salt = "" else: salt = binascii.unhexlify(salt.encode()) return cls(rdclass, rdtype, algorithm, flags, iterations, salt) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) - file.write(struct.pack("!BBHB", self.algorithm, self.flags, - self.iterations, l)) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) file.write(self.salt) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (algorithm, flags, iterations) = parser.get_struct('!BBH') + (algorithm, flags, iterations) = parser.get_struct("!BBH") salt = parser.get_counted_bytes() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) diff --git a/dns/rdtypes/ANY/OPENPGPKEY.py b/dns/rdtypes/ANY/OPENPGPKEY.py index dcfa028d..e5e25727 100644 --- a/dns/rdtypes/ANY/OPENPGPKEY.py +++ b/dns/rdtypes/ANY/OPENPGPKEY.py @@ -22,6 +22,7 @@ import dns.immutable import dns.rdata import dns.tokenizer + @dns.immutable.immutable class OPENPGPKEY(dns.rdata.Rdata): @@ -37,8 +38,9 @@ class OPENPGPKEY(dns.rdata.Rdata): return dns.rdata._base64ify(self.key, chunksize=None, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, key) diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py index 69b8fe75..36d4c7c6 100644 --- a/dns/rdtypes/ANY/OPT.py +++ b/dns/rdtypes/ANY/OPT.py @@ -26,12 +26,13 @@ import dns.rdata # We don't implement from_text, and that's ok. # pylint: disable=abstract-method + @dns.immutable.immutable class OPT(dns.rdata.Rdata): """OPT record""" - __slots__ = ['options'] + __slots__ = ["options"] def __init__(self, rdclass, rdtype, options): """Initialize an OPT rdata. @@ -45,10 +46,12 @@ class OPT(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) + def as_option(option): if not isinstance(option, dns.edns.Option): - raise ValueError('option is not a dns.edns.option') + raise ValueError("option is not a dns.edns.option") return option + self.options = self._as_tuple(options, as_option) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -58,13 +61,13 @@ class OPT(dns.rdata.Rdata): file.write(owire) def to_text(self, origin=None, relativize=True, **kw): - return ' '.join(opt.to_text() for opt in self.options) + return " ".join(opt.to_text() for opt in self.options) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): options = [] while parser.remaining() > 0: - (otype, olen) = parser.get_struct('!HH') + (otype, olen) = parser.get_struct("!HH") with parser.restrict_to(olen): opt = dns.edns.option_from_wire_parser(otype, parser) options.append(opt) diff --git a/dns/rdtypes/ANY/RP.py b/dns/rdtypes/ANY/RP.py index a4e2297d..c0c316b5 100644 --- a/dns/rdtypes/ANY/RP.py +++ b/dns/rdtypes/ANY/RP.py @@ -28,7 +28,7 @@ class RP(dns.rdata.Rdata): # see: RFC 1183 - __slots__ = ['mbox', 'txt'] + __slots__ = ["mbox", "txt"] def __init__(self, rdclass, rdtype, mbox, txt): super().__init__(rdclass, rdtype) @@ -41,8 +41,9 @@ class RP(dns.rdata.Rdata): return "{} {}".format(str(mbox), str(txt)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): mbox = tok.get_name(origin, relativize, relativize_to) txt = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, mbox, txt) diff --git a/dns/rdtypes/ANY/RRSIG.py b/dns/rdtypes/ANY/RRSIG.py index 82650c0f..3d5ad0f3 100644 --- a/dns/rdtypes/ANY/RRSIG.py +++ b/dns/rdtypes/ANY/RRSIG.py @@ -43,12 +43,11 @@ def sigtime_to_posixtime(what): hour = int(what[8:10]) minute = int(what[10:12]) second = int(what[12:14]) - return calendar.timegm((year, month, day, hour, minute, second, - 0, 0, 0)) + return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0)) def posixtime_to_sigtime(what): - return time.strftime('%Y%m%d%H%M%S', time.gmtime(what)) + return time.strftime("%Y%m%d%H%M%S", time.gmtime(what)) @dns.immutable.immutable @@ -56,13 +55,32 @@ class RRSIG(dns.rdata.Rdata): """RRSIG record""" - __slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl', - 'expiration', 'inception', 'key_tag', 'signer', - 'signature'] - - def __init__(self, rdclass, rdtype, type_covered, algorithm, labels, - original_ttl, expiration, inception, key_tag, signer, - signature): + __slots__ = [ + "type_covered", + "algorithm", + "labels", + "original_ttl", + "expiration", + "inception", + "key_tag", + "signer", + "signature", + ] + + def __init__( + self, + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ): super().__init__(rdclass, rdtype) self.type_covered = self._as_rdatatype(type_covered) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) @@ -78,7 +96,7 @@ class RRSIG(dns.rdata.Rdata): return self.type_covered def to_text(self, origin=None, relativize=True, **kw): - return '%s %d %d %d %s %s %d %s %s' % ( + return "%s %d %d %d %s %s %d %s %s" % ( dns.rdatatype.to_text(self.type_covered), self.algorithm, self.labels, @@ -87,12 +105,13 @@ class RRSIG(dns.rdata.Rdata): posixtime_to_sigtime(self.inception), self.key_tag, self.signer.choose_relativity(origin, relativize), - dns.rdata._base64ify(self.signature, **kw) + dns.rdata._base64ify(self.signature, **kw), ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): type_covered = dns.rdatatype.from_text(tok.get_string()) algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) labels = tok.get_int() @@ -103,22 +122,38 @@ class RRSIG(dns.rdata.Rdata): signer = tok.get_name(origin, relativize, relativize_to) b64 = tok.concatenate_remaining_identifiers().encode() signature = base64.b64decode(b64) - return cls(rdclass, rdtype, type_covered, algorithm, labels, - original_ttl, expiration, inception, key_tag, signer, - signature) + return cls( + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack('!HBBIIIH', self.type_covered, - self.algorithm, self.labels, - self.original_ttl, self.expiration, - self.inception, self.key_tag) + header = struct.pack( + "!HBBIIIH", + self.type_covered, + self.algorithm, + self.labels, + self.original_ttl, + self.expiration, + self.inception, + self.key_tag, + ) file.write(header) self.signer.to_wire(file, None, origin, canonicalize) file.write(self.signature) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!HBBIIIH') + header = parser.get_struct("!HBBIIIH") signer = parser.get_name(origin) signature = parser.get_remaining() return cls(rdclass, rdtype, *header, signer, signature) diff --git a/dns/rdtypes/ANY/SOA.py b/dns/rdtypes/ANY/SOA.py index 7ce88652..6f6fe58b 100644 --- a/dns/rdtypes/ANY/SOA.py +++ b/dns/rdtypes/ANY/SOA.py @@ -30,11 +30,11 @@ class SOA(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire', - 'minimum'] + __slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"] - def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, - expire, minimum): + def __init__( + self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ): super().__init__(rdclass, rdtype) self.mname = self._as_name(mname) self.rname = self._as_name(rname) @@ -47,13 +47,20 @@ class SOA(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): mname = self.mname.choose_relativity(origin, relativize) rname = self.rname.choose_relativity(origin, relativize) - return '%s %s %d %d %d %d %d' % ( - mname, rname, self.serial, self.refresh, self.retry, - self.expire, self.minimum) + return "%s %s %d %d %d %d %d" % ( + mname, + rname, + self.serial, + self.refresh, + self.retry, + self.expire, + self.minimum, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): mname = tok.get_name(origin, relativize, relativize_to) rname = tok.get_name(origin, relativize, relativize_to) serial = tok.get_uint32() @@ -61,18 +68,20 @@ class SOA(dns.rdata.Rdata): retry = tok.get_ttl() expire = tok.get_ttl() minimum = tok.get_ttl() - return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, - expire, minimum) + return cls( + rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.mname.to_wire(file, compress, origin, canonicalize) self.rname.to_wire(file, compress, origin, canonicalize) - five_ints = struct.pack('!IIIII', self.serial, self.refresh, - self.retry, self.expire, self.minimum) + five_ints = struct.pack( + "!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum + ) file.write(five_ints) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): mname = parser.get_name(origin) rname = parser.get_name(origin) - return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII')) + return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII")) diff --git a/dns/rdtypes/ANY/SSHFP.py b/dns/rdtypes/ANY/SSHFP.py index cc035195..58ffcbbc 100644 --- a/dns/rdtypes/ANY/SSHFP.py +++ b/dns/rdtypes/ANY/SSHFP.py @@ -30,10 +30,9 @@ class SSHFP(dns.rdata.Rdata): # See RFC 4255 - __slots__ = ['algorithm', 'fp_type', 'fingerprint'] + __slots__ = ["algorithm", "fp_type", "fingerprint"] - def __init__(self, rdclass, rdtype, algorithm, fp_type, - fingerprint): + def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): super().__init__(rdclass, rdtype) self.algorithm = self._as_uint8(algorithm) self.fp_type = self._as_uint8(fp_type) @@ -41,16 +40,17 @@ class SSHFP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %s' % (self.algorithm, - self.fp_type, - dns.rdata._hexify(self.fingerprint, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %s" % ( + self.algorithm, + self.fp_type, + dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() fp_type = tok.get_uint8() fingerprint = tok.concatenate_remaining_identifiers().encode() diff --git a/dns/rdtypes/ANY/TKEY.py b/dns/rdtypes/ANY/TKEY.py index 59ffe039..070f03af 100644 --- a/dns/rdtypes/ANY/TKEY.py +++ b/dns/rdtypes/ANY/TKEY.py @@ -28,11 +28,28 @@ class TKEY(dns.rdata.Rdata): """TKEY Record""" - __slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error', - 'key', 'other'] - - def __init__(self, rdclass, rdtype, algorithm, inception, expiration, - mode, error, key, other=b''): + __slots__ = [ + "algorithm", + "inception", + "expiration", + "mode", + "error", + "key", + "other", + ] + + def __init__( + self, + rdclass, + rdtype, + algorithm, + inception, + expiration, + mode, + error, + key, + other=b"", + ): super().__init__(rdclass, rdtype) self.algorithm = self._as_name(algorithm) self.inception = self._as_uint32(inception) @@ -44,17 +61,23 @@ class TKEY(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): _algorithm = self.algorithm.choose_relativity(origin, relativize) - text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception, - self.expiration, self.mode, self.error, - dns.rdata._base64ify(self.key, 0)) + text = "%s %u %u %u %u %s" % ( + str(_algorithm), + self.inception, + self.expiration, + self.mode, + self.error, + dns.rdata._base64ify(self.key, 0), + ) if len(self.other) > 0: - text += ' %s' % (dns.rdata._base64ify(self.other, 0)) + text += " %s" % (dns.rdata._base64ify(self.other, 0)) return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_name(relativize=False) inception = tok.get_uint32() expiration = tok.get_uint32() @@ -65,13 +88,15 @@ class TKEY(dns.rdata.Rdata): other_b64 = tok.concatenate_remaining_identifiers(True).encode() other = base64.b64decode(other_b64) - return cls(rdclass, rdtype, algorithm, inception, expiration, mode, - error, key, other) + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.algorithm.to_wire(file, compress, origin) - file.write(struct.pack("!IIHH", self.inception, self.expiration, - self.mode, self.error)) + file.write( + struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error) + ) file.write(struct.pack("!H", len(self.key))) file.write(self.key) file.write(struct.pack("!H", len(self.other))) @@ -85,8 +110,9 @@ class TKEY(dns.rdata.Rdata): key = parser.get_counted_bytes(2) other = parser.get_counted_bytes(2) - return cls(rdclass, rdtype, algorithm, inception, expiration, mode, - error, key, other) + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) # Constants for the mode field - from RFC 2930: # 2.5 The Mode Field diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py index b43a78f1..1ae87ebe 100644 --- a/dns/rdtypes/ANY/TSIG.py +++ b/dns/rdtypes/ANY/TSIG.py @@ -29,11 +29,28 @@ class TSIG(dns.rdata.Rdata): """TSIG record""" - __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', - 'original_id', 'error', 'other'] - - def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other): + __slots__ = [ + "algorithm", + "time_signed", + "fudge", + "mac", + "original_id", + "error", + "other", + ] + + def __init__( + self, + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ): """Initialize a TSIG rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata. @@ -67,45 +84,60 @@ class TSIG(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): algorithm = self.algorithm.choose_relativity(origin, relativize) error = dns.rcode.to_text(self.error, True) - text = f"{algorithm} {self.time_signed} {self.fudge} " + \ - f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ - f"{self.original_id} {error} {len(self.other)}" + text = ( + f"{algorithm} {self.time_signed} {self.fudge} " + + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + + f"{self.original_id} {error} {len(self.other)}" + ) if self.other: text += f" {dns.rdata._base64ify(self.other, 0)}" return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_name(relativize=False) time_signed = tok.get_uint48() fudge = tok.get_uint16() mac_len = tok.get_uint16() mac = base64.b64decode(tok.get_string()) if len(mac) != mac_len: - raise SyntaxError('invalid MAC') + raise SyntaxError("invalid MAC") original_id = tok.get_uint16() error = dns.rcode.from_text(tok.get_string()) other_len = tok.get_uint16() if other_len > 0: other = base64.b64decode(tok.get_string()) if len(other) != other_len: - raise SyntaxError('invalid other data') + raise SyntaxError("invalid other data") else: - other = b'' - return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other) + other = b"" + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.algorithm.to_wire(file, None, origin, False) - file.write(struct.pack('!HIHH', - (self.time_signed >> 32) & 0xffff, - self.time_signed & 0xffffffff, - self.fudge, - len(self.mac))) + file.write( + struct.pack( + "!HIHH", + (self.time_signed >> 32) & 0xFFFF, + self.time_signed & 0xFFFFFFFF, + self.fudge, + len(self.mac), + ) + ) file.write(self.mac) - file.write(struct.pack('!HHH', self.original_id, self.error, - len(self.other))) + file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other))) file.write(self.other) @classmethod @@ -114,7 +146,16 @@ class TSIG(dns.rdata.Rdata): time_signed = parser.get_uint48() fudge = parser.get_uint16() mac = parser.get_counted_bytes(2) - (original_id, error) = parser.get_struct('!HH') + (original_id, error) = parser.get_struct("!HH") other = parser.get_counted_bytes(2) - return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other) + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) diff --git a/dns/rdtypes/ANY/URI.py b/dns/rdtypes/ANY/URI.py index 524fa1ba..b4c95a3b 100644 --- a/dns/rdtypes/ANY/URI.py +++ b/dns/rdtypes/ANY/URI.py @@ -32,7 +32,7 @@ class URI(dns.rdata.Rdata): # see RFC 7553 - __slots__ = ['priority', 'weight', 'target'] + __slots__ = ["priority", "weight", "target"] def __init__(self, rdclass, rdtype, priority, weight, target): super().__init__(rdclass, rdtype) @@ -43,12 +43,12 @@ class URI(dns.rdata.Rdata): raise dns.exception.SyntaxError("URI target cannot be empty") def to_text(self, origin=None, relativize=True, **kw): - return '%d %d "%s"' % (self.priority, self.weight, - self.target.decode()) + return '%d %d "%s"' % (self.priority, self.weight, self.target.decode()) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() weight = tok.get_uint16() target = tok.get().unescape() @@ -63,10 +63,10 @@ class URI(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (priority, weight) = parser.get_struct('!HH') + (priority, weight) = parser.get_struct("!HH") target = parser.get_remaining() if len(target) == 0: - raise dns.exception.FormError('URI target may not be empty') + raise dns.exception.FormError("URI target may not be empty") return cls(rdclass, rdtype, priority, weight, target) def _processing_priority(self): diff --git a/dns/rdtypes/ANY/X25.py b/dns/rdtypes/ANY/X25.py index 4f7230c0..06c14534 100644 --- a/dns/rdtypes/ANY/X25.py +++ b/dns/rdtypes/ANY/X25.py @@ -30,7 +30,7 @@ class X25(dns.rdata.Rdata): # see RFC 1183 - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -40,15 +40,16 @@ class X25(dns.rdata.Rdata): return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() return cls(rdclass, rdtype, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.address) @classmethod diff --git a/dns/rdtypes/ANY/ZONEMD.py b/dns/rdtypes/ANY/ZONEMD.py index 75f99e5e..1f86ba49 100644 --- a/dns/rdtypes/ANY/ZONEMD.py +++ b/dns/rdtypes/ANY/ZONEMD.py @@ -16,7 +16,7 @@ class ZONEMD(dns.rdata.Rdata): # See RFC 8976 - __slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest'] + __slots__ = ["serial", "scheme", "hash_algorithm", "digest"] def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): super().__init__(rdclass, rdtype) @@ -26,25 +26,28 @@ class ZONEMD(dns.rdata.Rdata): self.digest = self._as_bytes(digest) if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 - raise ValueError('scheme 0 is reserved') + raise ValueError("scheme 0 is reserved") if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 - raise ValueError('hash_algorithm 0 is reserved') + raise ValueError("hash_algorithm 0 is reserved") hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm) if hasher and hasher().digest_size != len(self.digest): - raise ValueError('digest length inconsistent with hash algorithm') + raise ValueError("digest length inconsistent with hash algorithm") def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm, - dns.rdata._hexify(self.digest, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.serial, + self.scheme, + self.hash_algorithm, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): serial = tok.get_uint32() scheme = tok.get_uint8() hash_algorithm = tok.get_uint8() @@ -53,8 +56,7 @@ class ZONEMD(dns.rdata.Rdata): return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!IBB", self.serial, self.scheme, - self.hash_algorithm) + header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm) file.write(header) file.write(self.digest) diff --git a/dns/rdtypes/ANY/__init__.py b/dns/rdtypes/ANY/__init__.py index 2cadcde3..3824a0a0 100644 --- a/dns/rdtypes/ANY/__init__.py +++ b/dns/rdtypes/ANY/__init__.py @@ -18,51 +18,51 @@ """Class ANY (generic) rdata type classes.""" __all__ = [ - 'AFSDB', - 'AMTRELAY', - 'AVC', - 'CAA', - 'CDNSKEY', - 'CDS', - 'CERT', - 'CNAME', - 'CSYNC', - 'DLV', - 'DNAME', - 'DNSKEY', - 'DS', - 'EUI48', - 'EUI64', - 'GPOS', - 'HINFO', - 'HIP', - 'ISDN', - 'L32', - 'L64', - 'LOC', - 'LP', - 'MX', - 'NID', - 'NINFO', - 'NS', - 'NSEC', - 'NSEC3', - 'NSEC3PARAM', - 'OPENPGPKEY', - 'OPT', - 'PTR', - 'RP', - 'RRSIG', - 'RT', - 'SMIMEA', - 'SOA', - 'SPF', - 'SSHFP', - 'TKEY', - 'TLSA', - 'TSIG', - 'TXT', - 'URI', - 'X25', - 'ZONEMD', + "AFSDB", + "AMTRELAY", + "AVC", + "CAA", + "CDNSKEY", + "CDS", + "CERT", + "CNAME", + "CSYNC", + "DLV", + "DNAME", + "DNSKEY", + "DS", + "EUI48", + "EUI64", + "GPOS", + "HINFO", + "HIP", + "ISDN", + "L32", + "L64", + "LOC", + "LP", + "MX", + "NID", + "NINFO", + "NS", + "NSEC", + "NSEC3", + "NSEC3PARAM", + "OPENPGPKEY", + "OPT", + "PTR", + "RP", + "RRSIG", + "RT", + "SMIMEA", + "SOA", + "SPF", + "SSHFP", + "TKEY", + "TLSA", + "TSIG", + "TXT", + "URI", + "X25", + "ZONEMD", ] diff --git a/dns/rdtypes/CH/A.py b/dns/rdtypes/CH/A.py index 828701b4..9905c7c9 100644 --- a/dns/rdtypes/CH/A.py +++ b/dns/rdtypes/CH/A.py @@ -20,6 +20,7 @@ import struct import dns.rdtypes.mxbase import dns.immutable + @dns.immutable.immutable class A(dns.rdata.Rdata): @@ -28,7 +29,7 @@ class A(dns.rdata.Rdata): # domain: the domain of the address # address: the 16-bit address - __slots__ = ['domain', 'address'] + __slots__ = ["domain", "address"] def __init__(self, rdclass, rdtype, domain, address): super().__init__(rdclass, rdtype) @@ -37,11 +38,12 @@ class A(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): domain = self.domain.choose_relativity(origin, relativize) - return '%s %o' % (domain, self.address) + return "%s %o" % (domain, self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): domain = tok.get_name(origin, relativize, relativize_to) address = tok.get_uint16(base=8) return cls(rdclass, rdtype, domain, address) diff --git a/dns/rdtypes/CH/__init__.py b/dns/rdtypes/CH/__init__.py index 7184a733..0760c26c 100644 --- a/dns/rdtypes/CH/__init__.py +++ b/dns/rdtypes/CH/__init__.py @@ -18,5 +18,5 @@ """Class CH rdata type classes.""" __all__ = [ - 'A', + "A", ] diff --git a/dns/rdtypes/IN/A.py b/dns/rdtypes/IN/A.py index 74b591ef..713d5eea 100644 --- a/dns/rdtypes/IN/A.py +++ b/dns/rdtypes/IN/A.py @@ -27,7 +27,7 @@ class A(dns.rdata.Rdata): """A record.""" - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -37,8 +37,9 @@ class A(dns.rdata.Rdata): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_identifier() return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/AAAA.py b/dns/rdtypes/IN/AAAA.py index 2d3ec902..f8237b44 100644 --- a/dns/rdtypes/IN/AAAA.py +++ b/dns/rdtypes/IN/AAAA.py @@ -27,7 +27,7 @@ class AAAA(dns.rdata.Rdata): """AAAA record.""" - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -37,8 +37,9 @@ class AAAA(dns.rdata.Rdata): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_identifier() return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/APL.py b/dns/rdtypes/IN/APL.py index ae94fb24..05e1689f 100644 --- a/dns/rdtypes/IN/APL.py +++ b/dns/rdtypes/IN/APL.py @@ -26,12 +26,13 @@ import dns.ipv6 import dns.rdata import dns.tokenizer + @dns.immutable.immutable class APLItem: """An APL list item.""" - __slots__ = ['family', 'negation', 'address', 'prefix'] + __slots__ = ["family", "negation", "address", "prefix"] def __init__(self, family, negation, address, prefix): self.family = dns.rdata.Rdata._as_uint16(family) @@ -67,12 +68,12 @@ class APLItem: if address[i] != 0: last = i + 1 break - address = address[0: last] + address = address[0:last] l = len(address) assert l < 128 if self.negation: l |= 0x80 - header = struct.pack('!HBB', self.family, self.prefix, l) + header = struct.pack("!HBB", self.family, self.prefix, l) file.write(header) file.write(address) @@ -84,32 +85,33 @@ class APL(dns.rdata.Rdata): # see: RFC 3123 - __slots__ = ['items'] + __slots__ = ["items"] def __init__(self, rdclass, rdtype, items): super().__init__(rdclass, rdtype) for item in items: if not isinstance(item, APLItem): - raise ValueError('item not an APLItem') + raise ValueError("item not an APLItem") self.items = tuple(items) def to_text(self, origin=None, relativize=True, **kw): - return ' '.join(map(str, self.items)) + return " ".join(map(str, self.items)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): items = [] for token in tok.get_remaining(): item = token.unescape().value - if item[0] == '!': + if item[0] == "!": negation = True item = item[1:] else: negation = False - (family, rest) = item.split(':', 1) + (family, rest) = item.split(":", 1) family = int(family) - (address, prefix) = rest.split('/', 1) + (address, prefix) = rest.split("/", 1) prefix = int(prefix) item = APLItem(family, negation, address, prefix) items.append(item) @@ -125,7 +127,7 @@ class APL(dns.rdata.Rdata): items = [] while parser.remaining() > 0: - header = parser.get_struct('!HBB') + header = parser.get_struct("!HBB") afdlen = header[2] if afdlen > 127: negation = True @@ -136,16 +138,16 @@ class APL(dns.rdata.Rdata): l = len(address) if header[0] == 1: if l < 4: - address += b'\x00' * (4 - l) + address += b"\x00" * (4 - l) elif header[0] == 2: if l < 16: - address += b'\x00' * (16 - l) + address += b"\x00" * (16 - l) else: # # This isn't really right according to the RFC, but it # seems better than throwing an exception # - address = codecs.encode(address, 'hex_codec') + address = codecs.encode(address, "hex_codec") item = APLItem(header[0], negation, address, header[1]) items.append(item) return cls(rdclass, rdtype, items) diff --git a/dns/rdtypes/IN/DHCID.py b/dns/rdtypes/IN/DHCID.py index c1c70b46..65f85897 100644 --- a/dns/rdtypes/IN/DHCID.py +++ b/dns/rdtypes/IN/DHCID.py @@ -29,7 +29,7 @@ class DHCID(dns.rdata.Rdata): # see: RFC 4701 - __slots__ = ['data'] + __slots__ = ["data"] def __init__(self, rdclass, rdtype, data): super().__init__(rdclass, rdtype) @@ -39,8 +39,9 @@ class DHCID(dns.rdata.Rdata): return dns.rdata._base64ify(self.data, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): b64 = tok.concatenate_remaining_identifiers().encode() data = base64.b64decode(b64) return cls(rdclass, rdtype, data) diff --git a/dns/rdtypes/IN/HTTPS.py b/dns/rdtypes/IN/HTTPS.py index 6a67e8ed..7797fbaf 100644 --- a/dns/rdtypes/IN/HTTPS.py +++ b/dns/rdtypes/IN/HTTPS.py @@ -3,6 +3,7 @@ import dns.rdtypes.svcbbase import dns.immutable + @dns.immutable.immutable class HTTPS(dns.rdtypes.svcbbase.SVCBBase): """HTTPS record""" diff --git a/dns/rdtypes/IN/IPSECKEY.py b/dns/rdtypes/IN/IPSECKEY.py index d1d39438..1255739f 100644 --- a/dns/rdtypes/IN/IPSECKEY.py +++ b/dns/rdtypes/IN/IPSECKEY.py @@ -24,7 +24,8 @@ import dns.rdtypes.util class Gateway(dns.rdtypes.util.Gateway): - name = 'IPSECKEY gateway' + name = "IPSECKEY gateway" + @dns.immutable.immutable class IPSECKEY(dns.rdata.Rdata): @@ -33,10 +34,11 @@ class IPSECKEY(dns.rdata.Rdata): # see: RFC 4025 - __slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key'] + __slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"] - def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, - gateway, key): + def __init__( + self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key + ): super().__init__(rdclass, rdtype) gateway = Gateway(gateway_type, gateway) self.precedence = self._as_uint8(precedence) @@ -46,38 +48,45 @@ class IPSECKEY(dns.rdata.Rdata): self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): - gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, - relativize) - return '%d %d %d %s %s' % (self.precedence, self.gateway_type, - self.algorithm, gateway, - dns.rdata._base64ify(self.key, **kw)) + gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize) + return "%d %d %d %s %s" % ( + self.precedence, + self.gateway_type, + self.algorithm, + gateway, + dns.rdata._base64ify(self.key, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): precedence = tok.get_uint8() gateway_type = tok.get_uint8() algorithm = tok.get_uint8() - gateway = Gateway.from_text(gateway_type, tok, origin, relativize, - relativize_to) + gateway = Gateway.from_text( + gateway_type, tok, origin, relativize, relativize_to + ) b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) - return cls(rdclass, rdtype, precedence, gateway_type, algorithm, - gateway.gateway, key) + return cls( + rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!BBB", self.precedence, self.gateway_type, - self.algorithm) + header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm) file.write(header) - Gateway(self.gateway_type, self.gateway).to_wire(file, compress, - origin, canonicalize) + Gateway(self.gateway_type, self.gateway).to_wire( + file, compress, origin, canonicalize + ) file.write(self.key) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!BBB') + header = parser.get_struct("!BBB") gateway_type = header[1] gateway = Gateway.from_wire_parser(gateway_type, parser, origin) key = parser.get_remaining() - return cls(rdclass, rdtype, header[0], gateway_type, header[2], - gateway.gateway, key) + return cls( + rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key + ) diff --git a/dns/rdtypes/IN/NAPTR.py b/dns/rdtypes/IN/NAPTR.py index b107974d..1f1f5a12 100644 --- a/dns/rdtypes/IN/NAPTR.py +++ b/dns/rdtypes/IN/NAPTR.py @@ -27,7 +27,7 @@ import dns.rdtypes.util def _write_string(file, s): l = len(s) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(s) @@ -38,11 +38,11 @@ class NAPTR(dns.rdata.Rdata): # see: RFC 3403 - __slots__ = ['order', 'preference', 'flags', 'service', 'regexp', - 'replacement'] + __slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"] - def __init__(self, rdclass, rdtype, order, preference, flags, service, - regexp, replacement): + def __init__( + self, rdclass, rdtype, order, preference, flags, service, regexp, replacement + ): super().__init__(rdclass, rdtype) self.flags = self._as_bytes(flags, True, 255) self.service = self._as_bytes(service, True, 255) @@ -53,24 +53,28 @@ class NAPTR(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): replacement = self.replacement.choose_relativity(origin, relativize) - return '%d %d "%s" "%s" "%s" %s' % \ - (self.order, self.preference, - dns.rdata._escapify(self.flags), - dns.rdata._escapify(self.service), - dns.rdata._escapify(self.regexp), - replacement) + return '%d %d "%s" "%s" "%s" %s' % ( + self.order, + self.preference, + dns.rdata._escapify(self.flags), + dns.rdata._escapify(self.service), + dns.rdata._escapify(self.regexp), + replacement, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): order = tok.get_uint16() preference = tok.get_uint16() flags = tok.get_string() service = tok.get_string() regexp = tok.get_string() replacement = tok.get_name(origin, relativize, relativize_to) - return cls(rdclass, rdtype, order, preference, flags, service, - regexp, replacement) + return cls( + rdclass, rdtype, order, preference, flags, service, regexp, replacement + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): two_ints = struct.pack("!HH", self.order, self.preference) @@ -82,14 +86,22 @@ class NAPTR(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (order, preference) = parser.get_struct('!HH') + (order, preference) = parser.get_struct("!HH") strings = [] for _ in range(3): s = parser.get_counted_bytes() strings.append(s) replacement = parser.get_name(origin) - return cls(rdclass, rdtype, order, preference, strings[0], strings[1], - strings[2], replacement) + return cls( + rdclass, + rdtype, + order, + preference, + strings[0], + strings[1], + strings[2], + replacement, + ) def _processing_priority(self): return (self.order, self.preference) diff --git a/dns/rdtypes/IN/NSAP.py b/dns/rdtypes/IN/NSAP.py index 23ae9b1a..be8581e6 100644 --- a/dns/rdtypes/IN/NSAP.py +++ b/dns/rdtypes/IN/NSAP.py @@ -30,7 +30,7 @@ class NSAP(dns.rdata.Rdata): # see: RFC 1706 - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -40,14 +40,15 @@ class NSAP(dns.rdata.Rdata): return "0x%s" % binascii.hexlify(self.address).decode() @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() - if address[0:2] != '0x': - raise dns.exception.SyntaxError('string does not start with 0x') - address = address[2:].replace('.', '') + if address[0:2] != "0x": + raise dns.exception.SyntaxError("string does not start with 0x") + address = address[2:].replace(".", "") if len(address) % 2 != 0: - raise dns.exception.SyntaxError('hexstring has odd length') + raise dns.exception.SyntaxError("hexstring has odd length") address = binascii.unhexlify(address.encode()) return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/PX.py b/dns/rdtypes/IN/PX.py index 113d409c..b2216d6b 100644 --- a/dns/rdtypes/IN/PX.py +++ b/dns/rdtypes/IN/PX.py @@ -31,7 +31,7 @@ class PX(dns.rdata.Rdata): # see: RFC 2163 - __slots__ = ['preference', 'map822', 'mapx400'] + __slots__ = ["preference", "map822", "mapx400"] def __init__(self, rdclass, rdtype, preference, map822, mapx400): super().__init__(rdclass, rdtype) @@ -42,11 +42,12 @@ class PX(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): map822 = self.map822.choose_relativity(origin, relativize) mapx400 = self.mapx400.choose_relativity(origin, relativize) - return '%d %s %s' % (self.preference, map822, mapx400) + return "%d %s %s" % (self.preference, map822, mapx400) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() map822 = tok.get_name(origin, relativize, relativize_to) mapx400 = tok.get_name(origin, relativize, relativize_to) diff --git a/dns/rdtypes/IN/SRV.py b/dns/rdtypes/IN/SRV.py index 5b5ff422..8b0b6bf7 100644 --- a/dns/rdtypes/IN/SRV.py +++ b/dns/rdtypes/IN/SRV.py @@ -31,7 +31,7 @@ class SRV(dns.rdata.Rdata): # see: RFC 2782 - __slots__ = ['priority', 'weight', 'port', 'target'] + __slots__ = ["priority", "weight", "port", "target"] def __init__(self, rdclass, rdtype, priority, weight, port, target): super().__init__(rdclass, rdtype) @@ -42,12 +42,12 @@ class SRV(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) - return '%d %d %d %s' % (self.priority, self.weight, self.port, - target) + return "%d %d %d %s" % (self.priority, self.weight, self.port, target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() weight = tok.get_uint16() port = tok.get_uint16() @@ -61,7 +61,7 @@ class SRV(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (priority, weight, port) = parser.get_struct('!HHH') + (priority, weight, port) = parser.get_struct("!HHH") target = parser.get_name(origin) return cls(rdclass, rdtype, priority, weight, port, target) diff --git a/dns/rdtypes/IN/SVCB.py b/dns/rdtypes/IN/SVCB.py index 14838e16..9a1ad101 100644 --- a/dns/rdtypes/IN/SVCB.py +++ b/dns/rdtypes/IN/SVCB.py @@ -3,6 +3,7 @@ import dns.rdtypes.svcbbase import dns.immutable + @dns.immutable.immutable class SVCB(dns.rdtypes.svcbbase.SVCBBase): """SVCB record""" diff --git a/dns/rdtypes/IN/WKS.py b/dns/rdtypes/IN/WKS.py index 264e45d3..a671e203 100644 --- a/dns/rdtypes/IN/WKS.py +++ b/dns/rdtypes/IN/WKS.py @@ -23,13 +23,14 @@ import dns.immutable import dns.rdata try: - _proto_tcp = socket.getprotobyname('tcp') - _proto_udp = socket.getprotobyname('udp') + _proto_tcp = socket.getprotobyname("tcp") + _proto_udp = socket.getprotobyname("udp") except OSError: # Fall back to defaults in case /etc/protocols is unavailable. _proto_tcp = 6 _proto_udp = 17 + @dns.immutable.immutable class WKS(dns.rdata.Rdata): @@ -37,7 +38,7 @@ class WKS(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['address', 'protocol', 'bitmap'] + __slots__ = ["address", "protocol", "bitmap"] def __init__(self, rdclass, rdtype, address, protocol, bitmap): super().__init__(rdclass, rdtype) @@ -51,12 +52,13 @@ class WKS(dns.rdata.Rdata): for j in range(0, 8): if byte & (0x80 >> j): bits.append(str(i * 8 + j)) - text = ' '.join(bits) - return '%s %d %s' % (self.address, self.protocol, text) + text = " ".join(bits) + return "%s %d %s" % (self.address, self.protocol, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() protocol = tok.get_string() if protocol.isdigit(): @@ -87,7 +89,7 @@ class WKS(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(dns.ipv4.inet_aton(self.address)) - protocol = struct.pack('!B', self.protocol) + protocol = struct.pack("!B", self.protocol) file.write(protocol) file.write(self.bitmap) diff --git a/dns/rdtypes/IN/__init__.py b/dns/rdtypes/IN/__init__.py index d51b99e7..dcec4dd2 100644 --- a/dns/rdtypes/IN/__init__.py +++ b/dns/rdtypes/IN/__init__.py @@ -18,18 +18,18 @@ """Class IN rdata type classes.""" __all__ = [ - 'A', - 'AAAA', - 'APL', - 'DHCID', - 'HTTPS', - 'IPSECKEY', - 'KX', - 'NAPTR', - 'NSAP', - 'NSAP_PTR', - 'PX', - 'SRV', - 'SVCB', - 'WKS', + "A", + "AAAA", + "APL", + "DHCID", + "HTTPS", + "IPSECKEY", + "KX", + "NAPTR", + "NSAP", + "NSAP_PTR", + "PX", + "SRV", + "SVCB", + "WKS", ] diff --git a/dns/rdtypes/__init__.py b/dns/rdtypes/__init__.py index c3af264e..3997f84c 100644 --- a/dns/rdtypes/__init__.py +++ b/dns/rdtypes/__init__.py @@ -18,16 +18,16 @@ """DNS rdata type classes""" __all__ = [ - 'ANY', - 'IN', - 'CH', - 'dnskeybase', - 'dsbase', - 'euibase', - 'mxbase', - 'nsbase', - 'svcbbase', - 'tlsabase', - 'txtbase', - 'util' + "ANY", + "IN", + "CH", + "dnskeybase", + "dsbase", + "euibase", + "mxbase", + "nsbase", + "svcbbase", + "tlsabase", + "txtbase", + "util", ] diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py index 832df2d7..1d17f70f 100644 --- a/dns/rdtypes/dnskeybase.py +++ b/dns/rdtypes/dnskeybase.py @@ -25,7 +25,8 @@ import dns.dnssectypes import dns.rdata # wildcard import -__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 +__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 + class Flag(enum.IntFlag): SEP = 0x0001 @@ -38,7 +39,7 @@ class DNSKEYBase(dns.rdata.Rdata): """Base class for rdata that is like a DNSKEY record""" - __slots__ = ['flags', 'protocol', 'algorithm', 'key'] + __slots__ = ["flags", "protocol", "algorithm", "key"] def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): super().__init__(rdclass, rdtype) @@ -48,12 +49,17 @@ class DNSKEYBase(dns.rdata.Rdata): self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): - return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm, - dns.rdata._base64ify(self.key, **kw)) + return "%d %d %d %s" % ( + self.flags, + self.protocol, + self.algorithm, + dns.rdata._base64ify(self.key, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): flags = tok.get_uint16() protocol = tok.get_uint8() algorithm = tok.get_string() @@ -68,10 +74,10 @@ class DNSKEYBase(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!HBB') + header = parser.get_struct("!HBB") key = parser.get_remaining() - return cls(rdclass, rdtype, header[0], header[1], header[2], - key) + return cls(rdclass, rdtype, header[0], header[1], header[2], key) + ### BEGIN generated Flag constants diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py index 3bf93acc..b6032b0f 100644 --- a/dns/rdtypes/dsbase.py +++ b/dns/rdtypes/dsbase.py @@ -29,9 +29,10 @@ class DSBase(dns.rdata.Rdata): """Base class for rdata that is like a DS record""" - __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest'] + __slots__ = ["key_tag", "algorithm", "digest_type", "digest"] - # Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml + # Digest types registry: + # https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml _digest_length_by_type = { 1: 20, # SHA-1, RFC 3658 Sec. 2.4 2: 32, # SHA-256, RFC 4509 Sec. 2.2 @@ -39,8 +40,7 @@ class DSBase(dns.rdata.Rdata): 4: 48, # SHA-384, RFC 6605 Sec. 2 } - def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, - digest): + def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): super().__init__(rdclass, rdtype) self.key_tag = self._as_uint16(key_tag) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) @@ -48,34 +48,34 @@ class DSBase(dns.rdata.Rdata): self.digest = self._as_bytes(digest) try: if len(self.digest) != self._digest_length_by_type[self.digest_type]: - raise ValueError('digest length inconsistent with digest type') + raise ValueError("digest length inconsistent with digest type") except KeyError: if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4 - raise ValueError('digest type 0 is reserved') + raise ValueError("digest type 0 is reserved") def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.key_tag, self.algorithm, - self.digest_type, - dns.rdata._hexify(self.digest, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.key_tag, + self.algorithm, + self.digest_type, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): key_tag = tok.get_uint16() algorithm = tok.get_string() digest_type = tok.get_uint8() digest = tok.concatenate_remaining_identifiers().encode() digest = binascii.unhexlify(digest) - return cls(rdclass, rdtype, key_tag, algorithm, digest_type, - digest) + return cls(rdclass, rdtype, key_tag, algorithm, digest_type, digest) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!HBB", self.key_tag, self.algorithm, - self.digest_type) + header = struct.pack("!HBB", self.key_tag, self.algorithm, self.digest_type) file.write(header) file.write(self.digest) diff --git a/dns/rdtypes/euibase.py b/dns/rdtypes/euibase.py index 48b69bd3..e524aea9 100644 --- a/dns/rdtypes/euibase.py +++ b/dns/rdtypes/euibase.py @@ -27,7 +27,7 @@ class EUIBase(dns.rdata.Rdata): # see: rfc7043.txt - __slots__ = ['eui'] + __slots__ = ["eui"] # define these in subclasses # byte_len = 6 # 0123456789ab (in hex) # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab @@ -36,28 +36,30 @@ class EUIBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.eui = self._as_bytes(eui) if len(self.eui) != self.byte_len: - raise dns.exception.FormError('EUI%s rdata has to have %s bytes' - % (self.byte_len * 8, self.byte_len)) + raise dns.exception.FormError( + "EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len) + ) def to_text(self, origin=None, relativize=True, **kw): - return dns.rdata._hexify(self.eui, chunksize=2, separator=b'-', **kw) + return dns.rdata._hexify(self.eui, chunksize=2, separator=b"-", **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): text = tok.get_string() if len(text) != cls.text_len: raise dns.exception.SyntaxError( - 'Input text must have %s characters' % cls.text_len) + "Input text must have %s characters" % cls.text_len + ) for i in range(2, cls.byte_len * 3 - 1, 3): - if text[i] != '-': - raise dns.exception.SyntaxError('Dash expected at position %s' - % i) - text = text.replace('-', '') + if text[i] != "-": + raise dns.exception.SyntaxError("Dash expected at position %s" % i) + text = text.replace("-", "") try: data = binascii.unhexlify(text.encode()) except (ValueError, TypeError) as ex: - raise dns.exception.SyntaxError('Hex decoding error: %s' % str(ex)) + raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex)) return cls(rdclass, rdtype, data) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/dns/rdtypes/mxbase.py b/dns/rdtypes/mxbase.py index 56418234..b4b9b088 100644 --- a/dns/rdtypes/mxbase.py +++ b/dns/rdtypes/mxbase.py @@ -31,7 +31,7 @@ class MXBase(dns.rdata.Rdata): """Base class for rdata that is like an MX record.""" - __slots__ = ['preference', 'exchange'] + __slots__ = ["preference", "exchange"] def __init__(self, rdclass, rdtype, preference, exchange): super().__init__(rdclass, rdtype) @@ -40,11 +40,12 @@ class MXBase(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): exchange = self.exchange.choose_relativity(origin, relativize) - return '%d %s' % (self.preference, exchange) + return "%d %s" % (self.preference, exchange) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() exchange = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, preference, exchange) diff --git a/dns/rdtypes/nsbase.py b/dns/rdtypes/nsbase.py index b3e25506..ba7a2ab7 100644 --- a/dns/rdtypes/nsbase.py +++ b/dns/rdtypes/nsbase.py @@ -28,7 +28,7 @@ class NSBase(dns.rdata.Rdata): """Base class for rdata that is like an NS record.""" - __slots__ = ['target'] + __slots__ = ["target"] def __init__(self, rdclass, rdtype, target): super().__init__(rdclass, rdtype) @@ -39,8 +39,9 @@ class NSBase(dns.rdata.Rdata): return str(target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): target = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, target) diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py index d2874996..a7bc2739 100644 --- a/dns/rdtypes/svcbbase.py +++ b/dns/rdtypes/svcbbase.py @@ -63,44 +63,48 @@ def _validate_key(key): if isinstance(key, bytes): # We decode to latin-1 so we get 0-255 as valid and do NOT interpret # UTF-8 sequences - key = key.decode('latin-1') + key = key.decode("latin-1") if isinstance(key, str): - if key.lower().startswith('key'): + if key.lower().startswith("key"): force_generic = True - if key[3:].startswith('0') and len(key) != 4: + if key[3:].startswith("0") and len(key) != 4: # key has leading zeros - raise ValueError('leading zeros in key') - key = key.replace('-', '_') + raise ValueError("leading zeros in key") + key = key.replace("-", "_") return (ParamKey.make(key), force_generic) + def key_to_text(key): - return ParamKey.to_text(key).replace('_', '-').lower() + return ParamKey.to_text(key).replace("_", "-").lower() + # Like rdata escapify, but escapes ',' too. _escaped = b'",\\' + def _escapify(qstring): - text = '' + text = "" for c in qstring: if c in _escaped: - text += '\\' + chr(c) + text += "\\" + chr(c) elif c >= 0x20 and c < 0x7F: text += chr(c) else: - text += '\\%03d' % c + text += "\\%03d" % c return text + def _unescape(value): - if value == '': + if value == "": return value - unescaped = b'' + unescaped = b"" l = len(value) i = 0 while i < l: c = value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via tokenizer get()) raise dns.exception.UnexpectedEnd c = value[i] @@ -119,7 +123,7 @@ def _unescape(value): codepoint = int(c) * 100 + int(c2) * 10 + int(c3) if codepoint > 255: raise dns.exception.SyntaxError - unescaped += b'%c' % (codepoint) + unescaped += b"%c" % (codepoint) continue unescaped += c.encode() return unescaped @@ -129,21 +133,21 @@ def _split(value): l = len(value) i = 0 items = [] - unescaped = b'' + unescaped = b"" while i < l: c = value[i] i += 1 - if c == ord('\\'): + if c == ord("\\"): if i >= l: # pragma: no cover (can't happen via tokenizer get()) raise dns.exception.UnexpectedEnd c = value[i] i += 1 - unescaped += b'%c' % (c) - elif c == ord(','): + unescaped += b"%c" % (c) + elif c == ord(","): items.append(unescaped) - unescaped = b'' + unescaped = b"" else: - unescaped += b'%c' % (c) + unescaped += b"%c" % (c) items.append(unescaped) return items @@ -159,8 +163,8 @@ class Param: @dns.immutable.immutable class GenericParam(Param): - """Generic SVCB parameter - """ + """Generic SVCB parameter""" + def __init__(self, value): self.value = dns.rdata.Rdata._as_bytes(value, True) @@ -198,19 +202,19 @@ class MandatoryParam(Param): prior_k = None for k in keys: if k == prior_k: - raise ValueError(f'duplicate key {k:d}') + raise ValueError(f"duplicate key {k:d}") prior_k = k if k == ParamKey.MANDATORY: - raise ValueError('listed the mandatory key as mandatory') + raise ValueError("listed the mandatory key as mandatory") self.keys = tuple(keys) @classmethod def from_value(cls, value): - keys = [k.encode() for k in value.split(',')] + keys = [k.encode() for k in value.split(",")] return cls(keys) def to_text(self): - return '"' + ','.join([key_to_text(key) for key in self.keys]) + '"' + return '"' + ",".join([key_to_text(key) for key in self.keys]) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -219,28 +223,29 @@ class MandatoryParam(Param): while parser.remaining() > 0: key = parser.get_uint16() if key < last_key: - raise dns.exception.FormError('manadatory keys not ascending') + raise dns.exception.FormError("manadatory keys not ascending") last_key = key keys.append(key) return cls(keys) def to_wire(self, file, origin=None): # pylint: disable=W0613 for key in self.keys: - file.write(struct.pack('!H', key)) + file.write(struct.pack("!H", key)) @dns.immutable.immutable class ALPNParam(Param): def __init__(self, ids): self.ids = dns.rdata.Rdata._as_tuple( - ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False)) + ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False) + ) @classmethod def from_value(cls, value): return cls(_split(_unescape(value))) def to_text(self): - value = ','.join([_escapify(id) for id in self.ids]) + value = ",".join([_escapify(id) for id in self.ids]) return '"' + dns.rdata._escapify(value.encode()) + '"' @classmethod @@ -253,7 +258,7 @@ class ALPNParam(Param): def to_wire(self, file, origin=None): # pylint: disable=W0613 for id in self.ids: - file.write(struct.pack('!B', len(id))) + file.write(struct.pack("!B", len(id))) file.write(id) @@ -269,10 +274,10 @@ class NoDefaultALPNParam(Param): @classmethod def from_value(cls, value): - if value is None or value == '': + if value is None or value == "": return None else: - raise ValueError('no-default-alpn with non-empty value') + raise ValueError("no-default-alpn with non-empty value") def to_text(self): raise NotImplementedError # pragma: no cover @@ -306,22 +311,23 @@ class PortParam(Param): return cls(port) def to_wire(self, file, origin=None): # pylint: disable=W0613 - file.write(struct.pack('!H', self.port)) + file.write(struct.pack("!H", self.port)) @dns.immutable.immutable class IPv4HintParam(Param): def __init__(self, addresses): self.addresses = dns.rdata.Rdata._as_tuple( - addresses, dns.rdata.Rdata._as_ipv4_address) + addresses, dns.rdata.Rdata._as_ipv4_address + ) @classmethod def from_value(cls, value): - addresses = value.split(',') + addresses = value.split(",") return cls(addresses) def to_text(self): - return '"' + ','.join(self.addresses) + '"' + return '"' + ",".join(self.addresses) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -340,15 +346,16 @@ class IPv4HintParam(Param): class IPv6HintParam(Param): def __init__(self, addresses): self.addresses = dns.rdata.Rdata._as_tuple( - addresses, dns.rdata.Rdata._as_ipv6_address) + addresses, dns.rdata.Rdata._as_ipv6_address + ) @classmethod def from_value(cls, value): - addresses = value.split(',') + addresses = value.split(",") return cls(addresses) def to_text(self): - return '"' + ','.join(self.addresses) + '"' + return '"' + ",".join(self.addresses) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -370,13 +377,13 @@ class ECHParam(Param): @classmethod def from_value(cls, value): - if '\\' in value: - raise ValueError('escape in ECH value') + if "\\" in value: + raise ValueError("escape in ECH value") value = base64.b64decode(value.encode()) return cls(value) def to_text(self): - b64 = base64.b64encode(self.ech).decode('ascii') + b64 = base64.b64encode(self.ech).decode("ascii") return f'"{b64}"' @classmethod @@ -407,7 +414,7 @@ def _validate_and_define(params, key, value): emptiness = cls.emptiness() if value is None: if emptiness == Emptiness.NEVER: - raise SyntaxError('value cannot be empty') + raise SyntaxError("value cannot be empty") value = cls.from_value(value) else: if force_generic: @@ -424,7 +431,7 @@ class SVCBBase(dns.rdata.Rdata): # see: draft-ietf-dnsop-svcb-https-01 - __slots__ = ['priority', 'target', 'params'] + __slots__ = ["priority", "target", "params"] def __init__(self, rdclass, rdtype, priority, target, params): super().__init__(rdclass, rdtype) @@ -443,12 +450,13 @@ class SVCBBase(dns.rdata.Rdata): # Note we have to say "not in" as we have None as a value # so a get() and a not None test would be wrong. if key not in params: - raise ValueError(f'key {key:d} declared mandatory but not ' - 'present') + raise ValueError( + f"key {key:d} declared mandatory but not " "present" + ) # The no-default-alpn parameter requires the alpn parameter. if ParamKey.NO_DEFAULT_ALPN in params: if ParamKey.ALPN not in params: - raise ValueError('no-default-alpn present, but alpn missing') + raise ValueError("no-default-alpn present, but alpn missing") def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) @@ -458,23 +466,24 @@ class SVCBBase(dns.rdata.Rdata): if value is None: params.append(key_to_text(key)) else: - kv = key_to_text(key) + '=' + value.to_text() + kv = key_to_text(key) + "=" + value.to_text() params.append(kv) if len(params) > 0: - space = ' ' + space = " " else: - space = '' - return '%d %s%s%s' % (self.priority, target, space, ' '.join(params)) + space = "" + return "%d %s%s%s" % (self.priority, target, space, " ".join(params)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() target = tok.get_name(origin, relativize, relativize_to) if priority == 0: token = tok.get() if not token.is_eol_or_eof(): - raise SyntaxError('parameters in AliasMode') + raise SyntaxError("parameters in AliasMode") tok.unget(token) params = {} while True: @@ -483,20 +492,20 @@ class SVCBBase(dns.rdata.Rdata): tok.unget(token) break if token.ttype != dns.tokenizer.IDENTIFIER: - raise SyntaxError('parameter is not an identifier') - equals = token.value.find('=') + raise SyntaxError("parameter is not an identifier") + equals = token.value.find("=") if equals == len(token.value) - 1: # 'key=', so next token should be a quoted string without # any intervening whitespace. key = token.value[:-1] token = tok.get(want_leading=True) if token.ttype != dns.tokenizer.QUOTED_STRING: - raise SyntaxError('whitespace after =') + raise SyntaxError("whitespace after =") value = token.value elif equals > 0: # key=value key = token.value[:equals] - value = token.value[equals + 1:] + value = token.value[equals + 1 :] elif equals == 0: # =key raise SyntaxError('parameter cannot start with "="') @@ -532,13 +541,13 @@ class SVCBBase(dns.rdata.Rdata): priority = parser.get_uint16() target = parser.get_name(origin) if priority == 0 and parser.remaining() != 0: - raise dns.exception.FormError('parameters in AliasMode') + raise dns.exception.FormError("parameters in AliasMode") params = {} prior_key = -1 while parser.remaining() > 0: key = parser.get_uint16() if key < prior_key: - raise dns.exception.FormError('keys not in order') + raise dns.exception.FormError("keys not in order") prior_key = key vlen = parser.get_uint16() pcls = _class_for_key.get(key, GenericParam) diff --git a/dns/rdtypes/tlsabase.py b/dns/rdtypes/tlsabase.py index 786fca55..a3fdc354 100644 --- a/dns/rdtypes/tlsabase.py +++ b/dns/rdtypes/tlsabase.py @@ -30,10 +30,9 @@ class TLSABase(dns.rdata.Rdata): # see: RFC 6698 - __slots__ = ['usage', 'selector', 'mtype', 'cert'] + __slots__ = ["usage", "selector", "mtype", "cert"] - def __init__(self, rdclass, rdtype, usage, selector, - mtype, cert): + def __init__(self, rdclass, rdtype, usage, selector, mtype, cert): super().__init__(rdclass, rdtype) self.usage = self._as_uint8(usage) self.selector = self._as_uint8(selector) @@ -42,17 +41,18 @@ class TLSABase(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.usage, - self.selector, - self.mtype, - dns.rdata._hexify(self.cert, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.usage, + self.selector, + self.mtype, + dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): usage = tok.get_uint8() selector = tok.get_uint8() mtype = tok.get_uint8() diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index afef98e4..d4cb9bb2 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -32,11 +32,14 @@ class TXTBase(dns.rdata.Rdata): """Base class for rdata that is like a TXT record (see RFC 1035).""" - __slots__ = ['strings'] - - def __init__(self, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - strings: Iterable[Union[bytes, str]]): + __slots__ = ["strings"] + + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + strings: Iterable[Union[bytes, str]], + ): """Initialize a TXT-like rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata. @@ -46,28 +49,41 @@ class TXTBase(dns.rdata.Rdata): *strings*, a tuple of ``bytes`` """ super().__init__(rdclass, rdtype) - self.strings: Tuple[bytes] = self._as_tuple(strings, lambda x: self._as_bytes(x, True, 255)) - - def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw: Dict[str, Any]) -> str: - txt = '' - prefix = '' + self.strings: Tuple[bytes] = self._as_tuple( + strings, lambda x: self._as_bytes(x, True, 255) + ) + + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: + txt = "" + prefix = "" for s in self.strings: txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) - prefix = ' ' + prefix = " " return txt @classmethod - def from_text(cls, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, - relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> dns.rdata.Rdata: + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.rdata.Rdata: strings = [] for token in tok.get_remaining(): token = token.unescape_to_bytes() # The 'if' below is always true in the current code, but we # are leaving this check in in case things change some day. - if not (token.is_quoted_string() or - token.is_identifier()): # pragma: no cover + if not ( + token.is_quoted_string() or token.is_identifier() + ): # pragma: no cover raise dns.exception.SyntaxError("expected a string") if len(token.value) > 255: raise dns.exception.SyntaxError("string too long") @@ -80,7 +96,7 @@ class TXTBase(dns.rdata.Rdata): for s in self.strings: l = len(s) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(s) @classmethod diff --git a/dns/rdtypes/util.py b/dns/rdtypes/util.py index 9bf8f7e9..74596f05 100644 --- a/dns/rdtypes/util.py +++ b/dns/rdtypes/util.py @@ -28,6 +28,7 @@ import dns.rdata class Gateway: """A helper class for the IPSECKEY gateway and AMTRELAY relay fields""" + name = "" def __init__(self, type, gateway=None): @@ -67,15 +68,17 @@ class Gateway: raise ValueError(self._invalid_type(self.type)) # pragma: no cover @classmethod - def from_text(cls, gateway_type, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, gateway_type, tok, origin=None, relativize=True, relativize_to=None + ): if gateway_type in (0, 1, 2): gateway = tok.get_string() elif gateway_type == 3: gateway = tok.get_name(origin, relativize, relativize_to) else: raise dns.exception.SyntaxError( - cls._invalid_type(gateway_type)) # pragma: no cover + cls._invalid_type(gateway_type) + ) # pragma: no cover return cls(gateway_type, gateway) # pylint: disable=unused-argument @@ -90,6 +93,7 @@ class Gateway: self.gateway.to_wire(file, None, origin, False) else: raise ValueError(self._invalid_type(self.type)) # pragma: no cover + # pylint: enable=unused-argument @classmethod @@ -109,6 +113,7 @@ class Gateway: class Bitmap: """A helper class for the NSEC/NSEC3/CSYNC type bitmaps""" + type_name = "" def __init__(self, windows=None): @@ -136,7 +141,7 @@ class Bitmap: if byte & (0x80 >> j): rdtype = window * 256 + i * 8 + j bits.append(dns.rdatatype.to_text(rdtype)) - text += (' ' + ' '.join(bits)) + text += " " + " ".join(bits) return text @classmethod @@ -151,7 +156,7 @@ class Bitmap: window = 0 octets = 0 prior_rdtype = 0 - bitmap = bytearray(b'\0' * 32) + bitmap = bytearray(b"\0" * 32) windows = [] for rdtype in rdtypes: if rdtype == prior_rdtype: @@ -161,7 +166,7 @@ class Bitmap: if new_window != window: if octets != 0: windows.append((window, bytes(bitmap[0:octets]))) - bitmap = bytearray(b'\0' * 32) + bitmap = bytearray(b"\0" * 32) window = new_window offset = rdtype % 256 byte = offset // 8 @@ -174,7 +179,7 @@ class Bitmap: def to_wire(self, file): for (window, bitmap) in self.windows: - file.write(struct.pack('!BB', window, len(bitmap))) + file.write(struct.pack("!BB", window, len(bitmap))) file.write(bitmap) @classmethod @@ -193,6 +198,7 @@ def _priority_table(items): by_priority[rdata._processing_priority()].append(rdata) return by_priority + def priority_processing_order(iterable): items = list(iterable) if len(items) == 1: @@ -205,8 +211,10 @@ def priority_processing_order(iterable): ordered.extend(rdatas) return ordered + _no_weight = 0.1 + def weighted_processing_order(iterable): items = list(iterable) if len(items) == 1: @@ -215,8 +223,7 @@ def weighted_processing_order(iterable): ordered = [] for k in sorted(by_priority.keys()): rdatas = by_priority[k] - total = sum(rdata._processing_weight() or _no_weight - for rdata in rdatas) + total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) while len(rdatas) > 1: r = random.uniform(0, total) for (n, rdata) in enumerate(rdatas): @@ -230,15 +237,16 @@ def weighted_processing_order(iterable): ordered.append(rdatas[0]) return ordered + def parse_formatted_hex(formatted, num_chunks, chunk_size, separator): if len(formatted) != num_chunks * (chunk_size + 1) - 1: - raise ValueError('invalid formatted hex string') - value = b'' + raise ValueError("invalid formatted hex string") + value = b"" for _ in range(num_chunks): chunk = formatted[0:chunk_size] - value += int(chunk, 16).to_bytes(chunk_size // 2, 'big') + value += int(chunk, 16).to_bytes(chunk_size // 2, "big") formatted = formatted[chunk_size:] if len(formatted) > 0 and formatted[0] != separator: - raise ValueError('invalid formatted hex string') + raise ValueError("invalid formatted hex string") formatted = formatted[1:] return value diff --git a/dns/renderer.py b/dns/renderer.py index 4e4391cd..95e8bd3a 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -88,8 +88,8 @@ class Renderer: self.compress = {} self.section = QUESTION self.counts = [0, 0, 0, 0] - self.output.write(b'\x00' * 12) - self.mac = '' + self.output.write(b"\x00" * 12) + self.mac = "" def _rollback(self, where): """Truncate the output buffer at offset *where*, and remove any @@ -160,8 +160,7 @@ class Renderer: self._set_section(section) with self._track_size(): - n = rdataset.to_wire(name, self.output, self.compress, self.origin, - **kw) + n = rdataset.to_wire(name, self.output, self.compress, self.origin, **kw) self.counts[section] += n def add_edns(self, edns, ednsflags, payload, options=None): @@ -169,12 +168,21 @@ class Renderer: # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF - ednsflags |= (edns << 16) + ednsflags |= edns << 16 opt = dns.message.Message._make_opt(ednsflags, payload, options) self.add_rrset(ADDITIONAL, opt) - def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, - request_mac, algorithm=dns.tsig.default_algorithm): + def add_tsig( + self, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): """Add a TSIG signature to the message.""" s = self.output.getvalue() @@ -183,15 +191,24 @@ class Renderer: key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, - b'', id, tsig_error, other_data) - (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), - request_mac) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac) self._write_tsig(tsig, keyname) - def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error, - other_data, request_mac, - algorithm=dns.tsig.default_algorithm): + def add_multi_tsig( + self, + ctx, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): """Add a TSIG signature to the message. Unlike add_tsig(), this can be used for a series of consecutive DNS envelopes, e.g. for a zone transfer over TCP [RFC2845, 4.4]. @@ -206,10 +223,12 @@ class Renderer: key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, - b'', id, tsig_error, other_data) - (tsig, ctx) = dns.tsig.sign(s, key, tsig[0], int(time.time()), - request_mac, ctx, True) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, ctx) = dns.tsig.sign( + s, key, tsig[0], int(time.time()), request_mac, ctx, True + ) self._write_tsig(tsig, keyname) return ctx @@ -217,17 +236,18 @@ class Renderer: self._set_section(ADDITIONAL) with self._track_size(): keyname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, - dns.rdataclass.ANY, 0, 0)) + self.output.write( + struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0) + ) rdata_start = self.output.tell() tsig.to_wire(self.output) after = self.output.tell() self.output.seek(rdata_start - 2) - self.output.write(struct.pack('!H', after - rdata_start)) + self.output.write(struct.pack("!H", after - rdata_start)) self.counts[ADDITIONAL] += 1 self.output.seek(10) - self.output.write(struct.pack('!H', self.counts[ADDITIONAL])) + self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) self.output.seek(0, io.SEEK_END) def write_header(self): @@ -239,9 +259,17 @@ class Renderer: """ self.output.seek(0) - self.output.write(struct.pack('!HHHHHH', self.id, self.flags, - self.counts[0], self.counts[1], - self.counts[2], self.counts[3])) + self.output.write( + struct.pack( + "!HHHHHH", + self.id, + self.flags, + self.counts[0], + self.counts[1], + self.counts[2], + self.counts[3], + ) + ) self.output.seek(0, io.SEEK_END) def get_wire(self): diff --git a/dns/resolver.py b/dns/resolver.py index 0b132532..f9303692 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -26,10 +26,11 @@ import sys import time import random import warnings + try: import threading as _threading except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore + import dummy_threading as _threading # type: ignore import dns.exception import dns.edns @@ -46,22 +47,24 @@ import dns.rdatatype import dns.reversename import dns.tsig -if sys.platform == 'win32': +if sys.platform == "win32": import dns.win32util + class NXDOMAIN(dns.exception.DNSException): """The DNS query name does not exist.""" - supp_kwargs = {'qnames', 'responses'} + + supp_kwargs = {"qnames", "responses"} fmt = None # we have our own __str__ implementation # pylint: disable=arguments-differ - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _check_kwargs(self, qnames, - responses=None): + def _check_kwargs(self, qnames, responses=None): if not isinstance(qnames, (list, tuple, set)): raise AttributeError("qnames must be a list, tuple or set") if len(qnames) == 0: @@ -74,23 +77,23 @@ class NXDOMAIN(dns.exception.DNSException): return kwargs def __str__(self): - if 'qnames' not in self.kwargs: + if "qnames" not in self.kwargs: return super().__str__() - qnames = self.kwargs['qnames'] + qnames = self.kwargs["qnames"] if len(qnames) > 1: - msg = 'None of DNS query names exist' + msg = "None of DNS query names exist" else: - msg = 'The DNS query name does not exist' - qnames = ', '.join(map(str, qnames)) + msg = "The DNS query name does not exist" + qnames = ", ".join(map(str, qnames)) return "{}: {}".format(msg, qnames) @property def canonical_name(self): """Return the unresolved canonical name.""" - if 'qnames' not in self.kwargs: + if "qnames" not in self.kwargs: raise TypeError("parametrized exception required") - for qname in self.kwargs['qnames']: - response = self.kwargs['responses'][qname] + for qname in self.kwargs["qnames"]: + response = self.kwargs["responses"][qname] try: cname = response.canonical_name() if cname != qname: @@ -99,14 +102,14 @@ class NXDOMAIN(dns.exception.DNSException): # We can just eat this exception as it means there was # something wrong with the response. pass - return self.kwargs['qnames'][0] + return self.kwargs["qnames"][0] def __add__(self, e_nx): """Augment by results from another NXDOMAIN exception.""" - qnames0 = list(self.kwargs.get('qnames', [])) - responses0 = dict(self.kwargs.get('responses', {})) - responses1 = e_nx.kwargs.get('responses', {}) - for qname1 in e_nx.kwargs.get('qnames', []): + qnames0 = list(self.kwargs.get("qnames", [])) + responses0 = dict(self.kwargs.get("responses", {})) + responses1 = e_nx.kwargs.get("responses", {}) + for qname1 in e_nx.kwargs.get("qnames", []): if qname1 not in qnames0: qnames0.append(qname1) if qname1 in responses1: @@ -118,7 +121,7 @@ class NXDOMAIN(dns.exception.DNSException): Returns a list of ``dns.name.Name``. """ - return self.kwargs['qnames'] + return self.kwargs["qnames"] def responses(self): """A map from queried names to their NXDOMAIN responses. @@ -126,29 +129,34 @@ class NXDOMAIN(dns.exception.DNSException): Returns a dict mapping a ``dns.name.Name`` to a ``dns.message.Message``. """ - return self.kwargs['responses'] + return self.kwargs["responses"] def response(self, qname): """The response for query *qname*. Returns a ``dns.message.Message``. """ - return self.kwargs['responses'][qname] + return self.kwargs["responses"][qname] class YXDOMAIN(dns.exception.DNSException): """The DNS query name is too long after DNAME substitution.""" -ErrorTuple = Tuple[Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]] +ErrorTuple = Tuple[ + Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message] +] def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: - texts.append('Server {} {} port {} answered {}'.format(err[0], - 'TCP' if err[1] else 'UDP', err[2], err[3])) + texts.append( + "Server {} {} port {} answered {}".format( + err[0], "TCP" if err[1] else "UDP", err[2], err[3] + ) + ) return texts @@ -157,16 +165,18 @@ class LifetimeTimeout(dns.exception.Timeout): msg = "The resolution lifetime expired." fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] - supp_kwargs = {'timeout', 'errors'} + supp_kwargs = {"timeout", "errors"} - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - srv_msgs = _errors_to_text(kwargs['errors']) - return super()._fmt_kwargs(timeout=kwargs['timeout'], - errors='; '.join(srv_msgs)) + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + timeout=kwargs["timeout"], errors="; ".join(srv_msgs) + ) # We added more detail to resolution timeouts, but they are still @@ -177,19 +187,20 @@ Timeout = LifetimeTimeout class NoAnswer(dns.exception.DNSException): """The DNS response does not contain an answer to the question.""" - fmt = 'The DNS response does not contain an answer ' + \ - 'to the question: {query}' - supp_kwargs = {'response'} - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + fmt = "The DNS response does not contain an answer " + "to the question: {query}" + supp_kwargs = {"response"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - return super()._fmt_kwargs(query=kwargs['response'].question) + return super()._fmt_kwargs(query=kwargs["response"].question) def response(self): - return self.kwargs['response'] + return self.kwargs["response"] class NoNameservers(dns.exception.DNSException): @@ -203,16 +214,18 @@ class NoNameservers(dns.exception.DNSException): msg = "All nameservers failed to answer the query." fmt = "%s {query}: {errors}" % msg[:-1] - supp_kwargs = {'request', 'errors'} + supp_kwargs = {"request", "errors"} - # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - srv_msgs = _errors_to_text(kwargs['errors']) - return super()._fmt_kwargs(query=kwargs['request'].question, - errors='; '.join(srv_msgs)) + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + query=kwargs["request"].question, errors="; ".join(srv_msgs) + ) class NotAbsolute(dns.exception.DNSException): @@ -226,9 +239,11 @@ class NoRootSOA(dns.exception.DNSException): class NoMetaqueries(dns.exception.DNSException): """DNS metaqueries are not allowed.""" + class NoResolverConfiguration(dns.exception.DNSException): """Resolver configuration could not be read or specified no nameservers.""" + class Answer: """DNS stub resolver answer. @@ -245,9 +260,15 @@ class Answer: RRset's name might not be the query name. """ - def __init__(self, qname: dns.name.Name, rdtype: dns.rdatatype.RdataType, - rdclass: dns.rdataclass.RdataClass, response: dns.message.QueryMessage, - nameserver: Optional[str]=None, port: Optional[int]=None): + def __init__( + self, + qname: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + rdclass: dns.rdataclass.RdataClass, + response: dns.message.QueryMessage, + nameserver: Optional[str] = None, + port: Optional[int] = None, + ): self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -262,15 +283,15 @@ class Answer: self.expiration = time.time() + self.chaining_result.minimum_ttl def __getattr__(self, attr): # pragma: no cover - if attr == 'name': + if attr == "name": return self.rrset.name - elif attr == 'ttl': + elif attr == "ttl": return self.rrset.ttl - elif attr == 'covers': + elif attr == "covers": return self.rrset.covers - elif attr == 'rdclass': + elif attr == "rdclass": return self.rrset.rdclass - elif attr == 'rdtype': + elif attr == "rdtype": return self.rrset.rdtype else: raise AttributeError(attr) @@ -293,8 +314,7 @@ class Answer: class CacheStatistics: - """Cache Statistics - """ + """Cache Statistics""" def __init__(self, hits=0, misses=0): self.hits = hits @@ -304,7 +324,7 @@ class CacheStatistics: self.hits = 0 self.misses = 0 - def clone(self) -> 'CacheStatistics': + def clone(self) -> "CacheStatistics": return CacheStatistics(self.hits, self.misses) @@ -345,7 +365,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" - def __init__(self, cleaning_interval: float=300.0): + def __init__(self, cleaning_interval: float = 300.0): """*cleaning_interval*, a ``float`` is the number of seconds between periodic cleanings. """ @@ -374,8 +394,8 @@ class Cache(CacheBase): Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. """ @@ -392,8 +412,8 @@ class Cache(CacheBase): def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. """ @@ -402,14 +422,14 @@ class Cache(CacheBase): self._maybe_clean() self.data[key] = value - def flush(self, key: Optional[CacheKey]=None) -> None: + def flush(self, key: Optional[CacheKey] = None) -> None: """Flush the cache. - If *key* is not ``None``, only that item is flushed. Otherwise - the entire cache is flushed. + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. """ with self.lock: @@ -452,7 +472,7 @@ class LRUCache(CacheBase): for a new one. """ - def __init__(self, max_size: int=100000): + def __init__(self, max_size: int = 100000): """*max_size*, an ``int``, is the maximum number of nodes to cache; it must be greater than 0. """ @@ -474,8 +494,8 @@ class LRUCache(CacheBase): Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. """ @@ -509,8 +529,8 @@ class LRUCache(CacheBase): def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. """ @@ -528,14 +548,14 @@ class LRUCache(CacheBase): node.link_after(self.sentinel) self.data[key] = node - def flush(self, key: Optional[CacheKey]=None) -> None: + def flush(self, key: Optional[CacheKey] = None) -> None: """Flush the cache. - If *key* is not ``None``, only that item is flushed. Otherwise - the entire cache is flushed. + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. - *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. """ with self.lock: @@ -552,6 +572,7 @@ class LRUCache(CacheBase): gnode = next self.data = {} + class _Resolution: """Helper class for dns.resolver.Resolver.resolve(). @@ -564,10 +585,16 @@ class _Resolution: resolver data structures directly. """ - def __init__(self, resolver: 'BaseResolver', qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - rdclass: Union[dns.rdataclass.RdataClass, str], - tcp: bool, raise_on_no_answer: bool, search: Optional[bool]): + def __init__( + self, + resolver: "BaseResolver", + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str], + tcp: bool, + raise_on_no_answer: bool, + search: Optional[bool], + ): if isinstance(qname, str): qname = dns.name.from_text(qname, None) the_rdtype = dns.rdatatype.RdataType.make(rdtype) @@ -596,7 +623,9 @@ class _Resolution: self.request: Optional[dns.message.QueryMessage] = None self.backoff = 0.0 - def next_request(self) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]: + def next_request( + self, + ) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]: """Get the next request to send, and check the cache. Returns a (request, answer) tuple. At most one of request or @@ -611,32 +640,37 @@ class _Resolution: # Do we know the answer? if self.resolver.cache: - answer = self.resolver.cache.get((self.qname, self.rdtype, - self.rdclass)) + answer = self.resolver.cache.get( + (self.qname, self.rdtype, self.rdclass) + ) if answer is not None: if answer.rrset is None and self.raise_on_no_answer: raise NoAnswer(response=answer.response) else: return (None, answer) - answer = self.resolver.cache.get((self.qname, - dns.rdatatype.ANY, - self.rdclass)) - if answer is not None and \ - answer.response.rcode() == dns.rcode.NXDOMAIN: + answer = self.resolver.cache.get( + (self.qname, dns.rdatatype.ANY, self.rdclass) + ) + if answer is not None and answer.response.rcode() == dns.rcode.NXDOMAIN: # cached NXDOMAIN; record it and continue to next # name. self.nxdomain_responses[self.qname] = answer.response continue # Build the request - request = dns.message.make_query(self.qname, self.rdtype, - self.rdclass) + request = dns.message.make_query(self.qname, self.rdtype, self.rdclass) if self.resolver.keyname is not None: - request.use_tsig(self.resolver.keyring, self.resolver.keyname, - algorithm=self.resolver.keyalgorithm) - request.use_edns(self.resolver.edns, self.resolver.ednsflags, - self.resolver.payload, - options=self.resolver.ednsoptions) + request.use_tsig( + self.resolver.keyring, + self.resolver.keyname, + algorithm=self.resolver.keyalgorithm, + ) + request.use_edns( + self.resolver.edns, + self.resolver.ednsflags, + self.resolver.payload, + options=self.resolver.ednsoptions, + ) if self.resolver.flags is not None: request.flags = self.resolver.flags @@ -658,8 +692,7 @@ class _Resolution: # it's only NXDOMAINs as anything else would have returned # before now.) # - raise NXDOMAIN(qnames=self.qnames_to_try, - responses=self.nxdomain_responses) + raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) def next_nameserver(self) -> Tuple[str, int, bool, float]: if self.retry_with_tcp: @@ -678,13 +711,15 @@ class _Resolution: self.backoff = min(self.backoff * 2, 2) self.nameserver = self.current_nameservers.pop(0) - self.port = self.resolver.nameserver_ports.get(self.nameserver, - self.resolver.port) + self.port = self.resolver.nameserver_ports.get( + self.nameserver, self.resolver.port + ) self.tcp_attempt = self.tcp return (self.nameserver, self.port, self.tcp_attempt, backoff) - def query_result(self, response: Optional[dns.message.Message], - ex: Optional[Exception]) -> Tuple[Optional[Answer], bool]: + def query_result( + self, response: Optional[dns.message.Message], ex: Optional[Exception] + ) -> Tuple[Optional[Answer], bool]: # # returns an (answer: Answer, end_loop: bool) tuple. # @@ -692,12 +727,15 @@ class _Resolution: if ex: # Exception during I/O or from_wire() assert response is None - self.errors.append((self.nameserver, self.tcp_attempt, self.port, - ex, response)) - if isinstance(ex, dns.exception.FormError) or \ - isinstance(ex, EOFError) or \ - isinstance(ex, OSError) or \ - isinstance(ex, NotImplementedError): + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, ex, response) + ) + if ( + isinstance(ex, dns.exception.FormError) + or isinstance(ex, EOFError) + or isinstance(ex, OSError) + or isinstance(ex, NotImplementedError) + ): # This nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) elif isinstance(ex, dns.message.Truncated): @@ -713,17 +751,23 @@ class _Resolution: rcode = response.rcode() if rcode == dns.rcode.NOERROR: try: - answer = Answer(self.qname, self.rdtype, self.rdclass, response, - self.nameserver, self.port) + answer = Answer( + self.qname, + self.rdtype, + self.rdclass, + response, + self.nameserver, + self.port, + ) except Exception as e: - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, e, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, e, response) + ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) return (None, False) if self.resolver.cache: - self.resolver.cache.put((self.qname, self.rdtype, - self.rdclass), answer) + self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer) if answer.rrset is None and self.raise_on_no_answer: raise NoAnswer(response=answer.response) return (answer, True) @@ -731,26 +775,29 @@ class _Resolution: # Further validate the response by making an Answer, even # if we aren't going to cache it. try: - answer = Answer(self.qname, dns.rdatatype.ANY, - dns.rdataclass.IN, response) + answer = Answer( + self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response + ) except Exception as e: - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, e, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, e, response) + ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) return (None, False) self.nxdomain_responses[self.qname] = response if self.resolver.cache: - self.resolver.cache.put((self.qname, - dns.rdatatype.ANY, - self.rdclass), answer) + self.resolver.cache.put( + (self.qname, dns.rdatatype.ANY, self.rdclass), answer + ) # Make next_nameserver() return None, so caller breaks its # inner loop and calls next_request(). return (None, True) elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, yex, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, yex, response) + ) raise yex else: # @@ -759,8 +806,15 @@ class _Resolution: # if rcode != dns.rcode.SERVFAIL or not self.resolver.retry_servfail: self.nameservers.remove(self.nameserver) - self.errors.append((self.nameserver, self.tcp_attempt, self.port, - dns.rcode.to_text(rcode), response)) + self.errors.append( + ( + self.nameserver, + self.tcp_attempt, + self.port, + dns.rcode.to_text(rcode), + response, + ) + ) return (None, False) @@ -791,7 +845,7 @@ class BaseResolver: rotate: bool ndots: Optional[int] - def __init__(self, filename: str='/etc/resolv.conf', configure: bool=True): + def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True): """*filename*, a ``str`` or file object, specifying a file in standard /etc/resolv.conf format. This parameter is meaningful only when *configure* is true and the platform is POSIX. @@ -805,7 +859,7 @@ class BaseResolver: self.reset() if configure: - if sys.platform == 'win32': + if sys.platform == "win32": self.read_registry() elif filename: self.read_resolv_conf(filename) @@ -859,10 +913,10 @@ class BaseResolver: f = stack.enter_context(open(f)) except OSError: # /etc/resolv.conf doesn't exist, can't be read, etc. - raise NoResolverConfiguration(f'cannot open {f}') + raise NoResolverConfiguration(f"cannot open {f}") for l in f: - if len(l) == 0 or l[0] == '#' or l[0] == ';': + if len(l) == 0 or l[0] == "#" or l[0] == ";": continue tokens = l.split() @@ -870,37 +924,37 @@ class BaseResolver: if len(tokens) < 2: continue - if tokens[0] == 'nameserver': + if tokens[0] == "nameserver": self.nameservers.append(tokens[1]) - elif tokens[0] == 'domain': + elif tokens[0] == "domain": self.domain = dns.name.from_text(tokens[1]) # domain and search are exclusive self.search = [] - elif tokens[0] == 'search': + elif tokens[0] == "search": # the last search wins self.search = [] for suffix in tokens[1:]: self.search.append(dns.name.from_text(suffix)) # We don't set domain as it is not used if # len(self.search) > 0 - elif tokens[0] == 'options': + elif tokens[0] == "options": for opt in tokens[1:]: - if opt == 'rotate': + if opt == "rotate": self.rotate = True - elif opt == 'edns0': + elif opt == "edns0": self.use_edns() - elif 'timeout' in opt: + elif "timeout" in opt: try: - self.timeout = int(opt.split(':')[1]) + self.timeout = int(opt.split(":")[1]) except (ValueError, IndexError): pass - elif 'ndots' in opt: + elif "ndots" in opt: try: - self.ndots = int(opt.split(':')[1]) + self.ndots = int(opt.split(":")[1]) except (ValueError, IndexError): pass if len(self.nameservers) == 0: - raise NoResolverConfiguration('no nameservers') + raise NoResolverConfiguration("no nameservers") def read_registry(self) -> None: """Extract resolver configuration from the Windows registry.""" @@ -913,8 +967,12 @@ class BaseResolver: except AttributeError: raise NotImplementedError - def _compute_timeout(self, start: float, lifetime: Optional[float]=None, - errors: Optional[List[ErrorTuple]]=None) -> float: + def _compute_timeout( + self, + start: float, + lifetime: Optional[float] = None, + errors: Optional[List[ErrorTuple]] = None, + ) -> float: lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start @@ -933,7 +991,9 @@ class BaseResolver: raise LifetimeTimeout(timeout=duration, errors=errors) return min(lifetime - duration, self.timeout) - def _get_qnames_to_try(self, qname: dns.name.Name, search: Optional[bool]) -> List[dns.name.Name]: + def _get_qnames_to_try( + self, qname: dns.name.Name, search: Optional[bool] + ) -> List[dns.name.Name]: # This is a separate method so we can unit test the search # rules without requiring the Internet. if search is None: @@ -972,8 +1032,12 @@ class BaseResolver: qnames_to_try.append(abs_qname) return qnames_to_try - def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None, - algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None: + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: """Add a TSIG signature to each query. The parameters are passed to ``dns.message.Message.use_tsig()``; @@ -984,9 +1048,13 @@ class BaseResolver: self.keyname = keyname self.keyalgorithm = algorithm - def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0, - payload: int=dns.message.DEFAULT_EDNS_PAYLOAD, - options: Optional[List[dns.edns.Option]]=None) -> None: + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = dns.message.DEFAULT_EDNS_PAYLOAD, + options: Optional[List[dns.edns.Option]] = None, + ) -> None: """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -1037,27 +1105,35 @@ class BaseResolver: for nameserver in nameservers: if not dns.inet.is_address(nameserver): try: - if urlparse(nameserver).scheme != 'https': + if urlparse(nameserver).scheme != "https": raise NotImplementedError except Exception: - raise ValueError(f'nameserver {nameserver} is not an ' - 'IP address or valid https URL') + raise ValueError( + f"nameserver {nameserver} is not an " + "IP address or valid https URL" + ) self._nameservers = nameservers else: - raise ValueError('nameservers must be a list' - ' (not a {})'.format(type(nameservers))) + raise ValueError( + "nameservers must be a list" " (not a {})".format(type(nameservers)) + ) class Resolver(BaseResolver): """DNS stub resolver.""" - def resolve(self, qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, - tcp: bool = False, source: Optional[str] = None, - raise_on_no_answer: bool = True, source_port: int = 0, - lifetime: Optional[float] = None, - search: Optional[bool] = None) -> Answer: # pylint: disable=arguments-differ + def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + ) -> Answer: # pylint: disable=arguments-differ """Query nameservers to find the answer to the question. The *qname*, *rdtype*, and *rdclass* parameters may be objects @@ -1109,8 +1185,9 @@ class Resolver(BaseResolver): """ - resolution = _Resolution(self, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search) + resolution = _Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) start = time.time() while True: (request, answer) = resolution.next_request() @@ -1127,27 +1204,30 @@ class Resolver(BaseResolver): (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: time.sleep(backoff) - timeout = self._compute_timeout(start, lifetime, - resolution.errors) + timeout = self._compute_timeout(start, lifetime, resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: - response = dns.query.tcp(request, nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port) + response = dns.query.tcp( + request, + nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port, + ) else: - response = dns.query.udp(request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - raise_on_truncation=True) + response = dns.query.udp( + request, + nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port, + raise_on_truncation=True, + ) else: - response = dns.query.https(request, nameserver, - timeout=timeout) + response = dns.query.https(request, nameserver, timeout=timeout) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -1159,11 +1239,17 @@ class Resolver(BaseResolver): if answer is not None: return answer - def query(self, qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, - lifetime: Optional[float]=None) -> Answer: # pragma: no cover + def query( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + ) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1171,13 +1257,26 @@ class Resolver(BaseResolver): dnspython. See the documentation for the resolve() method for further details. """ - warnings.warn('please use dns.resolver.Resolver.resolve() instead', - DeprecationWarning, stacklevel=2) - return self.resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, lifetime, - True) - - def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Dict[str, Any]) -> Answer: + warnings.warn( + "please use dns.resolver.Resolver.resolve() instead", + DeprecationWarning, + stacklevel=2, + ) + return self.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) + + def resolve_address( + self, ipaddr: str, *args: Any, **kwargs: Dict[str, Any] + ) -> Answer: """Use a resolver to run a reverse query for PTR records. This utilizes the resolve() method to perform a PTR lookup on the @@ -1195,10 +1294,11 @@ class Resolver(BaseResolver): # in the kwargs more than once. modified_kwargs: Dict[str, Any] = {} modified_kwargs.update(kwargs) - modified_kwargs['rdtype'] = dns.rdatatype.PTR - modified_kwargs['rdclass'] = dns.rdataclass.IN - return self.resolve(dns.reversename.from_address(ipaddr), - *args, **modified_kwargs) # type: ignore[arg-type] + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) # type: ignore[arg-type] # pylint: disable=redefined-outer-name @@ -1249,11 +1349,17 @@ def reset_default_resolver(): default_resolver = Resolver() -def resolve(qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, - lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pragma: no cover +def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, +) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. @@ -1264,15 +1370,29 @@ def resolve(qname: Union[dns.name.Name, str], parameters. """ - return get_default_resolver().resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, - lifetime, search) - -def query(qname: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, - rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, - lifetime: Optional[float]=None) -> Answer: # pragma: no cover + return get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) + + +def query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, +) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1280,11 +1400,20 @@ def query(qname: Union[dns.name.Name, str], dnspython. See the documentation for the resolve() method for further details. """ - warnings.warn('please use dns.resolver.resolve() instead', - DeprecationWarning, stacklevel=2) - return resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, lifetime, - True) + warnings.warn( + "please use dns.resolver.resolve() instead", DeprecationWarning, stacklevel=2 + ) + return resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) def resolve_address(ipaddr: str, *args: Any, **kwargs: Dict[str, Any]) -> Answer: @@ -1307,9 +1436,13 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return get_default_resolver().canonical_name(name) -def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, - tcp: bool=False, resolver: Optional[Resolver]=None, - lifetime: Optional[float]=None) -> dns.name.Name: +def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + lifetime: Optional[float] = None, +) -> dns.name.Name: """Find the name of the zone which contains the specified name. *name*, an absolute ``dns.name.Name`` or ``str``, the query name. @@ -1356,8 +1489,9 @@ def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.rdataclass.Rdata rlifetime = 0 else: rlifetime = None - answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp, - lifetime=rlifetime) + answer = resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, lifetime=rlifetime + ) assert answer.rrset is not None if answer.rrset.name == name: return name @@ -1386,6 +1520,7 @@ def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.rdataclass.Rdata except dns.name.NoParent: raise NoRootSOA + # # Support for overriding the system resolver for all python code in the # running process. @@ -1405,16 +1540,16 @@ _original_gethostbyname_ex = socket.gethostbyname_ex _original_gethostbyaddr = socket.gethostbyaddr -def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, - proto=0, flags=0): +def _getaddrinfo( + host=None, service=None, family=socket.AF_UNSPEC, socktype=0, proto=0, flags=0 +): if flags & socket.AI_NUMERICHOST != 0: # Short circuit directly into the system's getaddrinfo(). We're # not adding any value in this case, and this avoids infinite loops # because dns.query.* needs to call getaddrinfo() for IPv6 scoping # reasons. We will also do this short circuit below if we # discover that the host is an address literal. - return _original_getaddrinfo(host, service, family, socktype, proto, - flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0: # Not implemented. We raise a gaierror as opposed to a # NotImplementedError as it helps callers handle errors more @@ -1424,32 +1559,30 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, # no EAI_SYSTEM on Windows [Issue #416]. We didn't go for # EAI_BADFLAGS as the flags aren't bad, we just don't # implement them. - raise socket.gaierror(socket.EAI_FAIL, - 'Non-recoverable failure in name resolution') + raise socket.gaierror( + socket.EAI_FAIL, "Non-recoverable failure in name resolution" + ) if host is None and service is None: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") v6addrs = [] v4addrs = [] canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). if host is None: - return _original_getaddrinfo(host, service, family, socktype, - proto, flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) try: # We don't care about the result of af_for_address(), we're just # calling it so it raises an exception if host is not an IPv4 or # IPv6 address. dns.inet.af_for_address(host) - return _original_getaddrinfo(host, service, family, socktype, - proto, flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) except Exception: pass # Something needs resolution! try: if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.resolve(host, dns.rdatatype.AAAA, - raise_on_no_answer=False) + v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) # Note that setting host ensures we query the same name # for A as we did for AAAA. (This is just in case search lists # are active by default in the resolver configuration and @@ -1461,20 +1594,18 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, for rdata in v6.rrset: v6addrs.append(rdata.address) if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.resolve(host, dns.rdatatype.A, - raise_on_no_answer=False) + v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False) canonical_name = v4.canonical_name.to_text(True) if v4.rrset is not None: for rdata in v4.rrset: v4addrs.append(rdata.address) except dns.resolver.NXDOMAIN: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: # We raise EAI_AGAIN here as the failure may be temporary # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows. # [Issue #416] - raise socket.gaierror(socket.EAI_AGAIN, - 'Temporary failure in name resolution') + raise socket.gaierror(socket.EAI_AGAIN, "Temporary failure in name resolution") port = None try: # Is it a port literal? @@ -1489,7 +1620,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, except Exception: pass if port is None: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") tuples = [] if socktype == 0: socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM] @@ -1498,21 +1629,23 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, if flags & socket.AI_CANONNAME != 0: cname = canonical_name else: - cname = '' + cname = "" if family == socket.AF_INET6 or family == socket.AF_UNSPEC: for addr in v6addrs: for socktype in socktypes: for proto in _protocols_for_socktype[socktype]: - tuples.append((socket.AF_INET6, socktype, proto, - cname, (addr, port, 0, 0))) + tuples.append( + (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0)) + ) if family == socket.AF_INET or family == socket.AF_UNSPEC: for addr in v4addrs: for socktype in socktypes: for proto in _protocols_for_socktype[socktype]: - tuples.append((socket.AF_INET, socktype, proto, - cname, (addr, port))) + tuples.append( + (socket.AF_INET, socktype, proto, cname, (addr, port)) + ) if len(tuples) == 0: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") return tuples @@ -1525,31 +1658,29 @@ def _getnameinfo(sockaddr, flags=0): else: scope = None family = socket.AF_INET - tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, - socket.SOL_TCP, 0) + tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0) if len(tuples) > 1: - raise socket.error('sockaddr resolved to multiple addresses') + raise socket.error("sockaddr resolved to multiple addresses") addr = tuples[0][4][0] if flags & socket.NI_DGRAM: - pname = 'udp' + pname = "udp" else: - pname = 'tcp' + pname = "tcp" qname = dns.reversename.from_address(addr) if flags & socket.NI_NUMERICHOST == 0: try: - answer = _resolver.resolve(qname, 'PTR') + answer = _resolver.resolve(qname, "PTR") hostname = answer.rrset[0].target.to_text(True) except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): if flags & socket.NI_NAMEREQD: - raise socket.gaierror(socket.EAI_NONAME, - 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") hostname = addr if scope is not None: - hostname += '%' + str(scope) + hostname += "%" + str(scope) else: hostname = addr if scope is not None: - hostname += '%' + str(scope) + hostname += "%" + str(scope) if flags & socket.NI_NUMERICSERV: service = str(port) else: @@ -1576,8 +1707,9 @@ def _gethostbyname(name): def _gethostbyname_ex(name): aliases = [] addresses = [] - tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM, - socket.SOL_TCP, socket.AI_CANONNAME) + tuples = _getaddrinfo( + name, 0, socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) canonical = tuples[0][3] for item in tuples: addresses.append(item[4][0]) @@ -1594,15 +1726,15 @@ def _gethostbyaddr(ip): try: dns.ipv4.inet_aton(ip) except Exception: - raise socket.gaierror(socket.EAI_NONAME, - 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") sockaddr = (ip, 80) family = socket.AF_INET (name, _) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) aliases = [] addresses = [] - tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, - socket.AI_CANONNAME) + tuples = _getaddrinfo( + name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) canonical = tuples[0][3] # We only want to include an address from the tuples if it's the # same as the one we asked about. We do this comparison in binary @@ -1617,7 +1749,7 @@ def _gethostbyaddr(ip): return (canonical, aliases, addresses) -def override_system_resolver(resolver: Optional[Resolver]=None) -> None: +def override_system_resolver(resolver: Optional[Resolver] = None) -> None: """Override the system resolver routines in the socket module with versions which use dnspython's resolver. diff --git a/dns/reversename.py b/dns/reversename.py index c25e77df..eb6a3b6b 100644 --- a/dns/reversename.py +++ b/dns/reversename.py @@ -23,12 +23,15 @@ import dns.name import dns.ipv6 import dns.ipv4 -ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') -ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') +ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") +ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") -def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain, - v6_origin: dns.name.Name=ipv6_reverse_domain) -> dns.name.Name: +def from_address( + text: str, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> dns.name.Name: """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. @@ -51,20 +54,22 @@ def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain, try: v6 = dns.ipv6.inet_aton(text) if dns.ipv6.is_mapped(v6): - parts = ['%d' % byte for byte in v6[12:]] + parts = ["%d" % byte for byte in v6[12:]] origin = v4_origin else: parts = [x for x in str(binascii.hexlify(v6).decode())] origin = v6_origin except Exception: - parts = ['%d' % - byte for byte in dns.ipv4.inet_aton(text)] + parts = ["%d" % byte for byte in dns.ipv4.inet_aton(text)] origin = v4_origin - return dns.name.from_text('.'.join(reversed(parts)), origin=origin) + return dns.name.from_text(".".join(reversed(parts)), origin=origin) -def to_address(name: dns.name.Name, v4_origin: dns.name.Name=ipv4_reverse_domain, - v6_origin: dns.name.Name=ipv6_reverse_domain) -> str: +def to_address( + name: dns.name.Name, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> str: """Convert a reverse map domain name into textual address form. *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name @@ -84,7 +89,7 @@ def to_address(name: dns.name.Name, v4_origin: dns.name.Name=ipv4_reverse_domain if name.is_subdomain(v4_origin): name = name.relativize(v4_origin) - text = b'.'.join(reversed(name.labels)) + text = b".".join(reversed(name.labels)) # run through inet_ntoa() to check syntax and make pretty. return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) elif name.is_subdomain(v6_origin): @@ -92,9 +97,9 @@ def to_address(name: dns.name.Name, v4_origin: dns.name.Name=ipv4_reverse_domain labels = list(reversed(name.labels)) parts = [] for i in range(0, len(labels), 4): - parts.append(b''.join(labels[i:i + 4])) - text = b':'.join(parts) + parts.append(b"".join(labels[i : i + 4])) + text = b":".join(parts) # run through inet_ntoa() to check syntax and make pretty. return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) else: - raise dns.exception.SyntaxError('unknown reverse-map address family') + raise dns.exception.SyntaxError("unknown reverse-map address family") diff --git a/dns/rrset.py b/dns/rrset.py index bfa47630..4217a04a 100644 --- a/dns/rrset.py +++ b/dns/rrset.py @@ -36,12 +36,16 @@ class RRset(dns.rdataset.Rdataset): name. """ - __slots__ = ['name', 'deleting'] - - def __init__(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - deleting: Optional[dns.rdataclass.RdataClass]=None): + __slots__ = ["name", "deleting"] + + def __init__( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ): """Create a new RRset.""" super().__init__(rdclass, rdtype, covers) @@ -56,17 +60,26 @@ class RRset(dns.rdataset.Rdataset): def __repr__(self): if self.covers == 0: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" if self.deleting is not None: - dtext = ' delete=' + dns.rdataclass.to_text(self.deleting) + dtext = " delete=" + dns.rdataclass.to_text(self.deleting) else: - dtext = '' - return '' + dtext = "" + return ( + "" + ) def __str__(self): return self.to_text() @@ -79,7 +92,9 @@ class RRset(dns.rdataset.Rdataset): return False return super().__eq__(other) - def match(self, *args: Any, **kwargs: Dict[str, Any]) -> bool: # type: ignore[override] + def match( # type: ignore[override] + self, *args: Any, **kwargs: Dict[str, Any] + ) -> bool: """Does this rrset match the specified attributes? Behaves as :py:func:`full_match()` if the first argument is a @@ -96,9 +111,14 @@ class RRset(dns.rdataset.Rdataset): else: return super().match(*args, **kwargs) # type: ignore[arg-type] - def full_match(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass, - rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType, - deleting: Optional[dns.rdataclass.RdataClass]=None) -> bool: + def full_match( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ) -> bool: """Returns ``True`` if this rrset matches the specified name, class, type, covers, and deletion state. """ @@ -110,7 +130,12 @@ class RRset(dns.rdataset.Rdataset): # pylint: disable=arguments-differ - def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw) -> str: # type: ignore + def to_text( # type: ignore[override] + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: """Convert the RRset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -127,11 +152,17 @@ class RRset(dns.rdataset.Rdataset): to *origin*. """ - return super().to_text(self.name, origin, relativize, - self.deleting, **kw) - - def to_wire(self, file: Any, compress: Optional[dns.name.CompressType]=None, # type: ignore - origin: Optional[dns.name.Name]=None, **kw) -> int: + return super().to_text( + self.name, origin, relativize, self.deleting, **kw # type: ignore + ) + + def to_wire( # type: ignore[override] + self, + file: Any, + compress: Optional[dns.name.CompressType] = None, # type: ignore + origin: Optional[dns.name.Name] = None, + **kw: Dict[str, Any] + ) -> int: """Convert the RRset to wire format. All keyword arguments are passed to ``dns.rdataset.to_wire()``; see @@ -140,8 +171,9 @@ class RRset(dns.rdataset.Rdataset): Returns an ``int``, the number of records emitted. """ - return super().to_wire(self.name, file, compress, origin, - self.deleting, **kw) + return super().to_wire( + self.name, file, compress, origin, self.deleting, **kw # type:ignore + ) # pylint: enable=arguments-differ @@ -153,13 +185,17 @@ class RRset(dns.rdataset.Rdataset): return dns.rdataset.from_rdata_list(self.ttl, list(self)) -def from_text_list(name: Union[dns.name.Name, str], ttl: int, - rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - text_rdatas: Collection[str], - idna_codec: Optional[dns.name.IDNACodec]=None, - origin: Optional[dns.name.Name]=None, relativize: bool=True, - relativize_to: Optional[dns.name.Name]=None) -> RRset: +def from_text_list( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> RRset: """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. @@ -185,29 +221,37 @@ def from_text_list(name: Union[dns.name.Name, str], ttl: int, r = RRset(name, the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, - relativize_to, idna_codec) + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) r.add(rd) return r -def from_text(name: Union[dns.name.Name, str], ttl: int, - rdclass: Union[dns.rdataclass.RdataClass, str], - rdtype: Union[dns.rdatatype.RdataType, str], - *text_rdatas: Any) -> RRset: +def from_text( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + *text_rdatas: Any +) -> RRset: """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. Returns a ``dns.rrset.RRset`` object. """ - return from_text_list(name, ttl, rdclass, rdtype, - cast(Collection[str], text_rdatas)) + return from_text_list( + name, ttl, rdclass, rdtype, cast(Collection[str], text_rdatas) + ) -def from_rdata_list(name: Union[dns.name.Name, str], ttl: int, - rdatas: Collection[dns.rdata.Rdata], - idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset: +def from_rdata_list( + name: Union[dns.name.Name, str], + ttl: int, + rdatas: Collection[dns.rdata.Rdata], + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> RRset: """Create an RRset with the specified name and TTL, and with the specified list of rdata objects. @@ -234,7 +278,7 @@ def from_rdata_list(name: Union[dns.name.Name, str], ttl: int, return r -def from_rdata(name: Union[dns.name.Name, str], ttl:int, *rdatas: Any) -> RRset: +def from_rdata(name: Union[dns.name.Name, str], ttl: int, *rdatas: Any) -> RRset: """Create an RRset with the specified name and TTL, and with the specified rdata objects. diff --git a/dns/serial.py b/dns/serial.py index b4d264cb..3417299b 100644 --- a/dns/serial.py +++ b/dns/serial.py @@ -2,13 +2,14 @@ """Serial Number Arthimetic from RFC 1982""" + class Serial: - def __init__(self, value: int, bits: int=32): - self.value = value % 2 ** bits + def __init__(self, value: int, bits: int = 32): + self.value = value % 2**bits self.bits = bits def __repr__(self): - return f'dns.serial.Serial({self.value}, {self.bits})' + return f"dns.serial.Serial({self.value}, {self.bits})" def __eq__(self, other): if isinstance(other, int): @@ -29,11 +30,11 @@ class Serial: other = Serial(other, self.bits) elif not isinstance(other, Serial) or other.bits != self.bits: return NotImplemented - if self.value < other.value and \ - other.value - self.value < 2 ** (self.bits - 1): + if self.value < other.value and other.value - self.value < 2 ** (self.bits - 1): return True - elif self.value > other.value and \ - self.value - other.value > 2 ** (self.bits - 1): + elif self.value > other.value and self.value - other.value > 2 ** ( + self.bits - 1 + ): return True else: return False @@ -46,11 +47,11 @@ class Serial: other = Serial(other, self.bits) elif not isinstance(other, Serial) or other.bits != self.bits: return NotImplemented - if self.value < other.value and \ - other.value - self.value > 2 ** (self.bits - 1): + if self.value < other.value and other.value - self.value > 2 ** (self.bits - 1): return True - elif self.value > other.value and \ - self.value - other.value < 2 ** (self.bits - 1): + elif self.value > other.value and self.value - other.value < 2 ** ( + self.bits - 1 + ): return True else: return False @@ -69,7 +70,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v += delta - v = v % 2 ** self.bits + v = v % 2**self.bits return Serial(v, self.bits) def __iadd__(self, other): @@ -83,7 +84,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v += delta - v = v % 2 ** self.bits + v = v % 2**self.bits self.value = v return self @@ -98,7 +99,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v -= delta - v = v % 2 ** self.bits + v = v % 2**self.bits return Serial(v, self.bits) def __isub__(self, other): @@ -112,6 +113,6 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v -= delta - v = v % 2 ** self.bits + v = v % 2**self.bits self.value = v return self diff --git a/dns/set.py b/dns/set.py index a4e12b67..fa50ed97 100644 --- a/dns/set.py +++ b/dns/set.py @@ -28,7 +28,7 @@ class Set: ability is widely used in dnspython applications. """ - __slots__ = ['items'] + __slots__ = ["items"] def __init__(self, items=None): """Initialize the set. @@ -47,15 +47,13 @@ class Set: return "dns.set.Set(%s)" % repr(list(self.items.keys())) def add(self, item): - """Add an item to the set. - """ + """Add an item to the set.""" if item not in self.items: self.items[item] = None def remove(self, item): - """Remove an item from the set. - """ + """Remove an item from the set.""" try: del self.items[item] @@ -63,8 +61,7 @@ class Set: raise ValueError def discard(self, item): - """Remove an item from the set if present. - """ + """Remove an item from the set if present.""" self.items.pop(item, None) @@ -73,7 +70,7 @@ class Set: (k, _) = self.items.popitem() return k - def _clone(self) -> 'Set': + def _clone(self) -> "Set": """Make a (shallow) copy of the set. There is a 'clone protocol' that subclasses of this class @@ -86,8 +83,8 @@ class Set: subclasses. """ - if hasattr(self, '_clone_class'): - cls = self._clone_class # type: ignore + if hasattr(self, "_clone_class"): + cls = self._clone_class # type: ignore else: cls = self.__class__ obj = cls.__new__(cls) @@ -96,14 +93,12 @@ class Set: return obj def __copy__(self): - """Make a (shallow) copy of the set. - """ + """Make a (shallow) copy of the set.""" return self._clone() def copy(self): - """Make a (shallow) copy of the set. - """ + """Make a (shallow) copy of the set.""" return self._clone() @@ -113,7 +108,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") if self is other: # lgtm[py/comparison-using-is] return for item in other.items: @@ -125,7 +120,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") if self is other: # lgtm[py/comparison-using-is] return # we make a copy of the list so that we can remove items from @@ -140,7 +135,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") if self is other: # lgtm[py/comparison-using-is] self.items.clear() else: @@ -151,7 +146,7 @@ class Set: """Update the set, retaining only elements unique to both sets.""" if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") if self is other: # lgtm[py/comparison-using-is] self.items.clear() else: @@ -285,7 +280,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") for item in self.items: if item not in other.items: return False @@ -298,7 +293,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") for item in other.items: if item not in self.items: return False @@ -306,7 +301,7 @@ class Set: def isdisjoint(self, other): if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") for item in other.items: if item in self.items: return False diff --git a/dns/tokenizer.py b/dns/tokenizer.py index 275861c6..0551578a 100644 --- a/dns/tokenizer.py +++ b/dns/tokenizer.py @@ -17,7 +17,7 @@ """Tokenize DNS zone file format""" -from typing import Any, Optional, List, Tuple, Union +from typing import Any, Optional, List, Tuple import io import sys @@ -26,7 +26,7 @@ import dns.exception import dns.name import dns.ttl -_DELIMITERS = {' ', '\t', '\n', ';', '(', ')', '"'} +_DELIMITERS = {" ", "\t", "\n", ";", "(", ")", '"'} _QUOTING_DELIMITERS = {'"'} EOF = 0 @@ -50,8 +50,13 @@ class Token: has_escape: Does the token value contain escapes? """ - def __init__(self, ttype: int, value: Any='', has_escape: bool=False, - comment: Optional[str]=None): + def __init__( + self, + ttype: int, + value: Any = "", + has_escape: bool = False, + comment: Optional[str] = None, + ): """Initialize a token instance.""" self.ttype = ttype @@ -86,28 +91,26 @@ class Token: def __eq__(self, other): if not isinstance(other, Token): return False - return (self.ttype == other.ttype and - self.value == other.value) + return self.ttype == other.ttype and self.value == other.value def __ne__(self, other): if not isinstance(other, Token): return True - return (self.ttype != other.ttype or - self.value != other.value) + return self.ttype != other.ttype or self.value != other.value def __str__(self): return '%d "%s"' % (self.ttype, self.value) - def unescape(self) -> 'Token': + def unescape(self) -> "Token": if not self.has_escape: return self - unescaped = '' + unescaped = "" l = len(self.value) i = 0 while i < l: c = self.value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] @@ -130,7 +133,7 @@ class Token: unescaped += c return Token(self.ttype, unescaped) - def unescape_to_bytes(self) -> 'Token': + def unescape_to_bytes(self) -> "Token": # We used to use unescape() for TXT-like records, but this # caused problems as we'd process DNS escapes into Unicode code # points instead of byte values, and then a to_text() of the @@ -155,13 +158,13 @@ class Token: # # foo\226\128\139bar # - unescaped = b'' + unescaped = b"" l = len(self.value) i = 0 while i < l: c = self.value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] @@ -180,7 +183,7 @@ class Token: codepoint = int(c) * 100 + int(c2) * 10 + int(c3) if codepoint > 255: raise dns.exception.SyntaxError - unescaped += b'%c' % (codepoint) + unescaped += b"%c" % (codepoint) else: # Note that as mentioned above, if c is a Unicode # code point outside of the ASCII range, then this @@ -226,8 +229,12 @@ class Tokenizer: encoder/decoder is used. """ - def __init__(self, f: Any=sys.stdin, filename: Optional[str]=None, - idna_codec: Optional[dns.name.IDNACodec]=None): + def __init__( + self, + f: Any = sys.stdin, + filename: Optional[str] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + ): """Initialize a tokenizer instance. f: The file to tokenize. The default is sys.stdin. @@ -245,17 +252,17 @@ class Tokenizer: if isinstance(f, str): f = io.StringIO(f) if filename is None: - filename = '' + filename = "" elif isinstance(f, bytes): f = io.StringIO(f.decode()) if filename is None: - filename = '' + filename = "" else: if filename is None: if f is sys.stdin: - filename = '' + filename = "" else: - filename = '' + filename = "" self.file = f self.ungotten_char: Optional[str] = None self.ungotten_token: Optional[Token] = None @@ -272,17 +279,16 @@ class Tokenizer: self.idna_codec = idna_codec def _get_char(self) -> str: - """Read a character from input. - """ + """Read a character from input.""" if self.ungotten_char is None: if self.eof: - c = '' + c = "" else: c = self.file.read(1) - if c == '': + if c == "": self.eof = True - elif c == '\n': + elif c == "\n": self.line_number += 1 else: c = self.ungotten_char @@ -328,13 +334,13 @@ class Tokenizer: skipped = 0 while True: c = self._get_char() - if c != ' ' and c != '\t': - if (c != '\n') or not self.multiline: + if c != " " and c != "\t": + if (c != "\n") or not self.multiline: self._unget_char(c) return skipped skipped += 1 - def get(self, want_leading: bool=False, want_comment: bool=False) -> Token: + def get(self, want_leading: bool = False, want_comment: bool = False) -> Token: """Get the next token. want_leading: If True, return a WHITESPACE token if the @@ -363,21 +369,21 @@ class Tokenizer: return utoken skipped = self.skip_whitespace() if want_leading and skipped > 0: - return Token(WHITESPACE, ' ') - token = '' + return Token(WHITESPACE, " ") + token = "" ttype = IDENTIFIER has_escape = False while True: c = self._get_char() - if c == '' or c in self.delimiters: - if c == '' and self.quoting: + if c == "" or c in self.delimiters: + if c == "" and self.quoting: raise dns.exception.UnexpectedEnd - if token == '' and ttype != QUOTED_STRING: - if c == '(': + if token == "" and ttype != QUOTED_STRING: + if c == "(": self.multiline += 1 self.skip_whitespace() continue - elif c == ')': + elif c == ")": if self.multiline <= 0: raise dns.exception.SyntaxError self.multiline -= 1 @@ -394,28 +400,29 @@ class Tokenizer: self.delimiters = _DELIMITERS self.skip_whitespace() continue - elif c == '\n': - return Token(EOL, '\n') - elif c == ';': + elif c == "\n": + return Token(EOL, "\n") + elif c == ";": while 1: c = self._get_char() - if c == '\n' or c == '': + if c == "\n" or c == "": break token += c if want_comment: self._unget_char(c) return Token(COMMENT, token) - elif c == '': + elif c == "": if self.multiline: raise dns.exception.SyntaxError( - 'unbalanced parentheses') + "unbalanced parentheses" + ) return Token(EOF, comment=token) elif self.multiline: self.skip_whitespace() - token = '' + token = "" continue else: - return Token(EOL, '\n', comment=token) + return Token(EOL, "\n", comment=token) else: # This code exists in case we ever want a # delimiter to be returned. It never produces @@ -425,9 +432,9 @@ class Tokenizer: else: self._unget_char(c) break - elif self.quoting and c == '\n': - raise dns.exception.SyntaxError('newline in quoted string') - elif c == '\\': + elif self.quoting and c == "\n": + raise dns.exception.SyntaxError("newline in quoted string") + elif c == "\\": # # It's an escape. Put it and the next character into # the token; it will be checked later for goodness. @@ -435,12 +442,12 @@ class Tokenizer: token += c has_escape = True c = self._get_char() - if c == '' or (c == '\n' and not self.quoting): + if c == "" or (c == "\n" and not self.quoting): raise dns.exception.UnexpectedEnd token += c - if token == '' and ttype != QUOTED_STRING: + if token == "" and ttype != QUOTED_STRING: if self.multiline: - raise dns.exception.SyntaxError('unbalanced parentheses') + raise dns.exception.SyntaxError("unbalanced parentheses") ttype = EOF return Token(ttype, token, has_escape) @@ -478,7 +485,7 @@ class Tokenizer: # Helpers - def get_int(self, base: int=10) -> int: + def get_int(self, base: int = 10) -> int: """Read the next token and interpret it as an unsigned integer. Raises dns.exception.SyntaxError if not an unsigned integer. @@ -488,9 +495,9 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") if not token.value.isdigit(): - raise dns.exception.SyntaxError('expecting an integer') + raise dns.exception.SyntaxError("expecting an integer") return int(token.value, base) def get_uint8(self) -> int: @@ -505,10 +512,11 @@ class Tokenizer: value = self.get_int() if value < 0 or value > 255: raise dns.exception.SyntaxError( - '%d is not an unsigned 8-bit integer' % value) + "%d is not an unsigned 8-bit integer" % value + ) return value - def get_uint16(self, base: int=10) -> int: + def get_uint16(self, base: int = 10) -> int: """Read the next token and interpret it as a 16-bit unsigned integer. @@ -521,13 +529,15 @@ class Tokenizer: if value < 0 or value > 65535: if base == 8: raise dns.exception.SyntaxError( - '%o is not an octal unsigned 16-bit integer' % value) + "%o is not an octal unsigned 16-bit integer" % value + ) else: raise dns.exception.SyntaxError( - '%d is not an unsigned 16-bit integer' % value) + "%d is not an unsigned 16-bit integer" % value + ) return value - def get_uint32(self, base: int=10) -> int: + def get_uint32(self, base: int = 10) -> int: """Read the next token and interpret it as a 32-bit unsigned integer. @@ -539,10 +549,11 @@ class Tokenizer: value = self.get_int(base=base) if value < 0 or value > 4294967295: raise dns.exception.SyntaxError( - '%d is not an unsigned 32-bit integer' % value) + "%d is not an unsigned 32-bit integer" % value + ) return value - def get_uint48(self, base: int=10) -> int: + def get_uint48(self, base: int = 10) -> int: """Read the next token and interpret it as a 48-bit unsigned integer. @@ -554,10 +565,11 @@ class Tokenizer: value = self.get_int(base=base) if value < 0 or value > 281474976710655: raise dns.exception.SyntaxError( - '%d is not an unsigned 48-bit integer' % value) + "%d is not an unsigned 48-bit integer" % value + ) return value - def get_string(self, max_length: Optional[int]=None) -> str: + def get_string(self, max_length: Optional[int] = None) -> str: """Read the next token and interpret it as a string. Raises dns.exception.SyntaxError if not a string. @@ -569,7 +581,7 @@ class Tokenizer: token = self.get().unescape() if not (token.is_identifier() or token.is_quoted_string()): - raise dns.exception.SyntaxError('expecting a string') + raise dns.exception.SyntaxError("expecting a string") if max_length and len(token.value) > max_length: raise dns.exception.SyntaxError("string too long") return token.value @@ -584,10 +596,10 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") return token.value - def get_remaining(self, max_tokens: Optional[int]=None) -> List[Token]: + def get_remaining(self, max_tokens: Optional[int] = None) -> List[Token]: """Return the remaining tokens on the line, until an EOL or EOF is seen. max_tokens: If not None, stop after this number of tokens. @@ -606,7 +618,7 @@ class Tokenizer: break return tokens - def concatenate_remaining_identifiers(self, allow_empty: bool=False) -> str: + def concatenate_remaining_identifiers(self, allow_empty: bool = False) -> str: """Read the remaining tokens on the line, which should be identifiers. Raises dns.exception.SyntaxError if there are no remaining tokens, @@ -628,11 +640,16 @@ class Tokenizer: raise dns.exception.SyntaxError s += token.value if not (allow_empty or s): - raise dns.exception.SyntaxError('expecting another identifier') + raise dns.exception.SyntaxError("expecting another identifier") return s - def as_name(self, token: Token, origin: Optional[dns.name.Name]=None, - relativize: bool=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: + def as_name( + self, + token: Token, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: """Try to interpret the token as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -640,12 +657,16 @@ class Tokenizer: Returns a dns.name.Name. """ if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") name = dns.name.from_text(token.value, origin, self.idna_codec) return name.choose_relativity(relativize_to or origin, relativize) - def get_name(self, origin: Optional[dns.name.Name]=None, relativize: bool=False, - relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: + def get_name( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: """Read the next token and interpret it as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -666,8 +687,8 @@ class Tokenizer: token = self.get() if not token.is_eol_or_eof(): raise dns.exception.SyntaxError( - 'expected EOL or EOF, got %d "%s"' % (token.ttype, - token.value)) + 'expected EOL or EOF, got %d "%s"' % (token.ttype, token.value) + ) return token def get_eol(self) -> str: @@ -684,5 +705,5 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") return dns.ttl.from_text(token.value) diff --git a/dns/transaction.py b/dns/transaction.py index 1e97c757..b0429dfb 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -16,11 +16,11 @@ import dns.ttl class TransactionManager: - def reader(self) -> 'Transaction': + def reader(self) -> "Transaction": """Begin a read-only transaction.""" raise NotImplementedError # pragma: no cover - def writer(self, replacement: bool=False) -> 'Transaction': + def writer(self, replacement: bool = False) -> "Transaction": """Begin a writable transaction. *replacement*, a ``bool``. If `True`, the content of the @@ -30,7 +30,9 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover - def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: """Returns a tuple (absolute_origin, relativize, effective_origin) @@ -56,13 +58,11 @@ class TransactionManager: raise NotImplementedError # pragma: no cover def get_class(self) -> dns.rdataclass.RdataClass: - """The class of the transaction manager. - """ + """The class of the transaction manager.""" raise NotImplementedError # pragma: no cover def from_wire_origin(self) -> Optional[dns.name.Name]: - """Origin to use in from_wire() calls. - """ + """Origin to use in from_wire() calls.""" (absolute_origin, relativize, _) = self.origin_information() if relativize: return absolute_origin @@ -87,39 +87,51 @@ def _ensure_immutable_rdataset(rdataset): return rdataset return dns.rdataset.ImmutableRdataset(rdataset) + def _ensure_immutable_node(node): if node is None or node.is_immutable(): return node return dns.node.ImmutableNode(node) -CheckPutRdatasetType = Callable[['Transaction', dns.name.Name, dns.rdataset.Rdataset], None] -CheckDeleteRdatasetType = Callable[['Transaction', dns.name.Name, - dns.rdatatype.RdataType, dns.rdatatype.RdataType], None] -CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None] +CheckPutRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None +] +CheckDeleteRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType], + None, +] +CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None] class Transaction: - - def __init__(self, manager: TransactionManager, replacement: bool=False, read_only: bool=False): + def __init__( + self, + manager: TransactionManager, + replacement: bool = False, + read_only: bool = False, + ): self.manager = manager self.replacement = replacement self.read_only = read_only self._ended = False - self._check_put_rdataset: List[CheckPutRdatasetType]= [] + self._check_put_rdataset: List[CheckPutRdatasetType] = [] self._check_delete_rdataset: List[CheckDeleteRdatasetType] = [] self._check_delete_name: List[CheckDeleteNameType] = [] # # This is the high level API # - # Note that we currently use non-immutable types in the return type signature to avoid - # covariance problems, e.g. if the caller has a List[Rdataset], mypy will be unhappy if we - # return an ImmutableRdataset. - - def get(self, name: Optional[Union[dns.name.Name,str]], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rdataset.Rdataset: + # Note that we currently use non-immutable types in the return type signature to + # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be + # unhappy if we return an ImmutableRdataset. + + def get( + self, + name: Optional[Union[dns.name.Name, str]], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rdataset.Rdataset: """Return the rdataset associated with *name*, *rdtype*, and *covers*, or `None` if not found. @@ -232,7 +244,12 @@ class Transaction: name = dns.name.from_text(name, None) return self._name_exists(name) - def update_serial(self, value: int=1, relative: bool=True, name: dns.name.Name=dns.name.empty) -> None: + def update_serial( + self, + value: int = 1, + relative: bool = True, + name: dns.name.Name = dns.name.empty, + ) -> None: """Update the serial number. *value*, an `int`, is an increment if *relative* is `True`, or the @@ -246,11 +263,10 @@ class Transaction: """ self._check_ended() if value < 0: - raise ValueError('negative update_serial() value') + raise ValueError("negative update_serial() value") if isinstance(name, str): name = dns.name.from_text(name, None) - rdataset = self._get_rdataset(name, dns.rdatatype.SOA, - dns.rdatatype.NONE) + rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE) if rdataset is None or len(rdataset) == 0: raise KeyError if relative: @@ -347,7 +363,7 @@ class Transaction: def _raise_if_not_empty(self, method, args): if len(args) != 0: - raise TypeError(f'extra parameters to {method}') + raise TypeError(f"extra parameters to {method}") def _rdataset_from_args(self, method, deleting, args): try: @@ -363,29 +379,29 @@ class Transaction: if isinstance(arg, int): ttl = arg if ttl > dns.ttl.MAX_TTL: - raise ValueError(f'{method}: TTL value too big') + raise ValueError(f"{method}: TTL value too big") else: - raise TypeError(f'{method}: expected a TTL') + raise TypeError(f"{method}: expected a TTL") arg = args.popleft() if isinstance(arg, dns.rdata.Rdata): rdataset = dns.rdataset.from_rdata(ttl, arg) else: - raise TypeError(f'{method}: expected an Rdata') + raise TypeError(f"{method}: expected an Rdata") return rdataset except IndexError: if deleting: return None else: # reraise - raise TypeError(f'{method}: expected more arguments') + raise TypeError(f"{method}: expected more arguments") def _add(self, replace, args): try: args = collections.deque(args) if replace: - method = 'replace()' + method = "replace()" else: - method = 'add()' + method = "add()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) @@ -399,44 +415,45 @@ class Transaction: # same and can't be stored in nodes, so convert. rdataset = rrset.to_rdataset() else: - raise TypeError(f'{method} requires a name or RRset ' + - 'as the first argument') + raise TypeError( + f"{method} requires a name or RRset " + "as the first argument" + ) if rdataset.rdclass != self.manager.get_class(): - raise ValueError(f'{method} has objects of wrong RdataClass') + raise ValueError(f"{method} has objects of wrong RdataClass") if rdataset.rdtype == dns.rdatatype.SOA: (_, _, origin) = self._origin_information() if name != origin: - raise ValueError(f'{method} has non-origin SOA') + raise ValueError(f"{method} has non-origin SOA") self._raise_if_not_empty(method, args) if not replace: - existing = self._get_rdataset(name, rdataset.rdtype, - rdataset.covers) + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) if existing is not None: if isinstance(existing, dns.rdataset.ImmutableRdataset): - trds = dns.rdataset.Rdataset(existing.rdclass, - existing.rdtype, - existing.covers) + trds = dns.rdataset.Rdataset( + existing.rdclass, existing.rdtype, existing.covers + ) trds.update(existing) existing = trds rdataset = existing.union(rdataset) self._checked_put_rdataset(name, rdataset) except IndexError: - raise TypeError(f'not enough parameters to {method}') + raise TypeError(f"not enough parameters to {method}") def _delete(self, exact, args): try: args = collections.deque(args) if exact: - method = 'delete_exact()' + method = "delete_exact()" else: - method = 'delete()' + method = "delete()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) if isinstance(arg, dns.name.Name): name = arg - if len(args) > 0 and (isinstance(args[0], int) or - isinstance(args[0], str)): + if len(args) > 0 and ( + isinstance(args[0], int) or isinstance(args[0], str) + ): # deleting by type and (optionally) covers rdtype = dns.rdatatype.RdataType.make(args.popleft()) if len(args) > 0: @@ -447,7 +464,7 @@ class Transaction: existing = self._get_rdataset(name, rdtype, covers) if existing is None: if exact: - raise DeleteNotExact(f'{method}: missing rdataset') + raise DeleteNotExact(f"{method}: missing rdataset") else: self._delete_rdataset(name, rdtype, covers) return @@ -457,34 +474,34 @@ class Transaction: rdataset = arg # rrsets are also rdatasets name = rdataset.name else: - raise TypeError(f'{method} requires a name or RRset ' + - 'as the first argument') + raise TypeError( + f"{method} requires a name or RRset " + "as the first argument" + ) self._raise_if_not_empty(method, args) if rdataset: if rdataset.rdclass != self.manager.get_class(): - raise ValueError(f'{method} has objects of wrong ' - 'RdataClass') - existing = self._get_rdataset(name, rdataset.rdtype, - rdataset.covers) + raise ValueError(f"{method} has objects of wrong " "RdataClass") + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) if existing is not None: if exact: intersection = existing.intersection(rdataset) if intersection != rdataset: - raise DeleteNotExact(f'{method}: missing rdatas') + raise DeleteNotExact(f"{method}: missing rdatas") rdataset = existing.difference(rdataset) if len(rdataset) == 0: - self._checked_delete_rdataset(name, rdataset.rdtype, - rdataset.covers) + self._checked_delete_rdataset( + name, rdataset.rdtype, rdataset.covers + ) else: self._checked_put_rdataset(name, rdataset) elif exact: - raise DeleteNotExact(f'{method}: missing rdataset') + raise DeleteNotExact(f"{method}: missing rdataset") else: if exact and not self._name_exists(name): - raise DeleteNotExact(f'{method}: name not known') + raise DeleteNotExact(f"{method}: name not known") self._checked_delete_name(name) except IndexError: - raise TypeError(f'not enough parameters to {method}') + raise TypeError(f"not enough parameters to {method}") def _check_ended(self): if self._ended: @@ -590,8 +607,7 @@ class Transaction: raise NotImplementedError # pragma: no cover def _iterate_rdatasets(self): - """Return an iterator that yields (name, rdataset) tuples. - """ + """Return an iterator that yields (name, rdataset) tuples.""" raise NotImplementedError # pragma: no cover def _get_node(self, name): diff --git a/dns/tsig.py b/dns/tsig.py index 50b2d47e..b3f52516 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -27,6 +27,7 @@ import dns.rdataclass import dns.name import dns.rcode + class BadTime(dns.exception.DNSException): """The current time is not within the TSIG's validity time.""" @@ -97,10 +98,11 @@ class GSSTSig: In order to avoid a direct GSSAPI dependency, the keyring holds a ref to the GSSAPI object required, rather than the key itself. """ + def __init__(self, gssapi_context): self.gssapi_context = gssapi_context - self.data = b'' - self.name = 'gss-tsig' + self.data = b"" + self.name = "gss-tsig" def update(self, data): self.data += data @@ -139,9 +141,9 @@ class GSSTSigAdapter: # client to complete the GSSAPI negotiation before attempting # to verify the signed response to a TKEY message exchange try: - rrset = message.find_rrset(message.answer, keyname, - dns.rdataclass.ANY, - dns.rdatatype.TKEY) + rrset = message.find_rrset( + message.answer, keyname, dns.rdataclass.ANY, dns.rdatatype.TKEY + ) if rrset: token = rrset[0].key gssapi_context = key.secret @@ -172,8 +174,9 @@ class HMACTSig: try: hashinfo = self._hashes[algorithm] except KeyError: - raise NotImplementedError(f"TSIG algorithm {algorithm} " + - "is not supported") + raise NotImplementedError( + f"TSIG algorithm {algorithm} " + "is not supported" + ) # create the HMAC context if isinstance(hashinfo, tuple): @@ -184,7 +187,7 @@ class HMACTSig: self.size = None self.name = self.hmac_context.name if self.size: - self.name += f'-{self.size}' + self.name += f"-{self.size}" def update(self, data): return self.hmac_context.update(data) @@ -203,8 +206,7 @@ class HMACTSig: raise BadSignature -def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, - multi=None): +def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None): """Return a context containing the TSIG rdata for the input parameters @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object @raises ValueError: I{other_data} is too long @@ -215,25 +217,25 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, if first: ctx = get_context(key) if request_mac: - ctx.update(struct.pack('!H', len(request_mac))) + ctx.update(struct.pack("!H", len(request_mac))) ctx.update(request_mac) - ctx.update(struct.pack('!H', rdata.original_id)) + ctx.update(struct.pack("!H", rdata.original_id)) ctx.update(wire[2:]) if first: ctx.update(key.name.to_digestable()) - ctx.update(struct.pack('!H', dns.rdataclass.ANY)) - ctx.update(struct.pack('!I', 0)) + ctx.update(struct.pack("!H", dns.rdataclass.ANY)) + ctx.update(struct.pack("!I", 0)) if time is None: time = rdata.time_signed - upper_time = (time >> 32) & 0xffff - lower_time = time & 0xffffffff - time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge) + upper_time = (time >> 32) & 0xFFFF + lower_time = time & 0xFFFFFFFF + time_encoded = struct.pack("!HIH", upper_time, lower_time, rdata.fudge) other_len = len(rdata.other) if other_len > 65535: - raise ValueError('TSIG Other Data is > 65535 bytes') + raise ValueError("TSIG Other Data is > 65535 bytes") if first: ctx.update(key.algorithm.to_digestable() + time_encoded) - ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other) + ctx.update(struct.pack("!HH", rdata.error, other_len) + rdata.other) else: ctx.update(time_encoded) return ctx @@ -246,7 +248,7 @@ def _maybe_start_digest(key, mac, multi): """ if multi: ctx = get_context(key) - ctx.update(struct.pack('!H', len(mac))) + ctx.update(struct.pack("!H", len(mac))) ctx.update(mac) return ctx else: @@ -269,8 +271,9 @@ def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): return (tsig, _maybe_start_digest(key, mac, multi)) -def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, - multi=False): +def validate( + wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, multi=False +): """Validate the specified TSIG rdata against the other input parameters. @raises FormError: The TSIG is badly formed. @@ -294,7 +297,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, elif rdata.error == dns.rcode.BADTRUNC: raise PeerBadTruncation else: - raise PeerError('unknown TSIG error code %d' % rdata.error) + raise PeerError("unknown TSIG error code %d" % rdata.error) if abs(rdata.time_signed - now) > rdata.fudge: raise BadTime if key.name != owner: @@ -332,14 +335,15 @@ class Key: self.algorithm = algorithm def __eq__(self, other): - return (isinstance(other, Key) and - self.name == other.name and - self.secret == other.secret and - self.algorithm == other.algorithm) + return ( + isinstance(other, Key) + and self.name == other.name + and self.secret == other.secret + and self.algorithm == other.algorithm + ) def __repr__(self): - r = f" Dict[str, Any]: @rtype: dict""" textring = {} + def b64encode(secret): return base64.encodebytes(secret).decode().rstrip() + for (name, key) in keyring.items(): tname = name.to_text() if isinstance(key, bytes): diff --git a/dns/ttl.py b/dns/ttl.py index 9f5730e7..264b0338 100644 --- a/dns/ttl.py +++ b/dns/ttl.py @@ -62,15 +62,15 @@ def from_text(text: str) -> int: if need_digit: raise BadTTL c = c.lower() - if c == 'w': + if c == "w": total += current * 604800 - elif c == 'd': + elif c == "d": total += current * 86400 - elif c == 'h': + elif c == "h": total += current * 3600 - elif c == 'm': + elif c == "m": total += current * 60 - elif c == 's': + elif c == "s": total += current else: raise BadTTL("unknown unit '%s'" % c) @@ -89,4 +89,4 @@ def make(value: Union[int, str]) -> int: elif isinstance(value, str): return dns.ttl.from_text(value) else: - raise ValueError('cannot convert value to TTL') + raise ValueError("cannot convert value to TTL") diff --git a/dns/update.py b/dns/update.py index eb7b9364..91c8aa49 100644 --- a/dns/update.py +++ b/dns/update.py @@ -31,6 +31,7 @@ import dns.tsig class UpdateSection(dns.enum.IntEnum): """Update sections""" + ZONE = 0 PREREQ = 1 UPDATE = 2 @@ -46,11 +47,15 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] # ignore the mypy error here as we mean to use a different enum _section_enum = UpdateSection # type: ignore - def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None, - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, - keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None, - keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm, - id: Optional[int]=None): + def __init__( + self, + zone: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + keyring: Optional[Any] = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + id: Optional[int] = None, + ): """Initialize a new DNS Update object. See the documentation of the Message class for a complete @@ -74,8 +79,14 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] rdclass = dns.rdataclass.RdataClass.make(rdclass) self.zone_rdclass = rdclass if self.origin: - self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA, - create=True, force_unique=True) + self.find_rrset( + self.zone, + self.origin, + rdclass, + dns.rdatatype.SOA, + create=True, + force_unique=True, + ) if keyring is not None: self.use_tsig(keyring, keyname, algorithm=keyalgorithm) @@ -112,8 +123,9 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if section is None: section = self.update covers = rd.covers() - rrset = self.find_rrset(section, name, self.zone_rdclass, rd.rdtype, - covers, deleting, True, True) + rrset = self.find_rrset( + section, name, self.zone_rdclass, rd.rdtype, covers, deleting, True, True + ) rrset.add(rd, ttl) def _add(self, replace, section, name, *args): @@ -153,8 +165,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if replace: self.delete(name, rdtype) for s in args: - rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, - self.origin) + rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, self.origin) self._add_rr(name, ttl, rd, section=section) def add(self, name: Union[dns.name.Name, str], *args: Any) -> None: @@ -190,9 +201,16 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.update, name, dns.rdataclass.ANY, - dns.rdatatype.ANY, dns.rdatatype.NONE, - dns.rdataclass.ANY, True, True) + self.find_rrset( + self.update, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) elif isinstance(args[0], dns.rdataset.Rdataset): for rds in args: for rd in rds: @@ -205,15 +223,24 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] else: rdtype = dns.rdatatype.RdataType.make(largs.pop(0)) if len(largs) == 0: - self.find_rrset(self.update, name, - self.zone_rdclass, rdtype, - dns.rdatatype.NONE, - dns.rdataclass.ANY, - True, True) + self.find_rrset( + self.update, + name, + self.zone_rdclass, + rdtype, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) else: for s in largs: - rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, # type: ignore[arg-type] - self.origin) + rd = dns.rdata.from_text( + self.zone_rdclass, + rdtype, + s, # type: ignore[arg-type] + self.origin, + ) self._add_rr(name, 0, rd, dns.rdataclass.NONE) def replace(self, name: Union[dns.name.Name, str], *args: Any) -> None: @@ -252,13 +279,21 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.prerequisite, name, - dns.rdataclass.ANY, dns.rdatatype.ANY, - dns.rdatatype.NONE, None, - True, True) - elif isinstance(args[0], dns.rdataset.Rdataset) or \ - isinstance(args[0], dns.rdata.Rdata) or \ - len(args) > 1: + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) + elif ( + isinstance(args[0], dns.rdataset.Rdataset) + or isinstance(args[0], dns.rdata.Rdata) + or len(args) > 1 + ): if not isinstance(args[0], dns.rdataset.Rdataset): # Add a 0 TTL largs = list(args) @@ -268,29 +303,50 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self._add(False, self.prerequisite, name, *args) else: rdtype = dns.rdatatype.RdataType.make(args[0]) - self.find_rrset(self.prerequisite, name, - dns.rdataclass.ANY, rdtype, - dns.rdatatype.NONE, None, - True, True) - - def absent(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str]=None) -> None: + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) + + def absent( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = None, + ) -> None: """Require that an owner name (and optionally an rdata type) does not exist as a prerequisite to the execution of the update.""" if isinstance(name, str): name = dns.name.from_text(name, None) if rdtype is None: - self.find_rrset(self.prerequisite, name, - dns.rdataclass.NONE, dns.rdatatype.ANY, - dns.rdatatype.NONE, None, - True, True) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) else: the_rdtype = dns.rdatatype.RdataType.make(rdtype) - self.find_rrset(self.prerequisite, name, - dns.rdataclass.NONE, the_rdtype, - dns.rdatatype.NONE, None, - True, True) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + the_rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) def _get_one_rr_per_rrset(self, value): # Updates are always one_rr_per_rrset @@ -300,9 +356,11 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] deleting = None empty = False if section == UpdateSection.ZONE: - if dns.rdataclass.is_metaclass(rdclass) or \ - rdtype != dns.rdatatype.SOA or \ - self.zone: + if ( + dns.rdataclass.is_metaclass(rdclass) + or rdtype != dns.rdatatype.SOA + or self.zone + ): raise dns.exception.FormError else: if not self.zone: @@ -310,10 +368,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): deleting = rdclass rdclass = self.zone[0].rdclass - empty = (deleting == dns.rdataclass.ANY or - section == UpdateSection.PREREQ) + empty = ( + deleting == dns.rdataclass.ANY or section == UpdateSection.PREREQ + ) return (rdclass, rdtype, deleting, empty) + # backwards compatibility Update = UpdateMessage diff --git a/dns/version.py b/dns/version.py index c6cdf6c6..66e8faaa 100644 --- a/dns/version.py +++ b/dns/version.py @@ -28,16 +28,31 @@ RELEASELEVEL = 0x00 #: SERIAL SERIAL = 0 -if RELEASELEVEL == 0x0f: # pragma: no cover lgtm[py/unreachable-statement] +if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement] #: version - version = '%d.%d.%d' % (MAJOR, MINOR, MICRO) # lgtm[py/unreachable-statement] + version = "%d.%d.%d" % (MAJOR, MINOR, MICRO) # lgtm[py/unreachable-statement] elif RELEASELEVEL == 0x00: # pragma: no cover lgtm[py/unreachable-statement] - version = '%d.%d.%ddev%d' % (MAJOR, MINOR, MICRO, SERIAL) # lgtm[py/unreachable-statement] -elif RELEASELEVEL == 0x0c: # pragma: no cover lgtm[py/unreachable-statement] - version = '%d.%d.%drc%d' % (MAJOR, MINOR, MICRO, SERIAL) # lgtm[py/unreachable-statement] + version = "%d.%d.%ddev%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] +elif RELEASELEVEL == 0x0C: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%drc%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] else: # pragma: no cover lgtm[py/unreachable-statement] - version = '%d.%d.%d%x%d' % (MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL) # lgtm[py/unreachable-statement] + version = "%d.%d.%d%x%d" % ( + MAJOR, + MINOR, + MICRO, + RELEASELEVEL, + SERIAL, + ) # lgtm[py/unreachable-statement] #: hexversion -hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | \ - SERIAL +hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | SERIAL diff --git a/dns/versioned.py b/dns/versioned.py index 9ed9cef6..5cf29e99 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -5,10 +5,11 @@ from typing import Callable, Deque, Optional, Set, Union import collections + try: import threading as _threading except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore + import dummy_threading as _threading # type: ignore import dns.exception import dns.immutable @@ -36,15 +37,25 @@ Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): # lgtm[py/missing-equals] - __slots__ = ['_versions', '_versions_lock', '_write_txn', - '_write_waiters', '_write_event', '_pruning_policy', - '_readers'] + __slots__ = [ + "_versions", + "_versions_lock", + "_write_txn", + "_write_waiters", + "_write_event", + "_pruning_policy", + "_readers", + ] node_factory = Node - def __init__(self, origin: Optional[Union[dns.name.Name, str]], - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True, - pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None): + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None, + ): """Initialize a versioned zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -71,13 +82,15 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] self._write_event: Optional[_threading.Event] = None self._write_waiters: Deque[_threading.Event] = collections.deque() self._readers: Set[Transaction] = set() - self._commit_version_unlocked(None, - WritableVersion(self, replacement=True), - origin) + self._commit_version_unlocked( + None, WritableVersion(self, replacement=True), origin + ) - def reader(self, id: Optional[int]=None, serial: Optional[int]=None) -> Transaction: # pylint: disable=arguments-differ + def reader( + self, id: Optional[int] = None, serial: Optional[int] = None + ) -> Transaction: # pylint: disable=arguments-differ if id is not None and serial is not None: - raise ValueError('cannot specify both id and serial') + raise ValueError("cannot specify both id and serial") with self._version_lock: if id is not None: version = None @@ -86,7 +99,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] version = v break if version is None: - raise KeyError('version not found') + raise KeyError("version not found") elif serial is not None: if self.relativize: oname = dns.name.empty @@ -102,14 +115,14 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] version = v break if version is None: - raise KeyError('serial not found') + raise KeyError("serial not found") else: version = self._versions[-1] txn = Transaction(self, False, version) self._readers.add(txn) return txn - def writer(self, replacement: bool=False) -> Transaction: + def writer(self, replacement: bool = False) -> Transaction: event = None while True: with self._version_lock: @@ -123,8 +136,9 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] # give up the lock, so that we hold the lock as # short a time as possible. This is why we call # _setup_version() below. - self._write_txn = Transaction(self, replacement, - make_immutable=True) + self._write_txn = Transaction( + self, replacement, make_immutable=True + ) # give up our exclusive right to make a Transaction self._write_event = None break @@ -165,6 +179,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] # pylint: disable=unused-argument def _default_pruning_policy(self, zone, version): return True + # pylint: enable=unused-argument def _prune_versions_unlocked(self): @@ -180,8 +195,9 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] least_kept = min(txn.version.id for txn in self._readers) else: least_kept = self._versions[-1].id - while self._versions[0].id < least_kept and \ - self._pruning_policy(self, self._versions[0]): + while self._versions[0].id < least_kept and self._pruning_policy( + self, self._versions[0] + ): self._versions.popleft() def set_max_versions(self, max_versions: Optional[int]) -> None: @@ -189,16 +205,22 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] of versions """ if max_versions is not None and max_versions < 1: - raise ValueError('max versions must be at least 1') + raise ValueError("max versions must be at least 1") if max_versions is None: + def policy(zone, _): # pylint: disable=unused-argument return False + else: + def policy(zone, _): return len(zone._versions) > max_versions + self.set_pruning_policy(policy) - def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]) -> None: + def set_pruning_policy( + self, policy: Optional[Callable[["Zone", Version], Optional[bool]]] + ) -> None: """Set the pruning policy for the zone. The *policy* function takes a `Version` and returns `True` if @@ -251,7 +273,9 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] id = 1 return id - def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node: + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: if create: raise UseTransaction return super().find_node(name) @@ -259,19 +283,25 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] def delete_node(self, name: Union[dns.name.Name, str]) -> None: raise UseTransaction - def find_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, - create: bool=False) -> dns.rdataset.Rdataset: + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise UseTransaction rdataset = super().find_rdataset(name, rdtype, covers) return dns.rdataset.ImmutableRdataset(rdataset) - def get_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, - create: bool=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise UseTransaction rdataset = super().get_rdataset(name, rdtype, covers) @@ -280,10 +310,15 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] else: return None - def delete_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None: + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: raise UseTransaction - def replace_rdataset(self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset) -> None: + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: raise UseTransaction diff --git a/dns/win32util.py b/dns/win32util.py index f4ded206..7a17b0bb 100644 --- a/dns/win32util.py +++ b/dns/win32util.py @@ -1,6 +1,6 @@ import sys -if sys.platform == 'win32': +if sys.platform == "win32": from typing import Any @@ -14,9 +14,10 @@ if sys.platform == 'win32': try: import threading as _threading except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore + import dummy_threading as _threading # type: ignore import pythoncom import wmi + _have_wmi = True except Exception: _have_wmi = False @@ -25,7 +26,7 @@ if sys.platform == 'win32': # Sometimes DHCP servers add a '.' prefix to the default domain, and # Windows just stores such values in the registry (see #687). # Check for this and fix it. - if domain.startswith('.'): + if domain.startswith("."): domain = domain[1:] return dns.name.from_text(domain) @@ -36,6 +37,7 @@ if sys.platform == 'win32': self.search = [] if _have_wmi: + class _WMIGetter(_threading.Thread): def __init__(self): super().__init__() @@ -49,8 +51,10 @@ if sys.platform == 'win32': if interface.IPEnabled: self.info.domain = _config_domain(interface.DNSDomain) self.info.nameservers = list(interface.DNSServerSearchOrder) - self.info.search = [dns.name.from_text(x) for x in - interface.DNSDomainSuffixSearchOrder] + self.info.search = [ + dns.name.from_text(x) + for x in interface.DNSDomainSuffixSearchOrder + ] break finally: pythoncom.CoUninitialize() @@ -61,11 +65,12 @@ if sys.platform == 'win32': self.start() self.join() return self.info + else: + class _WMIGetter: # type: ignore pass - class _RegistryGetter: def __init__(self): self.info = DnsInfo() @@ -76,13 +81,13 @@ if sys.platform == 'win32': # delimiter in between ' ' and ',' (and vice-versa) in various # versions of windows. # - if entry.find(' ') >= 0: - split_char = ' ' - elif entry.find(',') >= 0: - split_char = ',' + if entry.find(" ") >= 0: + split_char = " " + elif entry.find(",") >= 0: + split_char = "," else: # probably a singleton; treat as a space-separated list. - split_char = ' ' + split_char = " " return split_char def _config_nameservers(self, nameservers): @@ -102,38 +107,38 @@ if sys.platform == 'win32': def _config_fromkey(self, key, always_try_domain): try: - servers, _ = winreg.QueryValueEx(key, 'NameServer') + servers, _ = winreg.QueryValueEx(key, "NameServer") except WindowsError: servers = None if servers: self._config_nameservers(servers) if servers or always_try_domain: try: - dom, _ = winreg.QueryValueEx(key, 'Domain') + dom, _ = winreg.QueryValueEx(key, "Domain") if dom: self.info.domain = _config_domain(dom) except WindowsError: pass else: try: - servers, _ = winreg.QueryValueEx(key, 'DhcpNameServer') + servers, _ = winreg.QueryValueEx(key, "DhcpNameServer") except WindowsError: servers = None if servers: self._config_nameservers(servers) try: - dom, _ = winreg.QueryValueEx(key, 'DhcpDomain') + dom, _ = winreg.QueryValueEx(key, "DhcpDomain") if dom: self.info.domain = _config_domain(dom) except WindowsError: pass try: - search, _ = winreg.QueryValueEx(key, 'SearchList') + search, _ = winreg.QueryValueEx(key, "SearchList") except WindowsError: search = None if search is None: try: - search, _ = winreg.QueryValueEx(key, 'DhcpSearchList') + search, _ = winreg.QueryValueEx(key, "DhcpSearchList") except WindowsError: search = None if search: @@ -150,25 +155,27 @@ if sys.platform == 'win32': # from Windows 2000 through Vista. connection_key = winreg.OpenKey( lm, - r'SYSTEM\CurrentControlSet\Control\Network' - r'\{4D36E972-E325-11CE-BFC1-08002BE10318}' - r'\%s\Connection' % guid) + r"SYSTEM\CurrentControlSet\Control\Network" + r"\{4D36E972-E325-11CE-BFC1-08002BE10318}" + r"\%s\Connection" % guid, + ) try: # The PnpInstanceID points to a key inside Enum (pnp_id, ttype) = winreg.QueryValueEx( - connection_key, 'PnpInstanceID') + connection_key, "PnpInstanceID" + ) if ttype != winreg.REG_SZ: raise ValueError # pragma: no cover device_key = winreg.OpenKey( - lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id) + lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id + ) try: # Get ConfigFlags for this device - (flags, ttype) = winreg.QueryValueEx( - device_key, 'ConfigFlags') + (flags, ttype) = winreg.QueryValueEx(device_key, "ConfigFlags") if ttype != winreg.REG_DWORD: raise ValueError # pragma: no cover @@ -194,17 +201,19 @@ if sys.platform == 'win32': lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: - tcp_params = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters') + tcp_params = winreg.OpenKey( + lm, r"SYSTEM\CurrentControlSet" r"\Services\Tcpip\Parameters" + ) try: self._config_fromkey(tcp_params, True) finally: tcp_params.Close() - interfaces = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters' - r'\Interfaces') + interfaces = winreg.OpenKey( + lm, + r"SYSTEM\CurrentControlSet" + r"\Services\Tcpip\Parameters" + r"\Interfaces", + ) try: i = 0 while True: diff --git a/dns/wire.py b/dns/wire.py index 905930f7..cadf1686 100644 --- a/dns/wire.py +++ b/dns/wire.py @@ -8,8 +8,9 @@ import struct import dns.exception import dns.name + class Parser: - def __init__(self, wire: bytes, current: int=0): + def __init__(self, wire: bytes, current: int = 0): self.wire = wire self.current = 0 self.end = len(self.wire) @@ -24,34 +25,34 @@ class Parser: assert size >= 0 if size > self.remaining(): raise dns.exception.FormError - output = self.wire[self.current:self.current + size] + output = self.wire[self.current : self.current + size] self.current += size self.furthest = max(self.furthest, self.current) return output - def get_counted_bytes(self, length_size: int=1) -> bytes: - length = int.from_bytes(self.get_bytes(length_size), 'big') + def get_counted_bytes(self, length_size: int = 1) -> bytes: + length = int.from_bytes(self.get_bytes(length_size), "big") return self.get_bytes(length) def get_remaining(self) -> bytes: return self.get_bytes(self.remaining()) def get_uint8(self) -> int: - return struct.unpack('!B', self.get_bytes(1))[0] + return struct.unpack("!B", self.get_bytes(1))[0] def get_uint16(self) -> int: - return struct.unpack('!H', self.get_bytes(2))[0] + return struct.unpack("!H", self.get_bytes(2))[0] def get_uint32(self) -> int: - return struct.unpack('!I', self.get_bytes(4))[0] + return struct.unpack("!I", self.get_bytes(4))[0] def get_uint48(self) -> int: - return int.from_bytes(self.get_bytes(6), 'big') + return int.from_bytes(self.get_bytes(6), "big") def get_struct(self, format: str) -> Tuple: return struct.unpack(format, self.get_bytes(struct.calcsize(format))) - def get_name(self, origin: Optional['dns.name.Name']=None) -> 'dns.name.Name': + def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name": name = dns.name.from_wire_parser(self) if origin: name = name.relativize(origin) diff --git a/dns/xfr.py b/dns/xfr.py index a360deba..89e92caf 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -33,7 +33,7 @@ class TransferError(dns.exception.DNSException): """A zone transfer response got a non-zero rcode.""" def __init__(self, rcode): - message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + message = "Zone transfer error: %s" % dns.rcode.to_text(rcode) super().__init__(message) self.rcode = rcode @@ -51,9 +51,13 @@ class Inbound: State machine for zone transfers. """ - def __init__(self, txn_manager: dns.transaction.TransactionManager, - rdtype: dns.rdatatype.RdataType=dns.rdatatype.AXFR, - serial: Optional[int]=None, is_udp: bool=False): + def __init__( + self, + txn_manager: dns.transaction.TransactionManager, + rdtype: dns.rdatatype.RdataType = dns.rdatatype.AXFR, + serial: Optional[int] = None, + is_udp: bool = False, + ): """Initialize an inbound zone transfer. *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. @@ -71,9 +75,9 @@ class Inbound: self.rdtype = rdtype if rdtype == dns.rdatatype.IXFR: if serial is None: - raise ValueError('a starting serial must be supplied for IXFRs') + raise ValueError("a starting serial must be supplied for IXFRs") elif is_udp: - raise ValueError('is_udp specified for AXFR') + raise ValueError("is_udp specified for AXFR") self.serial = serial self.is_udp = is_udp (_, _, self.origin) = txn_manager.origin_information() @@ -113,8 +117,9 @@ class Inbound: # the origin. # if not message.answer or message.answer[0].name != self.origin: - raise dns.exception.FormError("No answer or RRset not " - "for zone origin") + raise dns.exception.FormError( + "No answer or RRset not " "for zone origin" + ) rrset = message.answer[0] rdataset = rrset if rdataset.rdtype != dns.rdatatype.SOA: @@ -127,8 +132,7 @@ class Inbound: # We're already up-to-date. # self.done = True - elif dns.serial.Serial(self.soa_rdataset[0].serial) < \ - self.serial: + elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial: # It went backwards! raise SerialWentBackwards else: @@ -153,8 +157,7 @@ class Inbound: if self.done: raise dns.exception.FormError("answers after final SOA") assert self.txn is not None # for mypy - if rdataset.rdtype == dns.rdatatype.SOA and \ - name == self.origin: + if rdataset.rdtype == dns.rdatatype.SOA and name == self.origin: # # Every time we see an origin SOA delete_mode inverts # @@ -166,20 +169,23 @@ class Inbound: # check that we're seeing the record in the expected # part of the response. # - if rdataset == self.soa_rdataset and \ - (self.rdtype == dns.rdatatype.AXFR or - (self.rdtype == dns.rdatatype.IXFR and - self.delete_mode)): + if rdataset == self.soa_rdataset and ( + self.rdtype == dns.rdatatype.AXFR + or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode) + ): # # This is the final SOA # if self.expecting_SOA: # We got an empty IXFR sequence! - raise dns.exception.FormError('empty IXFR sequence') - if self.rdtype == dns.rdatatype.IXFR \ - and self.serial != rdataset[0].serial: - raise dns.exception.FormError('unexpected end of IXFR ' - 'sequence') + raise dns.exception.FormError("empty IXFR sequence") + if ( + self.rdtype == dns.rdatatype.IXFR + and self.serial != rdataset[0].serial + ): + raise dns.exception.FormError( + "unexpected end of IXFR " "sequence" + ) self.txn.replace(name, rdataset) self.txn.commit() self.txn = None @@ -194,15 +200,17 @@ class Inbound: # This is the start of an IXFR deletion set if rdataset[0].serial != self.serial: raise dns.exception.FormError( - "IXFR base serial mismatch") + "IXFR base serial mismatch" + ) else: # This is the start of an IXFR addition set self.serial = rdataset[0].serial self.txn.replace(name, rdataset) else: # We saw a non-final SOA for the origin in an AXFR. - raise dns.exception.FormError('unexpected origin SOA ' - 'in AXFR') + raise dns.exception.FormError( + "unexpected origin SOA " "in AXFR" + ) continue if self.expecting_SOA: # @@ -229,7 +237,7 @@ class Inbound: # This is a UDP IXFR and we didn't get to done, and we didn't # get the proper "truncated" response # - raise dns.exception.FormError('unexpected end of UDP IXFR') + raise dns.exception.FormError("unexpected end of UDP IXFR") return self.done # @@ -245,12 +253,18 @@ class Inbound: return False -def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0, - use_edns: Optional[Union[int, bool]]=None, ednsflags: Optional[int]=None, payload: Optional[int]=None, - request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, - keyring: Any=None, keyname: Optional[dns.name.Name]=None, - keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) \ - -> Tuple[dns.message.QueryMessage, Optional[int]]: +def make_query( + txn_manager: dns.transaction.TransactionManager, + serial: Optional[int] = 0, + use_edns: Optional[Union[int, bool]] = None, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + keyring: Any = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Tuple[dns.message.QueryMessage, Optional[int]]: """Make an AXFR or IXFR query. *txn_manager* is a ``dns.transaction.TransactionManager``, typically a @@ -272,14 +286,14 @@ def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional """ (zone_origin, _, origin) = txn_manager.origin_information() if zone_origin is None: - raise ValueError('no zone origin') + raise ValueError("no zone origin") if serial is None: rdtype = dns.rdatatype.AXFR elif not isinstance(serial, int): - raise ValueError('serial is not an integer') + raise ValueError("serial is not an integer") elif serial == 0: with txn_manager.reader() as txn: - rdataset = txn.get(origin, 'SOA') + rdataset = txn.get(origin, "SOA") if rdataset: serial = rdataset[0].serial rdtype = dns.rdatatype.IXFR @@ -289,20 +303,30 @@ def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional elif serial > 0 and serial < 4294967296: rdtype = dns.rdatatype.IXFR else: - raise ValueError('serial out-of-range') + raise ValueError("serial out-of-range") rdclass = txn_manager.get_class() - q = dns.message.make_query(zone_origin, rdtype, rdclass, - use_edns, False, ednsflags, payload, - request_payload, options) + q = dns.message.make_query( + zone_origin, + rdtype, + rdclass, + use_edns, + False, + ednsflags, + payload, + request_payload, + options, + ) if serial is not None: - rdata = dns.rdata.from_text(rdclass, 'SOA', f'. . {serial} 0 0 0 0') - rrset = q.find_rrset(q.authority, zone_origin, rdclass, - dns.rdatatype.SOA, create=True) + rdata = dns.rdata.from_text(rdclass, "SOA", f". . {serial} 0 0 0 0") + rrset = q.find_rrset( + q.authority, zone_origin, rdclass, dns.rdatatype.SOA, create=True + ) rrset.add(rdata, 0) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) return (q, serial) + def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: """Extract the SOA serial number from query if it is an IXFR and return it, otherwise return None. @@ -313,12 +337,13 @@ def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: an appropriate SOA RRset in the authority section. """ if not isinstance(query, dns.message.QueryMessage): - raise ValueError('query not a QueryMessage') + raise ValueError("query not a QueryMessage") question = query.question[0] if question.rdtype == dns.rdatatype.AXFR: return None elif question.rdtype != dns.rdatatype.IXFR: raise ValueError("query is not an AXFR or IXFR") - soa = query.find_rrset(query.authority, question.name, question.rdclass, - dns.rdatatype.SOA) + soa = query.find_rrset( + query.authority, question.name, question.rdclass, dns.rdatatype.SOA + ) return soa[0].serial diff --git a/dns/zone.py b/dns/zone.py index d57838c6..d0d99284 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -97,10 +97,14 @@ class Zone(dns.transaction.TransactionManager): node_factory = dns.node.Node - __slots__ = ['rdclass', 'origin', 'nodes', 'relativize'] - - def __init__(self, origin: Optional[Union[dns.name.Name, str]], - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True): + __slots__ = ["rdclass", "origin", "nodes", "relativize"] + + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + ): """Initialize a zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -117,8 +121,9 @@ class Zone(dns.transaction.TransactionManager): if isinstance(origin, str): origin = dns.name.from_text(origin) elif not isinstance(origin, dns.name.Name): - raise ValueError("origin parameter must be convertible to a " - "DNS name") + raise ValueError( + "origin parameter must be convertible to a " "DNS name" + ) if not origin.is_absolute(): raise ValueError("origin parameter must be an absolute name") self.origin = origin @@ -135,9 +140,11 @@ class Zone(dns.transaction.TransactionManager): if not isinstance(other, Zone): return False - if self.rdclass != other.rdclass or \ - self.origin != other.origin or \ - self.nodes != other.nodes: + if ( + self.rdclass != other.rdclass + or self.origin != other.origin + or self.nodes != other.nodes + ): return False return True @@ -159,16 +166,15 @@ class Zone(dns.transaction.TransactionManager): # This should probably never happen as other code (e.g. # _rr_line) will notice the lack of an origin before us, but # we check just in case! - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") if not name.is_subdomain(self.origin): - raise KeyError( - "name parameter must be a subdomain of the zone origin") + raise KeyError("name parameter must be a subdomain of the zone origin") if self.relativize: name = name.relativize(self.origin) elif not self.relativize: # We have a relative name in a non-relative zone, so derelativize. if self.origin is None: - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") name = name.derelativize(self.origin) return name @@ -204,7 +210,9 @@ class Zone(dns.transaction.TransactionManager): key = self._validate_name(key) return key in self.nodes - def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node: + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: """Find a node in the zone, possibly creating it. *name*: the name of the node to find. @@ -230,7 +238,9 @@ class Zone(dns.transaction.TransactionManager): self.nodes[name] = node return node - def get_node(self, name: Union[dns.name.Name, str], create: bool=False) -> Optional[dns.node.Node]: + def get_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> Optional[dns.node.Node]: """Get a node in the zone, possibly creating it. This method is like ``find_node()``, except it returns None instead @@ -272,10 +282,13 @@ class Zone(dns.transaction.TransactionManager): if name in self.nodes: del self.nodes[name] - def find_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, - create: bool=False) -> dns.rdataset.Rdataset: + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: """Look for an rdataset with the specified name and type in the zone, and return an rdataset encapsulating it. @@ -316,10 +329,13 @@ class Zone(dns.transaction.TransactionManager): node = self.find_node(the_name, create) return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) - def get_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, - create: bool=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: """Look for an rdataset with the specified name and type in the zone. This method is like ``find_rdataset()``, except it returns None instead @@ -361,34 +377,33 @@ class Zone(dns.transaction.TransactionManager): rdataset = None return rdataset - def delete_rdataset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None: + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: """Delete the rdataset matching *rdtype* and *covers*, if it exists at the node specified by *name*. - It is not an error if the node does not exist, or if there is no - matching rdataset at the node. + It is not an error if the node does not exist, or if there is no matching + rdataset at the node. - If the node has no rdatasets after the deletion, it will itself - be deleted. + If the node has no rdatasets after the deletion, it will itself be deleted. - *name*: the name of the node to find. - The value may be a ``dns.name.Name`` or a ``str``. If absolute, the - name must be a subdomain of the zone's origin. If ``zone.relativize`` - is ``True``, then the name will be relativized. + *name*: the name of the node to find. The value may be a ``dns.name.Name`` or a + ``str``. If absolute, the name must be a subdomain of the zone's origin. If + ``zone.relativize`` is ``True``, then the name will be relativized. *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered type. - Usually this value is ``dns.rdatatype.NONE``, but if the - rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, - then the covers value will be the rdata type the SIG/RRSIG - covers. The library treats the SIG and RRSIG types as if they - were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). - This makes RRSIGs much easier to work with than if RRSIGs - covering different rdata types were aggregated into a single - RRSIG rdataset. + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered + type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is + ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be + the rdata type the SIG/RRSIG covers. The library treats the SIG and RRSIG types + as if they were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This + makes RRSIGs much easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. """ the_name = self._validate_name(name) @@ -400,8 +415,9 @@ class Zone(dns.transaction.TransactionManager): if len(node) == 0: self.delete_node(the_name) - def replace_rdataset(self, name: Union[dns.name.Name, str], - replacement: dns.rdataset.Rdataset) -> None: + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: """Replace an rdataset at name. It is not an error if there is no rdataset matching I{replacement}. @@ -421,13 +437,16 @@ class Zone(dns.transaction.TransactionManager): """ if replacement.rdclass != self.rdclass: - raise ValueError('replacement.rdclass != zone.rdclass') + raise ValueError("replacement.rdclass != zone.rdclass") node = self.find_node(name, True) node.replace_rdataset(replacement) - def find_rrset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rrset.RRset: + def find_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rrset.RRset: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -474,9 +493,12 @@ class Zone(dns.transaction.TransactionManager): rrset.update(rdataset) return rrset - def get_rrset(self, name: Union[dns.name.Name, str], - rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> Optional[dns.rrset.RRset]: + def get_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Optional[dns.rrset.RRset]: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -520,9 +542,11 @@ class Zone(dns.transaction.TransactionManager): rrset = None return rrset - def iterate_rdatasets(self, rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.ANY, - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) \ - -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + def iterate_rdatasets( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: """Return a generator which yields (name, rdataset) tuples for all rdatasets in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, @@ -545,13 +569,16 @@ class Zone(dns.transaction.TransactionManager): covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: - if rdtype == dns.rdatatype.ANY or \ - (rds.rdtype == rdtype and rds.covers == covers): + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): yield (name, rds) - def iterate_rdatas(self, rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.ANY, - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) \ - -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]: + def iterate_rdatas( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]: """Return a generator which yields (name, ttl, rdata) tuples for all rdatas in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, @@ -574,13 +601,21 @@ class Zone(dns.transaction.TransactionManager): covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: - if rdtype == dns.rdatatype.ANY or \ - (rds.rdtype == rdtype and rds.covers == covers): + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): for rdata in rds: yield (name, rds.ttl, rdata) - def to_file(self, f: Any, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None, - want_comments: bool=False, want_origin: bool=False) -> None: + def to_file( + self, + f: Any, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> None: """Write a zone to a file. *f*, a file or `str`. If *f* is a string, it is treated @@ -610,18 +645,18 @@ class Zone(dns.transaction.TransactionManager): with contextlib.ExitStack() as stack: if isinstance(f, str): - f = stack.enter_context(open(f, 'wb')) + f = stack.enter_context(open(f, "wb")) # must be in this way, f.encoding may contain None, or even # attribute may not be there - file_enc = getattr(f, 'encoding', None) + file_enc = getattr(f, "encoding", None) if file_enc is None: - file_enc = 'utf-8' + file_enc = "utf-8" if nl is None: # binary mode, '\n' is not enough nl_b = os.linesep.encode(file_enc) - nl = '\n' + nl = "\n" elif isinstance(nl, str): nl_b = nl.encode(file_enc) else: @@ -630,7 +665,7 @@ class Zone(dns.transaction.TransactionManager): if want_origin: assert self.origin is not None - l = '$ORIGIN ' + self.origin.to_text() + l = "$ORIGIN " + self.origin.to_text() l_b = l.encode(file_enc) try: f.write(l_b) @@ -645,9 +680,12 @@ class Zone(dns.transaction.TransactionManager): else: names = self.keys() for n in names: - l = self[n].to_text(n, origin=self.origin, - relativize=relativize, - want_comments=want_comments) + l = self[n].to_text( + n, + origin=self.origin, + relativize=relativize, + want_comments=want_comments, + ) l_b = l.encode(file_enc) try: @@ -657,8 +695,14 @@ class Zone(dns.transaction.TransactionManager): f.write(l) f.write(nl) - def to_text(self, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None, - want_comments: bool=False, want_origin: bool=False) -> str: + def to_text( + self, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> str: """Return a zone's text as though it were written to a file. *sorted*, a ``bool``. If True, the default, then the file @@ -685,8 +729,7 @@ class Zone(dns.transaction.TransactionManager): Returns a ``str``. """ temp_buffer = io.StringIO() - self.to_file(temp_buffer, sorted, relativize, nl, want_comments, - want_origin) + self.to_file(temp_buffer, sorted, relativize, nl, want_comments, want_origin) return_value = temp_buffer.getvalue() temp_buffer.close() return return_value @@ -710,7 +753,9 @@ class Zone(dns.transaction.TransactionManager): if self.get_rdataset(name, dns.rdatatype.NS) is None: raise NoNS - def get_soa(self, txn: Optional[dns.transaction.Transaction]=None) -> dns.rdtypes.ANY.SOA.SOA: + def get_soa( + self, txn: Optional[dns.transaction.Transaction] = None + ) -> dns.rdtypes.ANY.SOA.SOA: """Get the zone SOA rdata. Raises ``dns.zone.NoSOA`` if there is no SOA RRset. @@ -734,7 +779,11 @@ class Zone(dns.transaction.TransactionManager): raise NoSOA return soa[0] - def _compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme: DigestScheme=DigestScheme.SIMPLE) -> bytes: + def _compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> bytes: hashinfo = _digest_hashers.get(hash_algorithm) if not hashinfo: raise UnsupportedDigestHashAlgorithm @@ -749,30 +798,35 @@ class Zone(dns.transaction.TransactionManager): hasher = hashinfo() for (name, node) in sorted(self.items()): rrnamebuf = name.to_digestable(self.origin) - for rdataset in sorted(node, - key=lambda rds: (rds.rdtype, rds.covers)): - if name == origin_name and \ - dns.rdatatype.ZONEMD in (rdataset.rdtype, rdataset.covers): + for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): + if name == origin_name and dns.rdatatype.ZONEMD in ( + rdataset.rdtype, + rdataset.covers, + ): continue - rrfixed = struct.pack('!HHI', rdataset.rdtype, - rdataset.rdclass, rdataset.ttl) - rdatas = [rdata.to_digestable(self.origin) - for rdata in rdataset] + rrfixed = struct.pack( + "!HHI", rdataset.rdtype, rdataset.rdclass, rdataset.ttl + ) + rdatas = [rdata.to_digestable(self.origin) for rdata in rdataset] for rdata in sorted(rdatas): - rrlen = struct.pack('!H', len(rdata)) + rrlen = struct.pack("!H", len(rdata)) hasher.update(rrnamebuf + rrfixed + rrlen + rdata) return hasher.digest() - def compute_digest(self, hash_algorithm: DigestHashAlgorithm, - scheme: DigestScheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: + def compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: serial = self.get_soa().serial digest = self._compute_digest(hash_algorithm, scheme) - return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass, - dns.rdatatype.ZONEMD, - serial, scheme, hash_algorithm, - digest) + return dns.rdtypes.ANY.ZONEMD.ZONEMD( + self.rdclass, dns.rdatatype.ZONEMD, serial, scheme, hash_algorithm, digest + ) - def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None) -> None: + def verify_digest( + self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD] = None + ) -> None: digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]] if zonemd: digests = [zonemd] @@ -784,8 +838,7 @@ class Zone(dns.transaction.TransactionManager): digests = rds for digest in digests: try: - computed = self._compute_digest(digest.hash_algorithm, - digest.scheme) + computed = self._compute_digest(digest.hash_algorithm, digest.scheme) if computed == digest.digest: return except Exception: @@ -794,16 +847,17 @@ class Zone(dns.transaction.TransactionManager): # TransactionManager methods - def reader(self) -> 'Transaction': - return Transaction(self, False, - Version(self, 1, self.nodes, self.origin)) + def reader(self) -> "Transaction": + return Transaction(self, False, Version(self, 1, self.nodes, self.origin)) - def writer(self, replacement: bool=False) -> 'Transaction': + def writer(self, replacement: bool = False) -> "Transaction": txn = Transaction(self, replacement) txn._setup_version() return txn - def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: effective: Optional[dns.name.Name] if self.relativize: effective = dns.name.empty @@ -839,8 +893,9 @@ class Zone(dns.transaction.TransactionManager): # A node with a version id. + class VersionedNode(dns.node.Node): # lgtm[py/missing-equals] - __slots__ = ['id'] + __slots__ = ["id"] def __init__(self): super().__init__() @@ -850,7 +905,7 @@ class VersionedNode(dns.node.Node): # lgtm[py/missing-equals] @dns.immutable.immutable class ImmutableVersionedNode(VersionedNode): - __slots__ = ['id'] + __slots__ = ["id"] def __init__(self, node): super().__init__() @@ -859,22 +914,34 @@ class ImmutableVersionedNode(VersionedNode): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> dns.rdataset.Rdataset: + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create: bool=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: raise TypeError("immutable") def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: @@ -885,9 +952,13 @@ class ImmutableVersionedNode(VersionedNode): class Version: - def __init__(self, zone: Zone, id: int, - nodes: Optional[Dict[dns.name.Name, dns.node.Node]]=None, - origin: Optional[dns.name.Name]=None): + def __init__( + self, + zone: Zone, + id: int, + nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None, + origin: Optional[dns.name.Name] = None, + ): self.zone = zone self.id = id if nodes is not None: @@ -902,7 +973,7 @@ class Version: # This should probably never happen as other code (e.g. # _rr_line) will notice the lack of an origin before us, but # we check just in case! - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") if not name.is_subdomain(self.origin): raise KeyError("name is not a subdomain of the zone origin") if self.zone.relativize: @@ -910,7 +981,7 @@ class Version: elif not self.zone.relativize: # We have a relative name in a non-relative zone, so derelativize. if self.origin is None: - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") name = name.derelativize(self.origin) return name @@ -918,8 +989,12 @@ class Version: name = self._validate_name(name) return self.nodes.get(name) - def get_rdataset(self, name: dns.name.Name, rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> Optional[dns.rdataset.Rdataset]: node = self.get_node(name) if node is None: return None @@ -930,7 +1005,7 @@ class Version: class WritableVersion(Version): - def __init__(self, zone: Zone, replacement: bool=False): + def __init__(self, zone: Zone, replacement: bool = False): # The zone._versions_lock must be held by our caller in a versioned # zone. id = zone._get_next_version_id() @@ -951,14 +1026,14 @@ class WritableVersion(Version): node = self.nodes.get(name) if node is None or name not in self.changed: new_node = self.zone.node_factory() - if hasattr(new_node, 'id'): + if hasattr(new_node, "id"): # We keep doing this for backwards compatibility, as earlier # code used new_node.id != self.id for the "do we need to CoW?" # test. Now we use the changed set as this works with both # regular zones and versioned zones. # # We ignore the mypy error as this is safe but it doesn't see it. - new_node.id = self.id # type: ignore + new_node.id = self.id # type: ignore if node is not None: # moo! copy on write! new_node.rdatasets.extend(node.rdatasets) @@ -974,12 +1049,18 @@ class WritableVersion(Version): del self.nodes[name] self.changed.add(name) - def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset) -> None: + def put_rdataset( + self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset + ) -> None: node = self._maybe_cow(name) node.replace_rdataset(rdataset) - def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType) -> None: + def delete_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> None: node = self._maybe_cow(name) node.delete_rdataset(self.zone.rdclass, rdtype, covers) if len(node) == 0: @@ -1009,7 +1090,6 @@ class ImmutableVersion(Version): class Transaction(dns.transaction.Transaction): - def __init__(self, zone, replacement, version=None, make_immutable=False): read_only = version is not None super().__init__(zone, replacement, read_only) @@ -1086,11 +1166,17 @@ class Transaction(dns.transaction.Transaction): return (absolute, relativize, effective) -def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, - relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None, - allow_include: bool=False, check_origin: bool=True, - idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone: +def from_text( + text: str, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = False, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> Zone: """Build a zone object from a zone file format string. *text*, a ``str``, the zone file format input. @@ -1099,7 +1185,8 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, of the zone; if not specified, the first ``$ORIGIN`` statement in the zone file will determine the origin of the zone. - *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is class IN. + *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is + class IN. *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. @@ -1137,12 +1224,11 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, # interface is from_file(). if filename is None: - filename = '' + filename = "" zone = zone_factory(origin, rdclass, relativize=relativize) with zone.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) - reader = dns.zonefile.Reader(tok, rdclass, txn, - allow_include=allow_include) + reader = dns.zonefile.Reader(tok, rdclass, txn, allow_include=allow_include) try: reader.read() except dns.zonefile.UnknownOrigin: @@ -1154,10 +1240,16 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, return zone -def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None, - rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, - relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None, - allow_include: bool=True, check_origin: bool=True) -> Zone: +def from_file( + f: Any, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = True, + check_origin: bool = True, +) -> Zone: """Read a zone file and build a zone object. *f*, a file or ``str``. If *f* is a string, it is treated @@ -1205,12 +1297,25 @@ def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None, if filename is None: filename = f f = stack.enter_context(open(f)) - return from_text(f, origin, rdclass, relativize, zone_factory, - filename, allow_include, check_origin) + return from_text( + f, + origin, + rdclass, + relativize, + zone_factory, + filename, + allow_include, + check_origin, + ) assert False # make mypy happy lgtm[py/unreachable-statement] -def from_xfr(xfr: Any, zone_factory: Any=Zone, relativize: bool=True, check_origin: bool=True) -> Zone: +def from_xfr( + xfr: Any, + zone_factory: Any = Zone, + relativize: bool = True, + check_origin: bool = True, +) -> Zone: """Convert the output of a zone transfer generator into a zone object. *xfr*, a generator of ``dns.message.Message`` objects, typically @@ -1250,13 +1355,12 @@ def from_xfr(xfr: Any, zone_factory: Any=Zone, relativize: bool=True, check_orig if not znode: znode = z.node_factory() z.nodes[rrset.name] = znode - zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, - rrset.covers, True) + zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, rrset.covers, True) zrds.update_ttl(rrset.ttl) for rd in rrset: zrds.add(rd) if z is None: - raise ValueError('empty transfer') + raise ValueError("empty transfer") if check_origin: z.check_origin() return z diff --git a/dns/zonefile.py b/dns/zonefile.py index 479f0d63..fd17073d 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -51,42 +51,53 @@ def _check_cname_and_other_data(txn, name, rdataset): # empty nodes are neutral. return node_kind = node.classify() - if node_kind == dns.node.NodeKind.CNAME and \ - rdataset_kind == dns.node.NodeKind.REGULAR: - raise CNAMEAndOtherData('rdataset type is not compatible with a ' - 'CNAME node') - elif node_kind == dns.node.NodeKind.REGULAR and \ - rdataset_kind == dns.node.NodeKind.CNAME: - raise CNAMEAndOtherData('CNAME rdataset is not compatible with a ' - 'regular data node') + if ( + node_kind == dns.node.NodeKind.CNAME + and rdataset_kind == dns.node.NodeKind.REGULAR + ): + raise CNAMEAndOtherData("rdataset type is not compatible with a " "CNAME node") + elif ( + node_kind == dns.node.NodeKind.REGULAR + and rdataset_kind == dns.node.NodeKind.CNAME + ): + raise CNAMEAndOtherData( + "CNAME rdataset is not compatible with a " "regular data node" + ) # Otherwise at least one of the node and the rdataset is neutral, so # adding the rdataset is ok -SavedStateType = Tuple[dns.tokenizer.Tokenizer, - Optional[dns.name.Name], # current_origin - Optional[dns.name.Name], # last_name - Optional[Any], # current_file - int, # last_ttl - bool, # last_ttl_known - int, # default_ttl - bool] # default_ttl_known +SavedStateType = Tuple[ + dns.tokenizer.Tokenizer, + Optional[dns.name.Name], # current_origin + Optional[dns.name.Name], # last_name + Optional[Any], # current_file + int, # last_ttl + bool, # last_ttl_known + int, # default_ttl + bool, +] # default_ttl_known class Reader: """Read a DNS zone file into a transaction.""" - def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass, - txn: dns.transaction.Transaction, allow_include: bool=False, - allow_directives: bool=True, force_name: Optional[dns.name.Name]=None, - force_ttl: Optional[int]=None, - force_rdclass: Optional[dns.rdataclass.RdataClass]=None, - force_rdtype: Optional[dns.rdatatype.RdataType]=None, - default_ttl: Optional[int]=None): + def __init__( + self, + tok: dns.tokenizer.Tokenizer, + rdclass: dns.rdataclass.RdataClass, + txn: dns.transaction.Transaction, + allow_include: bool = False, + allow_directives: bool = True, + force_name: Optional[dns.name.Name] = None, + force_ttl: Optional[int] = None, + force_rdclass: Optional[dns.rdataclass.RdataClass] = None, + force_rdtype: Optional[dns.rdatatype.RdataType] = None, + default_ttl: Optional[int] = None, + ): self.tok = tok - (self.zone_origin, self.relativize, _) = \ - txn.manager.origin_information() + (self.zone_origin, self.relativize, _) = txn.manager.origin_information() self.current_origin = self.zone_origin self.last_ttl = 0 self.last_ttl_known = False @@ -191,13 +202,17 @@ class Reader: try: rdtype = dns.rdatatype.from_text(token.value) except Exception: - raise dns.exception.SyntaxError( - "unknown rdatatype '%s'" % token.value) + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) try: - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.current_origin, self.relativize, - self.zone_origin) + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.current_origin, + self.relativize, + self.zone_origin, + ) except dns.exception.SyntaxError: # Catch and reraise. raise @@ -209,7 +224,8 @@ class Reader: # helpful filename:line info. (ty, va) = sys.exc_info()[:2] raise dns.exception.SyntaxError( - "caught exception {}: {}".format(str(ty), str(va))) + "caught exception {}: {}".format(str(ty), str(va)) + ) if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default @@ -240,30 +256,30 @@ class Reader: g1 = is_generate1.match(side) if g1: mod, sign, offset, width, base = g1.groups() - if sign == '': - sign = '+' + if sign == "": + sign = "+" g2 = is_generate2.match(side) if g2: mod, sign, offset = g2.groups() - if sign == '': - sign = '+' + if sign == "": + sign = "+" width = 0 - base = 'd' + base = "d" g3 = is_generate3.match(side) if g3: mod, sign, offset, width = g3.groups() - if sign == '': - sign = '+' - base = 'd' + if sign == "": + sign = "+" + base = "d" if not (g1 or g2 or g3): - mod = '' - sign = '+' + mod = "" + sign = "+" offset = 0 width = 0 - base = 'd' + base = "d" - if base != 'd': + if base != "d": raise NotImplementedError() return mod, sign, offset, width, base @@ -328,8 +344,7 @@ class Reader: if not token.is_identifier(): raise dns.exception.SyntaxError except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % - token.value) + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) # rhs (required) rhs = token.value @@ -341,24 +356,25 @@ class Reader: for i in range(start, stop + 1, step): # +1 because bind is inclusive and python is exclusive - if lsign == '+': + if lsign == "+": lindex = i + int(loffset) - elif lsign == '-': + elif lsign == "-": lindex = i - int(loffset) - if rsign == '-': + if rsign == "-": rindex = i - int(roffset) - elif rsign == '+': + elif rsign == "+": rindex = i + int(roffset) lzfindex = str(lindex).zfill(int(lwidth)) rzfindex = str(rindex).zfill(int(rwidth)) - name = lhs.replace('$%s' % (lmod), lzfindex) - rdata = rhs.replace('$%s' % (rmod), rzfindex) + name = lhs.replace("$%s" % (lmod), lzfindex) + rdata = rhs.replace("$%s" % (rmod), rzfindex) - self.last_name = dns.name.from_text(name, self.current_origin, - self.tok.idna_codec) + self.last_name = dns.name.from_text( + name, self.current_origin, self.tok.idna_codec + ) name = self.last_name if not name.is_subdomain(self.zone_origin): self._eat_line() @@ -367,9 +383,14 @@ class Reader: name = name.relativize(self.zone_origin) try: - rd = dns.rdata.from_text(rdclass, rdtype, rdata, - self.current_origin, self.relativize, - self.zone_origin) + rd = dns.rdata.from_text( + rdclass, + rdtype, + rdata, + self.current_origin, + self.relativize, + self.zone_origin, + ) except dns.exception.SyntaxError: # Catch and reraise. raise @@ -380,8 +401,9 @@ class Reader: # We convert them to syntax errors so that we can emit # helpful filename:line info. (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError("caught exception %s: %s" % - (str(ty), str(va))) + raise dns.exception.SyntaxError( + "caught exception %s: %s" % (str(ty), str(va)) + ) self.txn.add(name, ttl, rd) @@ -399,14 +421,16 @@ class Reader: if self.current_file is not None: self.current_file.close() if len(self.saved_state) > 0: - (self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known) = self.saved_state.pop(-1) + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) = self.saved_state.pop(-1) continue break elif token.is_eol(): @@ -414,51 +438,56 @@ class Reader: elif token.is_comment(): self.tok.get_eol() continue - elif token.value[0] == '$' and self.allow_directives: + elif token.value[0] == "$" and self.allow_directives: c = token.value.upper() - if c == '$TTL': + if c == "$TTL": token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError("bad $TTL") self.default_ttl = dns.ttl.from_text(token.value) self.default_ttl_known = True self.tok.get_eol() - elif c == '$ORIGIN': + elif c == "$ORIGIN": self.current_origin = self.tok.get_name() self.tok.get_eol() if self.zone_origin is None: self.zone_origin = self.current_origin self.txn._set_origin(self.current_origin) - elif c == '$INCLUDE' and self.allow_include: + elif c == "$INCLUDE" and self.allow_include: token = self.tok.get() filename = token.value token = self.tok.get() new_origin: Optional[dns.name.Name] if token.is_identifier(): - new_origin = dns.name.from_text(token.value, self.current_origin, self.tok.idna_codec) + new_origin = dns.name.from_text( + token.value, self.current_origin, self.tok.idna_codec + ) self.tok.get_eol() elif not token.is_eol_or_eof(): - raise dns.exception.SyntaxError( - "bad origin in $INCLUDE") + raise dns.exception.SyntaxError("bad origin in $INCLUDE") else: new_origin = self.current_origin - self.saved_state.append((self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known)) - self.current_file = open(filename, 'r') - self.tok = dns.tokenizer.Tokenizer(self.current_file, - filename) + self.saved_state.append( + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) + ) + self.current_file = open(filename, "r") + self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) self.current_origin = new_origin - elif c == '$GENERATE': + elif c == "$GENERATE": self._generate_line() else: raise dns.exception.SyntaxError( - "Unknown zone file directive '" + c + "'") + "Unknown zone file directive '" + c + "'" + ) continue self.tok.unget(token) self._rr_line() @@ -467,13 +496,13 @@ class Reader: if detail is None: detail = "syntax error" ex = dns.exception.SyntaxError( - "%s:%d: %s" % (filename, line_number, detail)) + "%s:%d: %s" % (filename, line_number, detail) + ) tb = sys.exc_info()[2] raise ex.with_traceback(tb) from None class RRsetsReaderTransaction(dns.transaction.Transaction): - def __init__(self, manager, replacement, read_only): assert not read_only super().__init__(manager, replacement, read_only) @@ -525,8 +554,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): if commit and self._changed(): rrsets = [] for (name, _, _), rdataset in self.rdatasets.items(): - rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, - rdataset.covers) + rrset = dns.rrset.RRset( + name, rdataset.rdclass, rdataset.rdtype, rdataset.covers + ) rrset.update(rdataset) rrsets.append(rrset) self.manager.set_rrsets(rrsets) @@ -536,8 +566,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): class RRSetsReaderManager(dns.transaction.TransactionManager): - def __init__(self, origin=dns.name.root, relativize=False, - rdclass=dns.rdataclass.IN): + def __init__( + self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN + ): self.origin = origin self.relativize = relativize self.rdclass = rdclass @@ -561,16 +592,18 @@ class RRSetsReaderManager(dns.transaction.TransactionManager): self.rrsets = rrsets -def read_rrsets(text: Any, - name: Optional[Union[dns.name.Name, str]]=None, - ttl: Optional[int]=None, - rdclass: Optional[Union[dns.rdataclass.RdataClass, str]]=dns.rdataclass.IN, - default_rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - rdtype: Optional[Union[dns.rdatatype.RdataType, str]]=None, - default_ttl: Optional[Union[int, str]]=None, - idna_codec: Optional[dns.name.IDNACodec]=None, - origin: Optional[Union[dns.name.Name, str]]=dns.name.root, - relativize: bool=False) -> List[dns.rrset.RRset]: +def read_rrsets( + text: Any, + name: Optional[Union[dns.name.Name, str]] = None, + ttl: Optional[int] = None, + rdclass: Optional[Union[dns.rdataclass.RdataClass, str]] = dns.rdataclass.IN, + default_rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None, + default_ttl: Optional[Union[int, str]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[Union[dns.name.Name, str]] = dns.name.root, + relativize: bool = False, +) -> List[dns.rrset.RRset]: """Read one or more rrsets from the specified text, possibly subject to restrictions. @@ -639,9 +672,17 @@ def read_rrsets(text: Any, the_rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: - tok = dns.tokenizer.Tokenizer(text, '', idna_codec=idna_codec) - reader = Reader(tok, the_default_rdclass, txn, allow_directives=False, - force_name=name, force_ttl=ttl, force_rdclass=the_rdclass, - force_rdtype=the_rdtype, default_ttl=default_ttl) + tok = dns.tokenizer.Tokenizer(text, "", idna_codec=idna_codec) + reader = Reader( + tok, + the_default_rdclass, + txn, + allow_directives=False, + force_name=name, + force_ttl=ttl, + force_rdclass=the_rdclass, + force_rdtype=the_rdtype, + default_ttl=default_ttl, + ) reader.read() return manager.rrsets diff --git a/doc/conf.py b/doc/conf.py index 17e33eab..8bee4571 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -19,7 +19,8 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ @@ -32,37 +33,37 @@ sys.path.insert(0, os.path.abspath('..')) # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.todo' - ] + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.todo", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'dnspython' -copyright = 'Dnspython Contributors' -author = 'Dnspython Contributors' +project = "dnspython" +copyright = "Dnspython Contributors" +author = "Dnspython Contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '2.3' +version = "2.3" # The full version, including alpha/beta/rc tags. -release = '2.3.0' +release = "2.3.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -74,24 +75,24 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True # -- Options for autodoc -------------------------------------------------- -autoclass_content = 'both' +autoclass_content = "both" # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -#html_theme = 'alabaster' +# html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -107,7 +108,7 @@ autoclass_content = 'both' # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'dnspythondoc' +htmlhelp_basename = "dnspythondoc" # -- Options for LaTeX output --------------------------------------------- @@ -116,15 +117,12 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -134,8 +132,7 @@ latex_elements = { # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'dnspython.tex', 'dnspython Documentation', - 'Nominum, Inc.', 'manual'), + (master_doc, "dnspython.tex", "dnspython Documentation", "Nominum, Inc.", "manual"), ] @@ -143,10 +140,7 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'dnspython', 'dnspython Documentation', - [author], 1) -] +man_pages = [(master_doc, "dnspython", "dnspython Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -155,7 +149,13 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'dnspython', 'dnspython Documentation', - author, 'dnspython', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "dnspython", + "dnspython Documentation", + author, + "dnspython", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/doc/util/auto-values.py b/doc/util/auto-values.py index cd738185..8a41ba90 100644 --- a/doc/util/auto-values.py +++ b/doc/util/auto-values.py @@ -7,10 +7,10 @@ name = sys.argv[1] title = sys.argv[2] print(title) -print('=' * len(title)) +print("=" * len(title)) print() module = importlib.import_module(name) for t in sorted(module._by_text.keys()): - print('.. py:data:: {}.{}'.format(name, t)) - print(' :annotation: = {}'.format(module._by_text[t])) + print(".. py:data:: {}.{}".format(name, t)) + print(" :annotation: = {}".format(module._by_text[t])) diff --git a/examples/async_dns.py b/examples/async_dns.py index c42defcc..f7e3fe5d 100644 --- a/examples/async_dns.py +++ b/examples/async_dns.py @@ -1,4 +1,3 @@ - import sys import trio @@ -7,24 +6,26 @@ import dns.message import dns.asyncquery import dns.asyncresolver + async def main(): if len(sys.argv) > 1: host = sys.argv[0] else: - host = 'www.dnspython.org' - q = dns.message.make_query(host, 'A') - r = await dns.asyncquery.udp(q, '8.8.8.8') + host = "www.dnspython.org" + q = dns.message.make_query(host, "A") + r = await dns.asyncquery.udp(q, "8.8.8.8") print(r) - q = dns.message.make_query(host, 'A') - r = await dns.asyncquery.tcp(q, '8.8.8.8') + q = dns.message.make_query(host, "A") + r = await dns.asyncquery.tcp(q, "8.8.8.8") print(r) - q = dns.message.make_query(host, 'A') - r = await dns.asyncquery.tls(q, '8.8.8.8') + q = dns.message.make_query(host, "A") + r = await dns.asyncquery.tls(q, "8.8.8.8") print(r) - a = await dns.asyncresolver.resolve(host, 'A') + a = await dns.asyncresolver.resolve(host, "A") print(a.response) zn = await dns.asyncresolver.zone_for_name(host) print(zn) -if __name__ == '__main__': + +if __name__ == "__main__": trio.run(main) diff --git a/examples/ddns.py b/examples/ddns.py index c584f422..154ab3da 100755 --- a/examples/ddns.py +++ b/examples/ddns.py @@ -35,17 +35,15 @@ import dns.tsigkeyring # Replace the keyname and secret with appropriate values for your # configuration. # -keyring = dns.tsigkeyring.from_text({ - 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ==' - }) +keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="}) # # Replace "example." with your domain, and "host" with your hostname. # -update = dns.update.Update('example.', keyring=keyring) -update.replace('host', 300, 'A', sys.argv[1]) +update = dns.update.Update("example.", keyring=keyring) +update.replace("host", 300, "A", sys.argv[1]) # # Replace "10.0.0.1" with the IP address of your master server. # -response = dns.query.tcp(update, '10.0.0.1', timeout=10) +response = dns.query.tcp(update, "10.0.0.1", timeout=10) diff --git a/examples/doh-json.py b/examples/doh-json.py index 8cfe1b0c..e9fa0876 100755 --- a/examples/doh-json.py +++ b/examples/doh-json.py @@ -24,26 +24,27 @@ import dns.rdatatype # "simple" below means "simple python data types", i.e. things made of # combinations of dictionaries, lists, strings, and numbers. + def make_rr(simple, rdata): csimple = copy.copy(simple) - csimple['data'] = rdata.to_text() + csimple["data"] = rdata.to_text() return csimple + def flatten_rrset(rrs): simple = { - 'name': str(rrs.name), - 'type': rrs.rdtype, + "name": str(rrs.name), + "type": rrs.rdtype, } if len(rrs) > 0: - simple['TTL'] = rrs.ttl + simple["TTL"] = rrs.ttl return [make_rr(simple, rdata) for rdata in rrs] else: return [simple] + def to_doh_simple(message): - simple = { - 'Status': message.rcode() - } + simple = {"Status": message.rcode()} for f in dns.flags.Flag: if f != dns.flags.Flag.AA and f != dns.flags.Flag.QR: # DoH JSON doesn't need AA and omits it. DoH JSON is only @@ -57,6 +58,7 @@ def to_doh_simple(message): # we don't encode the ecs_client_subnet field return simple + def from_doh_simple(simple, add_qr=False): message = dns.message.QueryMessage() flags = 0 @@ -66,27 +68,35 @@ def from_doh_simple(simple, add_qr=False): if add_qr: # QR is implied flags |= dns.flags.QR message.flags = flags - message.set_rcode(simple.get('Status', 0)) + message.set_rcode(simple.get("Status", 0)) for i, sn in enumerate(dns.message.MessageSection): rr_list = simple.get(sn.name.title(), []) for rr in rr_list: - rdtype = dns.rdatatype.RdataType(rr['type']) - rrs = message.find_rrset(i, dns.name.from_text(rr['name']), - dns.rdataclass.IN, rdtype, - create=True) - if 'data' in rr: - rrs.add(dns.rdata.from_text(dns.rdataclass.IN, rdtype, - rr['data']), rr.get('TTL', 0)) + rdtype = dns.rdatatype.RdataType(rr["type"]) + rrs = message.find_rrset( + i, + dns.name.from_text(rr["name"]), + dns.rdataclass.IN, + rdtype, + create=True, + ) + if "data" in rr: + rrs.add( + dns.rdata.from_text(dns.rdataclass.IN, rdtype, rr["data"]), + rr.get("TTL", 0), + ) # we don't decode the ecs_client_subnet field return message -a = dns.resolver.resolve('www.dnspython.org', 'a') +a = dns.resolver.resolve("www.dnspython.org", "a") p = to_doh_simple(a.response) print(json.dumps(p, indent=4)) -response = requests.get('https://dns.google/resolve?', verify=True, - params={'name': 'www.dnspython.org', - 'type': 1}) +response = requests.get( + "https://dns.google/resolve?", + verify=True, + params={"name": "www.dnspython.org", "type": 1}, +) p = json.loads(response.text) m = from_doh_simple(p, True) print(m) diff --git a/examples/doh.py b/examples/doh.py index e789bf10..17787ed3 100755 --- a/examples/doh.py +++ b/examples/doh.py @@ -13,8 +13,8 @@ import dns.rdatatype def main(): - where = '1.1.1.1' - qname = 'example.com.' + where = "1.1.1.1" + qname = "example.com." # one method is to use context manager, session will automatically close with requests.sessions.Session() as session: q = dns.message.make_query(qname, dns.rdatatype.A) @@ -24,8 +24,8 @@ def main(): # ... do more lookups - where = 'https://dns.google/dns-query' - qname = 'example.net.' + where = "https://dns.google/dns-query" + qname = "example.net." # second method, close session manually session = requests.sessions.Session() q = dns.message.make_query(qname, dns.rdatatype.A) @@ -38,5 +38,6 @@ def main(): # close the session when you're done session.close() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/examples/e164.py b/examples/e164.py index 6d9e8727..8b677bf1 100755 --- a/examples/e164.py +++ b/examples/e164.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import dns.e164 + n = dns.e164.from_e164("+1 555 1212") print(n) print(dns.e164.to_e164(n)) diff --git a/examples/ecs.py b/examples/ecs.py index f7b31d84..d5a84f24 100755 --- a/examples/ecs.py +++ b/examples/ecs.py @@ -1,15 +1,13 @@ - import dns.edns import dns.message import dns.query # This example demonstrates how to use the EDNS client subnet option -ADDRESS = '0.0.0.0' # replace this with the address you want to check +ADDRESS = "0.0.0.0" # replace this with the address you want to check PREFIX = 0 # replace this with a prefix length (typically 24 for IPv4) ecs = dns.edns.ECSOption(ADDRESS, PREFIX) -q = dns.message.make_query('www.google.com', 'A', use_edns=0, options=[ecs]) -r = dns.query.udp(q, '8.8.8.8') +q = dns.message.make_query("www.google.com", "A", use_edns=0, options=[ecs]) +r = dns.query.udp(q, "8.8.8.8") print(r) - diff --git a/examples/edns.py b/examples/edns.py index a130f85a..0566bfb8 100755 --- a/examples/edns.py +++ b/examples/edns.py @@ -5,10 +5,10 @@ import dns.message import dns.query import dns.resolver -n = '.' +n = "." t = dns.rdatatype.SOA -l = '199.7.83.42' # Address of l.root-servers.net -i = '149.20.1.73' # Address of ns1.isc.org, for COOKIEs +l = "199.7.83.42" # Address of l.root-servers.net +i = "149.20.1.73" # Address of ns1.isc.org, for COOKIEs q_list = [] @@ -16,7 +16,7 @@ q_list = [] q_list.append((l, dns.message.make_query(n, t))) # The same query, but with EDNS0 turned on with no options -q_list.append((l,dns.message.make_query(n, t, use_edns=0))) +q_list.append((l, dns.message.make_query(n, t, use_edns=0))) # Use use_edns() to specify EDNS0 options, such as buffer size this_q = dns.message.make_query(n, t) @@ -25,28 +25,46 @@ q_list.append((l, this_q)) # With an NSID option # use_edns=0 is not needed if options are specified) -q_list.append((l, dns.message.make_query(n, t,\ - options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')]))) +q_list.append( + ( + l, + dns.message.make_query( + n, t, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b"")] + ), + ) +) # With an NSID option, but with use_edns() to specify the options this_q = dns.message.make_query(n, t) -this_q.use_edns(0, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')]) +this_q.use_edns(0, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b"")]) q_list.append((l, this_q)) # With a COOKIE -q_list.append((i, dns.message.make_query(n, t,\ - options=[dns.edns.GenericOption(dns.edns.OptionType.COOKIE, b'0xfe11ac99bebe3322')]))) +q_list.append( + ( + i, + dns.message.make_query( + n, + t, + options=[ + dns.edns.GenericOption( + dns.edns.OptionType.COOKIE, b"0xfe11ac99bebe3322" + ) + ], + ), + ) +) # With an ECS option using dns.edns.ECSOption to form the option -q_list.append((l, dns.message.make_query(n, t,\ - options=[dns.edns.ECSOption('192.168.0.0', 20)]))) +q_list.append( + (l, dns.message.make_query(n, t, options=[dns.edns.ECSOption("192.168.0.0", 20)])) +) for (addr, q) in q_list: - r = dns.query.udp(q, addr) - if not r.options: - print('No EDNS options returned') - else: - for o in r.options: - print(o.otype.value, o.data) - print() - + r = dns.query.udp(q, addr) + if not r.options: + print("No EDNS options returned") + else: + for o in r.options: + print(o.otype.value, o.data) + print() diff --git a/examples/edns_resolver.py b/examples/edns_resolver.py index fe5cc0f9..6edf4a9e 100644 --- a/examples/edns_resolver.py +++ b/examples/edns_resolver.py @@ -5,10 +5,10 @@ import dns.message import dns.query import dns.resolver -n = '.' +n = "." t = dns.rdatatype.SOA -l = 'google.com' # Address of l.root-servers.net, '199.7.83.42' -i = 'ns1.isc.org' # Address of ns1.isc.org, for COOKIEs, '149.20.1.73' +l = "google.com" # Address of l.root-servers.net, '199.7.83.42' +i = "ns1.isc.org" # Address of ns1.isc.org, for COOKIEs, '149.20.1.73' o_list = [] @@ -22,21 +22,32 @@ o_list.append((l, dict(options=[]))) o_list.append((l, dict(payload=2000))) # With an NSID option, but with use_edns() to specify the options -edns_kwargs = dict(edns=0, options=[ - dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')]) +edns_kwargs = dict( + edns=0, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b"")] +) o_list.append((l, edns_kwargs)) # With a COOKIE -o_list.append((i, dict(options=[ - dns.edns.GenericOption(dns.edns.OptionType.COOKIE, b'0xfe11ac99bebe3322')]))) +o_list.append( + ( + i, + dict( + options=[ + dns.edns.GenericOption( + dns.edns.OptionType.COOKIE, b"0xfe11ac99bebe3322" + ) + ] + ), + ) +) # With an ECS option using cloudflare dns address -o_list.append((l, dict(options=[dns.edns.ECSOption('1.1.1.1', 24)]))) +o_list.append((l, dict(options=[dns.edns.ECSOption("1.1.1.1", 24)]))) # With an ECS option using the current machine address import urllib.request -external_ip = urllib.request.urlopen('https://ident.me').read().decode('utf8') +external_ip = urllib.request.urlopen("https://ident.me").read().decode("utf8") o_list.append((l, dict(options=[dns.edns.ECSOption(external_ip, 24)]))) @@ -45,5 +56,5 @@ aresolver = dns.resolver.Resolver() for (addr, edns_kwargs) in o_list: if edns_kwargs: aresolver.use_edns(**edns_kwargs) - aresolver.nameservers = ['8.8.8.8'] - print(list(aresolver.resolve(addr, 'A'))) + aresolver.nameservers = ["8.8.8.8"] + print(list(aresolver.resolve(addr, "A"))) diff --git a/examples/mx.py b/examples/mx.py index 2c310ea9..5e9075b4 100755 --- a/examples/mx.py +++ b/examples/mx.py @@ -2,6 +2,6 @@ import dns.resolver -answers = dns.resolver.resolve('nominum.com', 'MX') +answers = dns.resolver.resolve("nominum.com", "MX") for rdata in answers: - print('Host', rdata.exchange, 'has preference', rdata.preference) + print("Host", rdata.exchange, "has preference", rdata.preference) diff --git a/examples/name.py b/examples/name.py index 614fdbc0..ff687e89 100755 --- a/examples/name.py +++ b/examples/name.py @@ -2,12 +2,12 @@ import dns.name -n = dns.name.from_text('www.dnspython.org') -o = dns.name.from_text('dnspython.org') -print(n.is_subdomain(o)) # True -print(n.is_superdomain(o)) # False -print(n > o) # True -rel = n.relativize(o) # rel is the relative name www +n = dns.name.from_text("www.dnspython.org") +o = dns.name.from_text("dnspython.org") +print(n.is_subdomain(o)) # True +print(n.is_superdomain(o)) # False +print(n > o) # True +rel = n.relativize(o) # rel is the relative name www n2 = rel + o -print(n2 == n) # True -print(n.labels) # ['www', 'dnspython', 'org', ''] +print(n2 == n) # True +print(n.labels) # ['www', 'dnspython', 'org', ''] diff --git a/examples/query_specific.py b/examples/query_specific.py index c82207c1..73dc3513 100644 --- a/examples/query_specific.py +++ b/examples/query_specific.py @@ -9,30 +9,30 @@ import dns.query # This way is just like nslookup/dig: -qname = dns.name.from_text('amazon.com') +qname = dns.name.from_text("amazon.com") q = dns.message.make_query(qname, dns.rdatatype.NS) -print('The query is:') +print("The query is:") print(q) -print('') -r = dns.query.udp(q, '8.8.8.8') -print('The response is:') +print("") +r = dns.query.udp(q, "8.8.8.8") +print("The response is:") print(r) -print('') -print('The nameservers are:') +print("") +print("The nameservers are:") ns_rrset = r.find_rrset(r.answer, qname, dns.rdataclass.IN, dns.rdatatype.NS) for rr in ns_rrset: print(rr.target) -print('') -print('') +print("") +print("") # A higher-level way import dns.resolver resolver = dns.resolver.Resolver(configure=False) -resolver.nameservers = ['8.8.8.8'] -answer = resolver.resolve('amazon.com', 'NS') -print('The nameservers are:') +resolver.nameservers = ["8.8.8.8"] +answer = resolver.resolve("amazon.com", "NS") +print("The nameservers are:") for rr in answer: print(rr.target) @@ -42,7 +42,7 @@ for rr in answer: # This sends a query with RD=0 for the root SOA RRset to the IP address # for l.root-servers.net. -q = dns.message.make_query('.', dns.rdatatype.SOA, flags=0) -r = dns.query.udp(q, '199.7.83.42') -print('\nThe flags in the response are {}'.format(dns.flags.to_text(r.flags))) +q = dns.message.make_query(".", dns.rdatatype.SOA, flags=0) +r = dns.query.udp(q, "199.7.83.42") +print("\nThe flags in the response are {}".format(dns.flags.to_text(r.flags))) print('The SOA in the response is "{}"'.format((r.answer)[0][0])) diff --git a/examples/receive_notify.py b/examples/receive_notify.py index c41b3363..97d01f30 100644 --- a/examples/receive_notify.py +++ b/examples/receive_notify.py @@ -13,7 +13,7 @@ import dns.name from typing import cast -address = '127.0.0.1' +address = "127.0.0.1" port = 53535 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -23,16 +23,17 @@ while True: notify = dns.message.from_wire(wire) try: - soa = notify.find_rrset(notify.answer, notify.question[0].name, - dns.rdataclass.IN, dns.rdatatype.SOA) + soa = notify.find_rrset( + notify.answer, notify.question[0].name, dns.rdataclass.IN, dns.rdatatype.SOA + ) # Do something with the SOA RR here - print('The serial number for', soa.name, 'is', soa[0].serial) + print("The serial number for", soa.name, "is", soa[0].serial) except KeyError: # No SOA RR in the answer section. pass - response = dns.message.make_response(notify) # type: dns.message.Message + response = dns.message.make_response(notify) # type: dns.message.Message response.flags |= dns.flags.AA wire = response.to_wire(cast(dns.name.Name, response)) s.sendto(wire, address) diff --git a/examples/reverse.py b/examples/reverse.py index 83b99b77..5829c681 100755 --- a/examples/reverse.py +++ b/examples/reverse.py @@ -20,14 +20,13 @@ import dns.zone import dns.ipv4 import os.path import sys -from typing import Dict, List # pylint: disable=unused-import +from typing import Dict, List # pylint: disable=unused-import -reverse_map = {} # type: Dict[str, List[str]] +reverse_map = {} # type: Dict[str, List[str]] for filename in sys.argv[1:]: - zone = dns.zone.from_file(filename, os.path.basename(filename), - relativize=False) - for (name, ttl, rdata) in zone.iterate_rdatas('A'): + zone = dns.zone.from_file(filename, os.path.basename(filename), relativize=False) + for (name, ttl, rdata) in zone.iterate_rdatas("A"): print(type(rdata)) try: reverse_map[rdata.address].append(name.to_text()) diff --git a/examples/reverse_name.py b/examples/reverse_name.py index 02b2e514..ec7fe1c3 100755 --- a/examples/reverse_name.py +++ b/examples/reverse_name.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import dns.reversename + n = dns.reversename.from_address("127.0.0.1") print(n) print(dns.reversename.to_address(n)) diff --git a/examples/xfr.py b/examples/xfr.py index a20cae3d..1c8175e5 100755 --- a/examples/xfr.py +++ b/examples/xfr.py @@ -4,9 +4,9 @@ import dns.query import dns.resolver import dns.zone -soa_answer = dns.resolver.resolve('dnspython.org', 'SOA') -master_answer = dns.resolver.resolve(soa_answer[0].mname, 'A') +soa_answer = dns.resolver.resolve("dnspython.org", "SOA") +master_answer = dns.resolver.resolve(soa_answer[0].mname, "A") -z = dns.zone.from_xfr(dns.query.xfr(master_answer[0].address, 'dnspython.org')) +z = dns.zone.from_xfr(dns.query.xfr(master_answer[0].address, "dnspython.org")) for n in sorted(z.nodes.keys()): print(z[n].to_text(n)) diff --git a/examples/zonediff.py b/examples/zonediff.py index 164bf2b5..2957f87e 100755 --- a/examples/zonediff.py +++ b/examples/zonediff.py @@ -21,9 +21,9 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. """See diff_zones.__doc__ for more information""" -from typing import cast, Union, Any # pylint: disable=unused-import +from typing import cast, Union, Any # pylint: disable=unused-import -__all__ = ['diff_zones', 'format_changes_plain', 'format_changes_html'] +__all__ = ["diff_zones", "format_changes_plain", "format_changes_html"] try: import dns.zone @@ -32,11 +32,12 @@ except ImportError: raise SystemExit("Please install dnspython") -def diff_zones(zone1, # type: dns.zone.Zone - zone2, # type: dns.zone.Zone - ignore_ttl=False, - ignore_soa=False - ): # type: (...) -> list +def diff_zones( + zone1, # type: dns.zone.Zone + zone2, # type: dns.zone.Zone + ignore_ttl=False, + ignore_soa=False, +): # type: (...) -> list """diff_zones(zone1, zone2, ignore_ttl=False, ignore_soa=False) -> changes Compares two dns.zone.Zone objects and returns a list of all changes in the format (name, oldnode, newnode). @@ -67,11 +68,13 @@ def diff_zones(zone1, # type: dns.zone.Zone changes.append((str(name), n3, n4)) return changes -def _nodes_differ(n1, # type: dns.node.Node - n2, # type: dns.node.Node - ignore_ttl, # type: bool - ignore_soa # type: bool - ): # type: (...) -> bool + +def _nodes_differ( + n1, # type: dns.node.Node + n2, # type: dns.node.Node + ignore_ttl, # type: bool + ignore_soa, # type: bool +): # type: (...) -> bool if ignore_soa or not ignore_ttl: # Compare datasets directly for r in n1.rdatasets: @@ -91,11 +94,13 @@ def _nodes_differ(n1, # type: dns.node.Node else: return n1 != n2 -def format_changes_plain(oldf, # type: str - newf, # type: str - changes, # type: list - ignore_ttl=False - ): # type: (...) -> str + +def format_changes_plain( + oldf, # type: str + newf, # type: str + changes, # type: list + ignore_ttl=False, +): # type: (...) -> str """format_changes(oldfile, newfile, changes, ignore_ttl=False) -> str Given 2 filenames and a list of changes from diff_zones, produce diff-like output. If ignore_ttl is True, TTL-only changes are not displayed""" @@ -105,35 +110,37 @@ def format_changes_plain(oldf, # type: str ret += "@ %s\n" % name if not old: for r in new.rdatasets: - ret += "+ %s\n" % str(r).replace('\n', '\n+ ') + ret += "+ %s\n" % str(r).replace("\n", "\n+ ") elif not new: for r in old.rdatasets: - ret += "- %s\n" % str(r).replace('\n', '\n+ ') + ret += "- %s\n" % str(r).replace("\n", "\n+ ") else: for r in old.rdatasets: if r not in new.rdatasets or ( - r.ttl != new.find_rdataset(r.rdclass, r.rdtype).ttl and - not ignore_ttl + r.ttl != new.find_rdataset(r.rdclass, r.rdtype).ttl + and not ignore_ttl ): - ret += "- %s\n" % str(r).replace('\n', '\n+ ') + ret += "- %s\n" % str(r).replace("\n", "\n+ ") for r in new.rdatasets: if r not in old.rdatasets or ( - r.ttl != old.find_rdataset(r.rdclass, r.rdtype).ttl and - not ignore_ttl + r.ttl != old.find_rdataset(r.rdclass, r.rdtype).ttl + and not ignore_ttl ): - ret += "+ %s\n" % str(r).replace('\n', '\n+ ') + ret += "+ %s\n" % str(r).replace("\n", "\n+ ") return ret -def format_changes_html(oldf, # type: str - newf, # type: str - changes, # type: list - ignore_ttl=False - ): # type: (...) -> str + +def format_changes_html( + oldf, # type: str + newf, # type: str + changes, # type: list + ignore_ttl=False, +): # type: (...) -> str """format_changes(oldfile, newfile, changes, ignore_ttl=False) -> str Given 2 filenames and a list of changes from diff_zones, produce nice html output. If ignore_ttl is True, TTL-only changes are not displayed""" - ret = ''' + ret = """
@@ -141,7 +148,10 @@ def format_changes_html(oldf, # type: str - \n''' % (oldf, newf) + \n""" % ( + oldf, + newf, + ) for name, old, new in changes: ret += ' \n \n' % name @@ -150,36 +160,36 @@ def format_changes_html(oldf, # type: str ret += ( ' \n' ' \n' - ) % str(r).replace('\n', '
') + ) % str(r).replace("\n", "
") elif not new: for r in old.rdatasets: ret += ( ' \n' ' \n' - ) % str(r).replace('\n', '
') + ) % str(r).replace("\n", "
") else: ret += ' \n' + ret += str(r).replace("\n", "
") + ret += "\n" ret += ' \n' - ret += ' \n' - return ret + ' \n
 %s
%s %s%s ' for r in old.rdatasets: if r not in new.rdatasets or ( - r.ttl != new.find_rdataset(r.rdclass, r.rdtype).ttl and - not ignore_ttl + r.ttl != new.find_rdataset(r.rdclass, r.rdtype).ttl + and not ignore_ttl ): - ret += str(r).replace('\n', '
') - ret += '
' for r in new.rdatasets: if r not in old.rdatasets or ( - r.ttl != old.find_rdataset(r.rdclass, r.rdtype).ttl and - not ignore_ttl + r.ttl != old.find_rdataset(r.rdclass, r.rdtype).ttl + and not ignore_ttl ): - ret += str(r).replace('\n', '
') - ret += '
' + ret += str(r).replace("\n", "
") + ret += "\n" + ret += " \n" + return ret + " \n" # Make this module usable as a script too. -def main(): # type: () -> None +def main(): # type: () -> None import argparse import subprocess import sys @@ -191,24 +201,66 @@ def main(): # type: () -> None The differences shown will be logical differences, not textual differences. """ p = argparse.ArgumentParser(usage=usage) - p.add_argument('-s', '--ignore-soa', action="store_true", default=False, dest="ignore_soa", - help="Ignore SOA-only changes to records") - p.add_argument('-t', '--ignore-ttl', action="store_true", default=False, dest="ignore_ttl", - help="Ignore TTL-only changes to Rdata") - p.add_argument('-T', '--traceback', action="store_true", default=False, dest="tracebacks", - help="Show python tracebacks when errors occur") - p.add_argument('-H', '--html', action="store_true", default=False, dest="html", - help="Print HTML output") - p.add_argument('-g', '--git', action="store_true", default=False, dest="use_git", - help="Use git revisions instead of real files") - p.add_argument('-b', '--bzr', action="store_true", default=False, dest="use_bzr", - help="Use bzr revisions instead of real files") - p.add_argument('-r', '--rcs', action="store_true", default=False, dest="use_rcs", - help="Use rcs revisions instead of real files") + p.add_argument( + "-s", + "--ignore-soa", + action="store_true", + default=False, + dest="ignore_soa", + help="Ignore SOA-only changes to records", + ) + p.add_argument( + "-t", + "--ignore-ttl", + action="store_true", + default=False, + dest="ignore_ttl", + help="Ignore TTL-only changes to Rdata", + ) + p.add_argument( + "-T", + "--traceback", + action="store_true", + default=False, + dest="tracebacks", + help="Show python tracebacks when errors occur", + ) + p.add_argument( + "-H", + "--html", + action="store_true", + default=False, + dest="html", + help="Print HTML output", + ) + p.add_argument( + "-g", + "--git", + action="store_true", + default=False, + dest="use_git", + help="Use git revisions instead of real files", + ) + p.add_argument( + "-b", + "--bzr", + action="store_true", + default=False, + dest="use_bzr", + help="Use bzr revisions instead of real files", + ) + p.add_argument( + "-r", + "--rcs", + action="store_true", + default=False, + dest="use_rcs", + help="Use rcs revisions instead of real files", + ) opts, args = p.parse_args() opts.use_vc = opts.use_git or opts.use_bzr or opts.use_rcs - def _open(what, err): # type: (Union[list,str], str) -> Any + def _open(what, err): # type: (Union[list,str], str) -> Any if isinstance(what, list): # Must be a list, open subprocess try: @@ -224,7 +276,7 @@ The differences shown will be logical differences, not textual differences. else: # Open as normal file try: - return open(what, 'rb') + return open(what, "rb") except IOError: sys.stderr.write(err + "\n") if opts.tracebacks: @@ -254,23 +306,35 @@ The differences shown will be logical differences, not textual differences. old, new = None, None oldz, newz = None, None if opts.use_bzr: - old = _open(["bzr", "cat", "-r" + oldr, filename], - "Unable to retrieve revision {} of {}".format(oldr, filename)) + old = _open( + ["bzr", "cat", "-r" + oldr, filename], + "Unable to retrieve revision {} of {}".format(oldr, filename), + ) if newr is not None: - new = _open(["bzr", "cat", "-r" + newr, filename], - "Unable to retrieve revision {} of {}".format(newr, filename)) + new = _open( + ["bzr", "cat", "-r" + newr, filename], + "Unable to retrieve revision {} of {}".format(newr, filename), + ) elif opts.use_git: - old = _open(["git", "show", oldn], - "Unable to retrieve revision {} of {}".format(oldr, filename)) + old = _open( + ["git", "show", oldn], + "Unable to retrieve revision {} of {}".format(oldr, filename), + ) if newr is not None: - new = _open(["git", "show", newn], - "Unable to retrieve revision {} of {}".format(newr, filename)) + new = _open( + ["git", "show", newn], + "Unable to retrieve revision {} of {}".format(newr, filename), + ) elif opts.use_rcs: - old = _open(["co", "-q", "-p", "-r" + oldr, filename], - "Unable to retrieve revision {} of {}".format(oldr, filename)) + old = _open( + ["co", "-q", "-p", "-r" + oldr, filename], + "Unable to retrieve revision {} of {}".format(oldr, filename), + ) if newr is not None: - new = _open(["co", "-q", "-p", "-r" + newr, filename], - "Unable to retrieve revision {} of {}".format(newr, filename)) + new = _open( + ["co", "-q", "-p", "-r" + newr, filename], + "Unable to retrieve revision {} of {}".format(newr, filename), + ) if not opts.use_vc: old = _open(oldn, "Unable to open %s" % oldn) if not opts.use_vc or newr is None: @@ -281,13 +345,13 @@ The differences shown will be logical differences, not textual differences. # Parse the zones try: - oldz = dns.zone.from_file(old, origin='.', check_origin=False) + oldz = dns.zone.from_file(old, origin=".", check_origin=False) except dns.exception.DNSException: sys.stderr.write("Incorrect zonefile: %s\n" % old) if opts.tracebacks: traceback.print_exc() try: - newz = dns.zone.from_file(new, origin='.', check_origin=False) + newz = dns.zone.from_file(new, origin=".", check_origin=False) except dns.exception.DNSException: sys.stderr.write("Incorrect zonefile: %s\n" % new) if opts.tracebacks: @@ -306,5 +370,6 @@ The differences shown will be logical differences, not textual differences. print(format_changes_plain(oldn, newn, changes, opts.ignore_ttl)) sys.exit(1) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/pylintrc b/pylintrc index f23e8ea4..306429ea 100644 --- a/pylintrc +++ b/pylintrc @@ -32,6 +32,7 @@ disable= raise-missing-from, # we should start doing this, but too noisy for now consider-using-f-string, unspecified-encoding, + useless-super-delegation, # not against this, but we have to do it for mypy happiness [REPORTS] @@ -49,4 +50,4 @@ reports=no msg-template='{path}:{line}: [{msg_id}({symbol}), {obj}] {msg})' [FORMAT] -max-line-length=120 +max-line-length=88 diff --git a/setup.py b/setup.py index 42e794bc..c91c7193 100755 --- a/setup.py +++ b/setup.py @@ -22,18 +22,20 @@ from setuptools import setup try: - sys.argv.remove("--cython-compile") + sys.argv.remove("--cython-compile") except ValueError: - compile_cython = False + compile_cython = False else: compile_cython = True from Cython.Build import cythonize - ext_modules = cythonize(['dns/*.py', 'dns/rdtypes/*.py', 'dns/rdtypes/*/*.py'], - language_level='3') + + ext_modules = cythonize( + ["dns/*.py", "dns/rdtypes/*.py", "dns/rdtypes/*/*.py"], language_level="3" + ) kwargs = { - 'ext_modules': ext_modules if compile_cython else None, - 'zip_safe': False if compile_cython else None, - } + "ext_modules": ext_modules if compile_cython else None, + "zip_safe": False if compile_cython else None, +} setup(**kwargs) diff --git a/tests/md_module.py b/tests/md_module.py index 19568bd9..6ad05a83 100644 --- a/tests/md_module.py +++ b/tests/md_module.py @@ -1,4 +1,5 @@ import dns.rdtypes.nsbase + class MD(dns.rdtypes.nsbase.NSBase): """Test MD record.""" diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index 3391b323..7246e112 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -13,23 +13,26 @@ import dns.asyncquery import dns.message import dns.rcode + async def read_exactly(stream, count): """Read the specified number of bytes from stream. Keep trying until we either get the desired amount, or we hit EOF. """ - s = b'' + s = b"" while count > 0: n = await stream.receive_some(count) - if n == b'': + if n == b"": raise EOFError count = count - len(n) s = s + n return s + class ConnectionType(enum.IntEnum): UDP = 1 TCP = 2 + class Request: def __init__(self, message, wire, peer, local, connection_type): self.message = message @@ -54,6 +57,7 @@ class Request: def qtype(self): return self.question.rdtype + class Server(threading.Thread): """The nanoserver is a nameserver skeleton suitable for faking a DNS @@ -75,9 +79,16 @@ class Server(threading.Thread): called. """ - def __init__(self, address='127.0.0.1', port=0, enable_udp=True, - enable_tcp=True, use_thread=True, origin=None, - keyring=None): + def __init__( + self, + address="127.0.0.1", + port=0, + enable_udp=True, + enable_tcp=True, + use_thread=True, + origin=None, + keyring=None, + ): super().__init__() self.address = address self.port = port @@ -102,21 +113,20 @@ class Server(threading.Thread): try: while True: if self.enable_udp: - self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, - 0) + self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) self.udp.bind((self.address, self.port)) self.udp_address = self.udp.getsockname() if self.enable_tcp: - self.tcp = socket.socket(socket.AF_INET, - socket.SOCK_STREAM, 0) - self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, - 1) + self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if self.port == 0 and self.enable_udp: try: self.tcp.bind((self.address, self.udp_address[1])) except OSError as e: - if e.errno == errno.EADDRINUSE and \ - len(open_udp_sockets) < 100: + if ( + e.errno == errno.EADDRINUSE + and len(open_udp_sockets) < 100 + ): open_udp_sockets.append(self.udp) continue raise @@ -247,8 +257,7 @@ class Server(threading.Thread): while True: try: (wire, peer) = await sock.recvfrom(65535) - for wire in self.handle_wire(wire, peer, local, - ConnectionType.UDP): + for wire in self.handle_wire(wire, peer, local, ConnectionType.UDP): await sock.sendto(wire, peer) except Exception: pass @@ -261,8 +270,7 @@ class Server(threading.Thread): ldata = await read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await read_exactly(stream, l) - for wire in self.handle_wire(wire, peer, local, - ConnectionType.TCP): + for wire in self.handle_wire(wire, peer, local, ConnectionType.TCP): l = len(wire) stream_message = struct.pack("!H", l) + wire await stream.send_all(stream_message) @@ -274,8 +282,12 @@ class Server(threading.Thread): self.tcp = None # we own cleanup listener = trio.SocketListener(sock) async with trio.open_nursery() as nursery: - serve = functools.partial(trio.serve_listeners, self.serve_tcp, - [listener], handler_nursery=nursery) + serve = functools.partial( + trio.serve_listeners, + self.serve_tcp, + [listener], + handler_nursery=nursery, + ) nursery.start_soon(serve) async def main(self): @@ -292,9 +304,10 @@ class Server(threading.Thread): def run(self): if not self.use_thread: - raise RuntimeError('start() called on a use_thread=False Server') + raise RuntimeError("start() called on a use_thread=False Server") trio.run(self.main) + if __name__ == "__main__": import sys import time @@ -302,8 +315,10 @@ if __name__ == "__main__": async def trio_main(): try: with Server(port=5354, use_thread=False) as server: - print(f'Trio mode: listening on UDP: {server.udp_address}, ' + - f'TCP: {server.tcp_address}') + print( + f"Trio mode: listening on UDP: {server.udp_address}, " + + f"TCP: {server.tcp_address}" + ) async with trio.open_nursery() as nursery: nursery.start_soon(server.main) except Exception: @@ -311,11 +326,13 @@ if __name__ == "__main__": def threaded_main(): with Server(port=5354) as server: - print(f'Thread Mode: listening on UDP: {server.udp_address}, ' + - f'TCP: {server.tcp_address}') + print( + f"Thread Mode: listening on UDP: {server.udp_address}, " + + f"TCP: {server.tcp_address}" + ) time.sleep(300) - if len(sys.argv) > 1 and sys.argv[1] == 'trio': + if len(sys.argv) > 1 and sys.argv[1] == "trio": trio.run(trio_main) else: threaded_main() diff --git a/tests/stxt_module.py b/tests/stxt_module.py index 7f612357..dfe875c9 100644 --- a/tests/stxt_module.py +++ b/tests/stxt_module.py @@ -1,4 +1,5 @@ import dns.rdtypes.txtbase + class STXT(dns.rdtypes.txtbase.TXTBase): """Test singleton TXT-like record""" diff --git a/tests/test_address.py b/tests/test_address.py index 1ee7022c..bb689b73 100644 --- a/tests/test_address.py +++ b/tests/test_address.py @@ -8,6 +8,7 @@ import dns.exception import dns.ipv4 import dns.ipv6 + class IPv4Tests(unittest.TestCase): def test_valid(self): valid = ( @@ -23,8 +24,7 @@ class IPv4Tests(unittest.TestCase): "192.0.2.128", ) for s in valid: - self.assertEqual(dns.ipv4.inet_aton(s), - socket.inet_pton(socket.AF_INET, s)) + self.assertEqual(dns.ipv4.inet_aton(s), socket.inet_pton(socket.AF_INET, s)) def test_invalid(self): invalid = ( @@ -74,10 +74,12 @@ class IPv4Tests(unittest.TestCase): "::", ) for s in invalid: - with self.assertRaises(dns.exception.SyntaxError, - msg=f'invalid IPv4 address: "{s}"'): + with self.assertRaises( + dns.exception.SyntaxError, msg=f'invalid IPv4 address: "{s}"' + ): dns.ipv4.inet_aton(s) + class IPv6Tests(unittest.TestCase): def test_valid(self): valid = ( @@ -258,12 +260,13 @@ class IPv6Tests(unittest.TestCase): } for s in valid: - if sys.platform == 'win32' and s in win32_invalid: + if sys.platform == "win32" and s in win32_invalid: # socket.inet_pton() on win32 rejects some valid (as # far as we can tell) IPv6 addresses. Skip them. continue - self.assertEqual(dns.ipv6.inet_aton(s), - socket.inet_pton(socket.AF_INET6, s)) + self.assertEqual( + dns.ipv6.inet_aton(s), socket.inet_pton(socket.AF_INET6, s) + ) def test_invalid(self): invalid = ( @@ -576,6 +579,7 @@ class IPv6Tests(unittest.TestCase): "':10.0.0.1", ) for s in invalid: - with self.assertRaises(dns.exception.SyntaxError, - msg=f'invalid IPv6 address: "{s}"'): + with self.assertRaises( + dns.exception.SyntaxError, msg=f'invalid IPv6 address: "{s}"' + ): dns.ipv6.inet_aton(s) diff --git a/tests/test_async.py b/tests/test_async.py index ce0caa14..3c9a7e6d 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -46,7 +46,7 @@ except Exception: # skip those if it's not there. _network_available = True try: - socket.gethostbyname('dnspython.org') + socket.gethostbyname("dnspython.org") except socket.gaierror: _network_available = False @@ -57,15 +57,17 @@ except socket.gaierror: _systemd_resolved_present = False try: _resolver = dns.resolver.Resolver() - if _resolver.nameservers == ['127.0.0.53']: + if _resolver.nameservers == ["127.0.0.53"]: _systemd_resolved_present = True except Exception: pass # Probe for IPv4 and IPv6 query_addresses = [] -for (af, address) in ((socket.AF_INET, '8.8.8.8'), - (socket.AF_INET6, '2001:4860:4860::8888')): +for (af, address) in ( + (socket.AF_INET, "8.8.8.8"), + (socket.AF_INET6, "2001:4860:4860::8888"), +): try: with socket.socket(af, socket.SOCK_DGRAM) as s: # Connecting a UDP socket is supposed to return ENETUNREACH if @@ -75,31 +77,37 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'), except Exception: pass -KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query', - 'https://dns.google/dns-query', - # 'https://dns11.quad9.net/dns-query', - ] +KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ + "https://cloudflare-dns.com/dns-query", + "https://dns.google/dns-query", + # 'https://dns11.quad9.net/dns-query', +] class AsyncDetectionTests(unittest.TestCase): - sniff_result = 'asyncio' + sniff_result = "asyncio" def async_run(self, afunc): return asyncio.run(afunc()) def test_sniff(self): dns.asyncbackend._default_backend = None + async def run(): self.assertEqual(dns.asyncbackend.sniff(), self.sniff_result) + self.async_run(run) def test_get_default_backend(self): dns.asyncbackend._default_backend = None + async def run(): backend = dns.asyncbackend.get_default_backend() self.assertEqual(backend.name(), self.sniff_result) + self.async_run(run) + class NoSniffioAsyncDetectionTests(AsyncDetectionTests): expect_raise = False @@ -112,10 +120,13 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests): def test_sniff(self): dns.asyncbackend._default_backend = None if self.expect_raise: + async def abad(): dns.asyncbackend.sniff() + def bad(): self.async_run(abad) + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) else: super().test_sniff() @@ -123,10 +134,13 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests): def test_get_default_backend(self): dns.asyncbackend._default_backend = None if self.expect_raise: + async def abad(): dns.asyncbackend.get_default_backend() + def bad(): self.async_run(abad) + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) else: super().test_get_default_backend() @@ -135,13 +149,16 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests): class MiscBackend(unittest.TestCase): def test_sniff_without_run_loop(self): dns.asyncbackend._default_backend = None + def bad(): dns.asyncbackend.sniff() + self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad) def test_bogus_backend(self): def bad(): - dns.asyncbackend.get_backend('bogus') + dns.asyncbackend.get_backend("bogus") + self.assertRaises(NotImplementedError, bad) @@ -151,256 +168,297 @@ class MiscQuery(unittest.TestCase): self.assertEqual(t, None) t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 0) self.assertEqual(t, None) - t = dns.asyncquery._source_tuple(socket.AF_INET, '1.2.3.4', 53) - self.assertEqual(t, ('1.2.3.4', 53)) - t = dns.asyncquery._source_tuple(socket.AF_INET6, '1::2', 53) - self.assertEqual(t, ('1::2', 53)) + t = dns.asyncquery._source_tuple(socket.AF_INET, "1.2.3.4", 53) + self.assertEqual(t, ("1.2.3.4", 53)) + t = dns.asyncquery._source_tuple(socket.AF_INET6, "1::2", 53) + self.assertEqual(t, ("1::2", 53)) t = dns.asyncquery._source_tuple(socket.AF_INET, None, 53) - self.assertEqual(t, ('0.0.0.0', 53)) + self.assertEqual(t, ("0.0.0.0", 53)) t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 53) - self.assertEqual(t, ('::', 53)) + self.assertEqual(t, ("::", 53)) @unittest.skipIf(not _network_available, "Internet not reachable") class AsyncTests(unittest.TestCase): - connect_udp = sys.platform == 'win32' + connect_udp = sys.platform == "win32" def setUp(self): - self.backend = dns.asyncbackend.set_default_backend('asyncio') + self.backend = dns.asyncbackend.set_default_backend("asyncio") def async_run(self, afunc): return asyncio.run(afunc()) def testResolve(self): async def run(): - answer = await dns.asyncresolver.resolve('dns.google.', 'A') + answer = await dns.asyncresolver.resolve("dns.google.", "A") return set([rdata.address for rdata in answer]) + seen = self.async_run(run) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testResolveAddress(self): async def run(): - return await dns.asyncresolver.resolve_address('8.8.8.8') + return await dns.asyncresolver.resolve_address("8.8.8.8") + answer = self.async_run(run) - dnsgoogle = dns.name.from_text('dns.google.') + dnsgoogle = dns.name.from_text("dns.google.") self.assertEqual(answer[0].target, dnsgoogle) def testCanonicalNameNoCNAME(self): - cname = dns.name.from_text('www.google.com') + cname = dns.name.from_text("www.google.com") + async def run(): - return await dns.asyncresolver.canonical_name('www.google.com') + return await dns.asyncresolver.canonical_name("www.google.com") + self.assertEqual(self.async_run(run), cname) def testCanonicalNameCNAME(self): - name = dns.name.from_text('www.dnspython.org') - cname = dns.name.from_text('dmfrjf4ips8xa.cloudfront.net') + name = dns.name.from_text("www.dnspython.org") + cname = dns.name.from_text("dmfrjf4ips8xa.cloudfront.net") + async def run(): return await dns.asyncresolver.canonical_name(name) + self.assertEqual(self.async_run(run), cname) @unittest.skipIf(_systemd_resolved_present, "systemd-resolved in use") def testCanonicalNameDangling(self): - name = dns.name.from_text('dangling-cname.dnspython.org') - cname = dns.name.from_text('dangling-target.dnspython.org') + name = dns.name.from_text("dangling-cname.dnspython.org") + cname = dns.name.from_text("dangling-target.dnspython.org") + async def run(): return await dns.asyncresolver.canonical_name(name) - self.assertEqual(self.async_run(run), cname) + self.assertEqual(self.async_run(run), cname) def testZoneForName1(self): async def run(): - name = dns.name.from_text('www.dnspython.org.') + name = dns.name.from_text("www.dnspython.org.") return await dns.asyncresolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') + + ezname = dns.name.from_text("dnspython.org.") zname = self.async_run(run) self.assertEqual(zname, ezname) def testZoneForName2(self): async def run(): - name = dns.name.from_text('a.b.www.dnspython.org.') + name = dns.name.from_text("a.b.www.dnspython.org.") return await dns.asyncresolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') + + ezname = dns.name.from_text("dnspython.org.") zname = self.async_run(run) self.assertEqual(zname, ezname) def testZoneForName3(self): async def run(): - name = dns.name.from_text('dnspython.org.') + name = dns.name.from_text("dnspython.org.") return await dns.asyncresolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') + + ezname = dns.name.from_text("dnspython.org.") zname = self.async_run(run) self.assertEqual(zname, ezname) def testZoneForName4(self): def bad(): - name = dns.name.from_text('dnspython.org', None) + name = dns.name.from_text("dnspython.org", None) + async def run(): return await dns.asyncresolver.zone_for_name(name) + self.async_run(run) + self.assertRaises(dns.resolver.NotAbsolute, bad) def testQueryUDP(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.udp(q, address, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryUDPWithSocket(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): if self.connect_udp: - dtuple=(address, 53) + dtuple = (address, 53) else: - dtuple=None + dtuple = None async with await self.backend.make_socket( - dns.inet.af_for_address(address), - socket.SOCK_DGRAM, 0, None, dtuple) as s: + dns.inet.af_for_address(address), socket.SOCK_DGRAM, 0, None, dtuple + ) as s: q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.udp(q, address, sock=s, - timeout=2) + return await dns.asyncquery.udp(q, address, sock=s, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryTCP(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tcp(q, address, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryTCPWithSocket(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): async with await self.backend.make_socket( - dns.inet.af_for_address(address), - socket.SOCK_STREAM, 0, - None, - (address, 53), 2) as s: + dns.inet.af_for_address(address), + socket.SOCK_STREAM, + 0, + None, + (address, 53), + 2, + ) as s: # for basic coverage await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tcp(q, address, sock=s, - timeout=2) + return await dns.asyncquery.tcp(q, address, sock=s, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) @unittest.skipIf(not _ssl_available, "SSL not available") def testQueryTLS(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tls(q, address, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) @unittest.skipIf(not _ssl_available, "SSL not available") def testQueryTLSWithSocket(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): ssl_context = ssl.create_default_context() ssl_context.check_hostname = False async with await self.backend.make_socket( - dns.inet.af_for_address(address), - socket.SOCK_STREAM, 0, - None, - (address, 853), 2, - ssl_context, None) as s: + dns.inet.af_for_address(address), + socket.SOCK_STREAM, + 0, + None, + (address, 853), + 2, + ssl_context, + None, + ) as s: # for basic coverage await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tls(q, '8.8.8.8', sock=s, - timeout=2) + return await dns.asyncquery.tls(q, "8.8.8.8", sock=s, timeout=2) + response = self.async_run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryUDPFallback(self): for address in query_addresses: - qname = dns.name.from_text('.') + qname = dns.name.from_text(".") + async def run(): q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) - return await dns.asyncquery.udp_with_fallback(q, address, - timeout=2) + return await dns.asyncquery.udp_with_fallback(q, address, timeout=2) + (_, tcp) = self.async_run(run) self.assertTrue(tcp) def testQueryUDPFallbackNoFallback(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") + async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.udp_with_fallback(q, address, - timeout=2) + return await dns.asyncquery.udp_with_fallback(q, address, timeout=2) + (_, tcp) = self.async_run(run) self.assertFalse(tcp) def testUDPReceiveQuery(self): if self.connect_udp: - self.skipTest('test needs connectionless sockets') + self.skipTest("test needs connectionless sockets") + async def run(): async with await self.backend.make_socket( - socket.AF_INET, socket.SOCK_DGRAM, - source=('127.0.0.1', 0)) as listener: + socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0) + ) as listener: listener_address = await listener.getsockname() async with await self.backend.make_socket( - socket.AF_INET, socket.SOCK_DGRAM, - source=('127.0.0.1', 0)) as sender: + socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0) + ) as sender: sender_address = await sender.getsockname() - q = dns.message.make_query('dns.google', dns.rdatatype.A) + q = dns.message.make_query("dns.google", dns.rdatatype.A) await dns.asyncquery.send_udp(sender, q, listener_address) expiration = time.time() + 2 (_, _, recv_address) = await dns.asyncquery.receive_udp( - listener, expiration=expiration) + listener, expiration=expiration + ) return (sender_address, recv_address) + (sender_address, recv_address) = self.async_run(run) self.assertEqual(sender_address, recv_address) def testUDPReceiveTimeout(self): if self.connect_udp: - self.skipTest('test needs connectionless sockets') + self.skipTest("test needs connectionless sockets") + async def arun(): - async with await self.backend.make_socket(socket.AF_INET, - socket.SOCK_DGRAM, 0, - ('127.0.0.1', 0)) as s: + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, 0, ("127.0.0.1", 0) + ) as s: try: # for basic coverage await s.getpeername() @@ -408,62 +466,69 @@ class AsyncTests(unittest.TestCase): # we expect failure as we haven't connected the socket pass await s.recvfrom(1000, 0.05) + def run(): self.async_run(arun) + self.assertRaises(dns.exception.Timeout, run) @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testDOHGetRequest(self): - if self.backend.name() == 'curio': - self.skipTest('anyio dropped curio support') + if self.backend.name() == "curio": + self.skipTest("anyio dropped curio support") + async def run(): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=False, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4) self.assertTrue(q.is_response(r)) + self.async_run(run) @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testDOHGetRequestHttp1(self): - if self.backend.name() == 'curio': - self.skipTest('anyio dropped curio support') + if self.backend.name() == "curio": + self.skipTest("anyio dropped curio support") + async def run(): saved_have_http2 = dns.query._have_http2 try: dns.query._have_http2 = False nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=False, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4) self.assertTrue(q.is_response(r)) finally: dns.query._have_http2 = saved_have_http2 + self.async_run(run) @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testDOHPostRequest(self): - if self.backend.name() == 'curio': - self.skipTest('anyio dropped curio support') + if self.backend.name() == "curio": + self.skipTest("anyio dropped curio support") + async def run(): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=True, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = await dns.asyncquery.https(q, nameserver_url, post=True, timeout=4) self.assertTrue(q.is_response(r)) + self.async_run(run) @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testResolverDOH(self): - if self.backend.name() == 'curio': - self.skipTest('anyio dropped curio support') + if self.backend.name() == "curio": + self.skipTest("anyio dropped curio support") + async def run(): res = dns.asyncresolver.Resolver(configure=False) - res.nameservers = ['https://dns.google/dns-query'] - answer = await res.resolve('dns.google', 'A', backend=self.backend) + res.nameservers = ["https://dns.google/dns-query"] + answer = await res.resolve("dns.google", "A", backend=self.backend) seen = set([rdata.address for rdata in answer]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) + self.async_run(run) def testSleep(self): @@ -472,29 +537,35 @@ class AsyncTests(unittest.TestCase): await self.backend.sleep(0.1) after = time.time() self.assertTrue(after - before >= 0.1) + self.async_run(run) + try: import trio import sniffio class TrioAsyncDetectionTests(AsyncDetectionTests): - sniff_result = 'trio' + sniff_result = "trio" + def async_run(self, afunc): return trio.run(afunc) class TrioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests): expect_raise = True + def async_run(self, afunc): return trio.run(afunc) class TrioAsyncTests(AsyncTests): connect_udp = False + def setUp(self): - self.backend = dns.asyncbackend.set_default_backend('trio') + self.backend = dns.asyncbackend.set_default_backend("trio") def async_run(self, afunc): return trio.run(afunc) + except ImportError: pass @@ -503,21 +574,25 @@ try: import sniffio class CurioAsyncDetectionTests(AsyncDetectionTests): - sniff_result = 'curio' + sniff_result = "curio" + def async_run(self, afunc): return curio.run(afunc) class CurioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests): expect_raise = True + def async_run(self, afunc): return curio.run(afunc) class CurioAsyncTests(AsyncTests): connect_udp = False + def setUp(self): - self.backend = dns.asyncbackend.set_default_backend('curio') + self.backend = dns.asyncbackend.set_default_backend("curio") def async_run(self, afunc): return curio.run(afunc) + except ImportError: pass diff --git a/tests/test_bugs.py b/tests/test_bugs.py index 3080e50c..ac76b8de 100644 --- a/tests/test_bugs.py +++ b/tests/test_bugs.py @@ -28,69 +28,72 @@ import dns.ttl class BugsTestCase(unittest.TestCase): - def test_float_LOC(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.LOC, - u"30 30 0.000 N 100 30 0.000 W 10.00m 20m 2000m 20m") + rdata = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.LOC, + "30 30 0.000 N 100 30 0.000 W 10.00m 20m 2000m 20m", + ) self.assertEqual(rdata.float_latitude, 30.5) self.assertEqual(rdata.float_longitude, -100.5) def test_SOA_BIND8_TTL(self): - rdata1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - u"a b 100 1s 1m 1h 1d") - rdata2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - u"a b 100 1 60 3600 86400") + rdata1 = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "a b 100 1s 1m 1h 1d" + ) + rdata2 = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "a b 100 1 60 3600 86400" + ) self.assertEqual(rdata1, rdata2) def test_empty_NSEC3_window(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NSEC3, - u"1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715") + rdata = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.NSEC3, + "1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715", + ) self.assertEqual(rdata.windows, ()) def test_zero_size_APL(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.APL, - "") - rdata2 = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.APL, - "", 0, 0) + rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.APL, "") + rdata2 = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.APL, "", 0, 0) self.assertEqual(rdata, rdata2) def test_CAA_from_wire(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CAA, - '0 issue "ca.example.net"') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CAA, '0 issue "ca.example.net"' + ) f = BytesIO() rdata.to_wire(f) wire = f.getvalue() rdlen = len(wire) wire += b"trailing garbage" - rdata2 = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.CAA, - wire, 0, rdlen) + rdata2 = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.CAA, wire, 0, rdlen + ) self.assertEqual(rdata, rdata2) def test_trailing_zero_APL(self): in4 = "!1:127.0.0.0/1" rd4 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.APL, in4) out4 = rd4.to_digestable(dns.name.from_text("test")) - text4 = binascii.hexlify(out4).decode('ascii') - self.assertEqual(text4, '000101817f') + text4 = binascii.hexlify(out4).decode("ascii") + self.assertEqual(text4, "000101817f") in6 = "!2:::1000/1" rd6 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.APL, in6) out6 = rd6.to_digestable(dns.name.from_text("test")) - text6 = binascii.hexlify(out6).decode('ascii') - self.assertEqual(text6, '0002018f000000000000000000000000000010') + text6 = binascii.hexlify(out6).decode("ascii") + self.assertEqual(text6, "0002018f000000000000000000000000000010") def test_TXT_conversions(self): - t1 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, - [b'foo']) - t2 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, - b'foo') - t3 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, - 'foo') - t4 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, - ['foo']) + t1 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, [b"foo"]) + t2 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, b"foo") + t3 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, "foo") + t4 = dns.rdtypes.ANY.TXT.TXT(dns.rdataclass.IN, dns.rdatatype.TXT, ["foo"]) self.assertEqual(t1, t2) self.assertEqual(t1, t2) self.assertEqual(t1, t4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_constants.py b/tests/test_constants.py index 0c38c281..bf0d9709 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -15,10 +15,10 @@ import tests.util class ConstantsTestCase(unittest.TestCase): - def test_dnssec_constants(self): - tests.util.check_enum_exports(dns.dnssec, self.assertEqual, - only={dns.dnssec.Algorithm}) + tests.util.check_enum_exports( + dns.dnssec, self.assertEqual, only={dns.dnssec.Algorithm} + ) tests.util.check_enum_exports(dns.rdtypes.dnskeybase, self.assertEqual) def test_flags_constants(self): @@ -35,5 +35,6 @@ class ConstantsTestCase(unittest.TestCase): tests.util.check_enum_exports(dns.rdatatype, self.assertEqual) def test_edns_constants(self): - tests.util.check_enum_exports(dns.edns, self.assertEqual, - only={dns.edns.OptionType}) + tests.util.check_enum_exports( + dns.edns, self.assertEqual, only={dns.edns.OptionType} + ) diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py index e8189925..fbc83f91 100644 --- a/tests/test_dnssec.py +++ b/tests/test_dnssec.py @@ -30,246 +30,419 @@ import dns.rrset # pylint: disable=line-too-long -abs_dnspython_org = dns.name.from_text('dnspython.org') +abs_dnspython_org = dns.name.from_text("dnspython.org") abs_keys = { abs_dnspython_org: dns.rrset.from_text( - 'dnspython.org.', 3600, 'IN', 'DNSKEY', - '257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=', - '256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF' + "dnspython.org.", + 3600, + "IN", + "DNSKEY", + "257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=", + "256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF", ) } abs_keys_duplicate_keytag = { abs_dnspython_org: dns.rrset.from_text( - 'dnspython.org.', 3600, 'IN', 'DNSKEY', - '257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=', - '256 3 5 AwEAAdSSg++++THIS/IS/NOT/THE/CORRECT/KEY++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ AaOSydAF', - '256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF' + "dnspython.org.", + 3600, + "IN", + "DNSKEY", + "257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=", + "256 3 5 AwEAAdSSg++++THIS/IS/NOT/THE/CORRECT/KEY++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ AaOSydAF", + "256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF", ) } rel_keys = { dns.name.empty: dns.rrset.from_text( - '@', 3600, 'IN', 'DNSKEY', - '257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=', - '256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF' + "@", + 3600, + "IN", + "DNSKEY", + "257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=", + "256 3 5 AwEAAdSSghOGjU33IQZgwZM2Hh771VGXX05olJK49FxpSyuEAjDBXY58 LGU9R2Zgeecnk/b9EAhFu/vCV9oECtiTCvwuVAkt9YEweqYDluQInmgP NGMJCKdSLlnX93DkjDw8rMYv5dqXCuSGPlKChfTJOLQxIAxGloS7lL+c 0CTZydAF", ) } when = 1290250287 -abs_soa = dns.rrset.from_text('dnspython.org.', 3600, 'IN', 'SOA', - 'howl.dnspython.org. hostmaster.dnspython.org. 2010020047 3600 1800 604800 3600') - -abs_other_soa = dns.rrset.from_text('dnspython.org.', 3600, 'IN', 'SOA', - 'foo.dnspython.org. hostmaster.dnspython.org. 2010020047 3600 1800 604800 3600') - -abs_soa_rrsig = dns.rrset.from_text('dnspython.org.', 3600, 'IN', 'RRSIG', - 'SOA 5 2 3600 20101127004331 20101119213831 61695 dnspython.org. sDUlltRlFTQw5ITFxOXW3TgmrHeMeNpdqcZ4EXxM9FHhIlte6V9YCnDw t6dvM9jAXdIEi03l9H/RAd9xNNW6gvGMHsBGzpvvqFQxIBR2PoiZA1mX /SWHZFdbt4xjYTtXqpyYvrMK0Dt7bUYPadyhPFCJ1B+I8Zi7B5WJEOd0 8vs=') - -rel_soa = dns.rrset.from_text('@', 3600, 'IN', 'SOA', - 'howl hostmaster 2010020047 3600 1800 604800 3600') - -rel_other_soa = dns.rrset.from_text('@', 3600, 'IN', 'SOA', - 'foo hostmaster 2010020047 3600 1800 604800 3600') - -rel_soa_rrsig = dns.rrset.from_text('@', 3600, 'IN', 'RRSIG', - 'SOA 5 2 3600 20101127004331 20101119213831 61695 @ sDUlltRlFTQw5ITFxOXW3TgmrHeMeNpdqcZ4EXxM9FHhIlte6V9YCnDw t6dvM9jAXdIEi03l9H/RAd9xNNW6gvGMHsBGzpvvqFQxIBR2PoiZA1mX /SWHZFdbt4xjYTtXqpyYvrMK0Dt7bUYPadyhPFCJ1B+I8Zi7B5WJEOd0 8vs=') - -sep_key = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DNSKEY, - '257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=') - -good_ds = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, - '57349 5 2 53A79A3E7488AB44FFC56B2D1109F0699D1796DD977E72108B841F96 E47D7013') +abs_soa = dns.rrset.from_text( + "dnspython.org.", + 3600, + "IN", + "SOA", + "howl.dnspython.org. hostmaster.dnspython.org. 2010020047 3600 1800 604800 3600", +) + +abs_other_soa = dns.rrset.from_text( + "dnspython.org.", + 3600, + "IN", + "SOA", + "foo.dnspython.org. hostmaster.dnspython.org. 2010020047 3600 1800 604800 3600", +) + +abs_soa_rrsig = dns.rrset.from_text( + "dnspython.org.", + 3600, + "IN", + "RRSIG", + "SOA 5 2 3600 20101127004331 20101119213831 61695 dnspython.org. sDUlltRlFTQw5ITFxOXW3TgmrHeMeNpdqcZ4EXxM9FHhIlte6V9YCnDw t6dvM9jAXdIEi03l9H/RAd9xNNW6gvGMHsBGzpvvqFQxIBR2PoiZA1mX /SWHZFdbt4xjYTtXqpyYvrMK0Dt7bUYPadyhPFCJ1B+I8Zi7B5WJEOd0 8vs=", +) + +rel_soa = dns.rrset.from_text( + "@", 3600, "IN", "SOA", "howl hostmaster 2010020047 3600 1800 604800 3600" +) + +rel_other_soa = dns.rrset.from_text( + "@", 3600, "IN", "SOA", "foo hostmaster 2010020047 3600 1800 604800 3600" +) + +rel_soa_rrsig = dns.rrset.from_text( + "@", + 3600, + "IN", + "RRSIG", + "SOA 5 2 3600 20101127004331 20101119213831 61695 @ sDUlltRlFTQw5ITFxOXW3TgmrHeMeNpdqcZ4EXxM9FHhIlte6V9YCnDw t6dvM9jAXdIEi03l9H/RAd9xNNW6gvGMHsBGzpvvqFQxIBR2PoiZA1mX /SWHZFdbt4xjYTtXqpyYvrMK0Dt7bUYPadyhPFCJ1B+I8Zi7B5WJEOd0 8vs=", +) + +sep_key = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DNSKEY, + "257 3 5 AwEAAenVTr9L1OMlL1/N2ta0Qj9LLLnnmFWIr1dJoAsWM9BQfsbV7kFZ XbAkER/FY9Ji2o7cELxBwAsVBuWn6IUUAJXLH74YbC1anY0lifjgt29z SwDzuB7zmC7yVYZzUunBulVW4zT0tg1aePbpVL2EtTL8VzREqbJbE25R KuQYHZtFwG8S4iBxJUmT2Bbd0921LLxSQgVoFXlQx/gFV2+UERXcJ5ce iX6A6wc02M/pdg/YbJd2rBa0MYL3/Fz/Xltre0tqsImZGxzi6YtYDs45 NC8gH+44egz82e2DATCVM1ICPmRDjXYTLldQiWA2ZXIWnK0iitl5ue24 7EsWJefrIhE=", +) + +good_ds = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + "57349 5 2 53A79A3E7488AB44FFC56B2D1109F0699D1796DD977E72108B841F96 E47D7013", +) when2 = 1290425644 -abs_example = dns.name.from_text('example') +abs_example = dns.name.from_text("example") abs_dsa_keys = { abs_example: dns.rrset.from_text( - 'example.', 86400, 'IN', 'DNSKEY', - '257 3 3 CI3nCqyJsiCJHTjrNsJOT4RaszetzcJPYuoH3F9ZTVt3KJXncCVR3bwn 1w0iavKljb9hDlAYSfHbFCp4ic/rvg4p1L8vh5s8ToMjqDNl40A0hUGQ Ybx5hsECyK+qHoajilUX1phYSAD8d9WAGO3fDWzUPBuzR7o85NiZCDxz yXuNVfni0uhj9n1KYhEO5yAbbruDGN89wIZcxMKuQsdUY2GYD93ssnBv a55W6XRABYWayKZ90WkRVODLVYLSn53Pj/wwxGH+XdhIAZJXimrZL4yl My7rtBsLMqq8Ihs4Tows7LqYwY7cp6y/50tw6pj8tFqMYcPUjKZV36l1 M/2t5BVg3i7IK61Aidt6aoC3TDJtzAxg3ZxfjZWJfhHjMJqzQIfbW5b9 q1mjFsW5EUv39RaNnX+3JWPRLyDqD4pIwDyqfutMsdk/Py3paHn82FGp CaOg+nicqZ9TiMZURN/XXy5JoXUNQ3RNvbHCUiPUe18KUkY6mTfnyHld 1l9YCWmzXQVClkx/hOYxjJ4j8Ife58+Obu5X', - '256 3 3 CJE1yb9YRQiw5d2xZrMUMR+cGCTt1bp1KDCefmYKmS+Z1+q9f42ETVhx JRiQwXclYwmxborzIkSZegTNYIV6mrYwbNB27Q44c3UGcspb3PiOw5TC jNPRYEcdwGvDZ2wWy+vkSV/S9tHXY8O6ODiE6abZJDDg/RnITyi+eoDL R3KZ5n/V1f1T1b90rrV6EewhBGQJpQGDogaXb2oHww9Tm6NfXyo7SoMM pbwbzOckXv+GxRPJIQNSF4D4A9E8XCksuzVVdE/0lr37+uoiAiPia38U 5W2QWe/FJAEPLjIp2eTzf0TrADc1pKP1wrA2ASpdzpm/aX3IB5RPp8Ew S9U72eBFZJAUwg635HxJVxH1maG6atzorR566E+e0OZSaxXS9o1o6QqN 3oPlYLGPORDiExilKfez3C/x/yioOupW9K5eKF0gmtaqrHX0oq9s67f/ RIM2xVaKHgG9Vf2cgJIZkhv7sntujr+E4htnRmy9P9BxyFxsItYxPI6Z bzygHAZpGhlI/7ltEGlIwKxyTK3ZKBm67q7B' + "example.", + 86400, + "IN", + "DNSKEY", + "257 3 3 CI3nCqyJsiCJHTjrNsJOT4RaszetzcJPYuoH3F9ZTVt3KJXncCVR3bwn 1w0iavKljb9hDlAYSfHbFCp4ic/rvg4p1L8vh5s8ToMjqDNl40A0hUGQ Ybx5hsECyK+qHoajilUX1phYSAD8d9WAGO3fDWzUPBuzR7o85NiZCDxz yXuNVfni0uhj9n1KYhEO5yAbbruDGN89wIZcxMKuQsdUY2GYD93ssnBv a55W6XRABYWayKZ90WkRVODLVYLSn53Pj/wwxGH+XdhIAZJXimrZL4yl My7rtBsLMqq8Ihs4Tows7LqYwY7cp6y/50tw6pj8tFqMYcPUjKZV36l1 M/2t5BVg3i7IK61Aidt6aoC3TDJtzAxg3ZxfjZWJfhHjMJqzQIfbW5b9 q1mjFsW5EUv39RaNnX+3JWPRLyDqD4pIwDyqfutMsdk/Py3paHn82FGp CaOg+nicqZ9TiMZURN/XXy5JoXUNQ3RNvbHCUiPUe18KUkY6mTfnyHld 1l9YCWmzXQVClkx/hOYxjJ4j8Ife58+Obu5X", + "256 3 3 CJE1yb9YRQiw5d2xZrMUMR+cGCTt1bp1KDCefmYKmS+Z1+q9f42ETVhx JRiQwXclYwmxborzIkSZegTNYIV6mrYwbNB27Q44c3UGcspb3PiOw5TC jNPRYEcdwGvDZ2wWy+vkSV/S9tHXY8O6ODiE6abZJDDg/RnITyi+eoDL R3KZ5n/V1f1T1b90rrV6EewhBGQJpQGDogaXb2oHww9Tm6NfXyo7SoMM pbwbzOckXv+GxRPJIQNSF4D4A9E8XCksuzVVdE/0lr37+uoiAiPia38U 5W2QWe/FJAEPLjIp2eTzf0TrADc1pKP1wrA2ASpdzpm/aX3IB5RPp8Ew S9U72eBFZJAUwg635HxJVxH1maG6atzorR566E+e0OZSaxXS9o1o6QqN 3oPlYLGPORDiExilKfez3C/x/yioOupW9K5eKF0gmtaqrHX0oq9s67f/ RIM2xVaKHgG9Vf2cgJIZkhv7sntujr+E4htnRmy9P9BxyFxsItYxPI6Z bzygHAZpGhlI/7ltEGlIwKxyTK3ZKBm67q7B", ) } -abs_dsa_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 2 10800 3600 604800 86400') - -abs_other_dsa_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 2 10800 3600 604800 86401') - -abs_dsa_soa_rrsig = dns.rrset.from_text('example.', 86400, 'IN', 'RRSIG', - 'SOA 3 1 86400 20101129143231 20101122112731 42088 example. CGul9SuBofsktunV8cJs4eRs6u+3NCS3yaPKvBbD+pB2C76OUXDZq9U=') - -example_sep_key = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DNSKEY, - '257 3 3 CI3nCqyJsiCJHTjrNsJOT4RaszetzcJPYuoH3F9ZTVt3KJXncCVR3bwn 1w0iavKljb9hDlAYSfHbFCp4ic/rvg4p1L8vh5s8ToMjqDNl40A0hUGQ Ybx5hsECyK+qHoajilUX1phYSAD8d9WAGO3fDWzUPBuzR7o85NiZCDxz yXuNVfni0uhj9n1KYhEO5yAbbruDGN89wIZcxMKuQsdUY2GYD93ssnBv a55W6XRABYWayKZ90WkRVODLVYLSn53Pj/wwxGH+XdhIAZJXimrZL4yl My7rtBsLMqq8Ihs4Tows7LqYwY7cp6y/50tw6pj8tFqMYcPUjKZV36l1 M/2t5BVg3i7IK61Aidt6aoC3TDJtzAxg3ZxfjZWJfhHjMJqzQIfbW5b9 q1mjFsW5EUv39RaNnX+3JWPRLyDqD4pIwDyqfutMsdk/Py3paHn82FGp CaOg+nicqZ9TiMZURN/XXy5JoXUNQ3RNvbHCUiPUe18KUkY6mTfnyHld 1l9YCWmzXQVClkx/hOYxjJ4j8Ife58+Obu5X') - -example_ds_sha1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, - '18673 3 1 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7') - -example_ds_sha256 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, - '18673 3 2 eb8344cbbf07c9d3d3d6c81d10c76653e28d8611a65e639ef8f716e4e4e5d913') - -example_ds_sha384 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, - '18673 3 4 61ab241025c5f88d2537be04dcfba96f952adaefe0b382ecbc4108c97b75768c9e99fd16caed2a09634c51e8089fb84f') +abs_dsa_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 2 10800 3600 604800 86400", +) + +abs_other_dsa_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 2 10800 3600 604800 86401", +) + +abs_dsa_soa_rrsig = dns.rrset.from_text( + "example.", + 86400, + "IN", + "RRSIG", + "SOA 3 1 86400 20101129143231 20101122112731 42088 example. CGul9SuBofsktunV8cJs4eRs6u+3NCS3yaPKvBbD+pB2C76OUXDZq9U=", +) + +example_sep_key = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DNSKEY, + "257 3 3 CI3nCqyJsiCJHTjrNsJOT4RaszetzcJPYuoH3F9ZTVt3KJXncCVR3bwn 1w0iavKljb9hDlAYSfHbFCp4ic/rvg4p1L8vh5s8ToMjqDNl40A0hUGQ Ybx5hsECyK+qHoajilUX1phYSAD8d9WAGO3fDWzUPBuzR7o85NiZCDxz yXuNVfni0uhj9n1KYhEO5yAbbruDGN89wIZcxMKuQsdUY2GYD93ssnBv a55W6XRABYWayKZ90WkRVODLVYLSn53Pj/wwxGH+XdhIAZJXimrZL4yl My7rtBsLMqq8Ihs4Tows7LqYwY7cp6y/50tw6pj8tFqMYcPUjKZV36l1 M/2t5BVg3i7IK61Aidt6aoC3TDJtzAxg3ZxfjZWJfhHjMJqzQIfbW5b9 q1mjFsW5EUv39RaNnX+3JWPRLyDqD4pIwDyqfutMsdk/Py3paHn82FGp CaOg+nicqZ9TiMZURN/XXy5JoXUNQ3RNvbHCUiPUe18KUkY6mTfnyHld 1l9YCWmzXQVClkx/hOYxjJ4j8Ife58+Obu5X", +) + +example_ds_sha1 = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + "18673 3 1 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7", +) + +example_ds_sha256 = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + "18673 3 2 eb8344cbbf07c9d3d3d6c81d10c76653e28d8611a65e639ef8f716e4e4e5d913", +) + +example_ds_sha384 = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + "18673 3 4 61ab241025c5f88d2537be04dcfba96f952adaefe0b382ecbc4108c97b75768c9e99fd16caed2a09634c51e8089fb84f", +) when3 = 1379801800 abs_ecdsa256_keys = { abs_example: dns.rrset.from_text( - 'example.', 86400, 'IN', 'DNSKEY', + "example.", + 86400, + "IN", + "DNSKEY", "256 3 13 +3ss1sCpdARVA61DJigEsL/8quo2a8MszKtn2gkkfxgzFs8S2UHtpb4N fY+XFmNW+JK6MsCkI3jHYN8eEQUgMw==", - "257 3 13 eJCEVH7AS3wnoaQpaNlAXH0W8wxymtT9P6P3qjN2ZCV641ED8pF7wZ5V yWfOpgTs6oaZevbJgehl/GaRPUgVyQ==" + "257 3 13 eJCEVH7AS3wnoaQpaNlAXH0W8wxymtT9P6P3qjN2ZCV641ED8pF7wZ5V yWfOpgTs6oaZevbJgehl/GaRPUgVyQ==", ) } -abs_ecdsa256_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 4 10800 3600 604800 86400') - -abs_other_ecdsa256_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 2 10800 3600 604800 86401') - -abs_ecdsa256_soa_rrsig = dns.rrset.from_text('example.', 86400, 'IN', 'RRSIG', - "SOA 13 1 86400 20130921221753 20130921221638 7460 example. Sm09SOGz1ULB5D/duwdE2Zpn8bWbVBM77H6N1wPkc42LevvVO+kZEjpq 2nq4GOMJcih52667GIAbMrwmU5P2MQ==") +abs_ecdsa256_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 4 10800 3600 604800 86400", +) + +abs_other_ecdsa256_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 2 10800 3600 604800 86401", +) + +abs_ecdsa256_soa_rrsig = dns.rrset.from_text( + "example.", + 86400, + "IN", + "RRSIG", + "SOA 13 1 86400 20130921221753 20130921221638 7460 example. Sm09SOGz1ULB5D/duwdE2Zpn8bWbVBM77H6N1wPkc42LevvVO+kZEjpq 2nq4GOMJcih52667GIAbMrwmU5P2MQ==", +) when4 = 1379804850 abs_ecdsa384_keys = { abs_example: dns.rrset.from_text( - 'example.', 86400, 'IN', 'DNSKEY', + "example.", + 86400, + "IN", + "DNSKEY", "256 3 14 1bG8qWviKNXQX3BIuG6/T5jrP1FISiLW/8qGF6BsM9DQtWYhhZUA3Owr OAEiyHAhQwjkN2kTvWiAYoPN80Ii+5ff9/atzY4F9W50P4l75Dj9PYrL HN/hLUgWMNVc9pvA", - "257 3 14 mSub2n0KRt6u2FaD5XJ3oQu0R4XvB/9vUJcyW6+oo0y+KzfQeTdkf1ro ZMVKoyWXW9zUKBYGJpMUIdbAxzrYi7f5HyZ3yDpBFz1hw9+o3CX+gtgb +RyhHfJDwwFXBid9" + "257 3 14 mSub2n0KRt6u2FaD5XJ3oQu0R4XvB/9vUJcyW6+oo0y+KzfQeTdkf1ro ZMVKoyWXW9zUKBYGJpMUIdbAxzrYi7f5HyZ3yDpBFz1hw9+o3CX+gtgb +RyhHfJDwwFXBid9", ) } -abs_ecdsa384_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 2 10800 3600 604800 86400') - -abs_other_ecdsa384_soa = dns.rrset.from_text('example.', 86400, 'IN', 'SOA', - 'ns1.example. hostmaster.example. 2 10800 3600 604800 86401') - -abs_ecdsa384_soa_rrsig = dns.rrset.from_text('example.', 86400, 'IN', 'RRSIG', - "SOA 14 1 86400 20130929021229 20130921230729 63571 example. CrnCu34EeeRz0fEhL9PLlwjpBKGYW8QjBjFQTwd+ViVLRAS8tNkcDwQE NhSV89NEjj7ze1a/JcCfcJ+/mZgnvH4NHLNg3Tf6KuLZsgs2I4kKQXEk 37oIHravPEOlGYNI") - -abs_example_com = dns.name.from_text('example.com') - -abs_ed25519_mx = dns.rrset.from_text('example.com.', 3600, 'IN', 'MX', - '10 mail.example.com.') -abs_other_ed25519_mx = dns.rrset.from_text('example.com.', 3600, 'IN', 'MX', - '11 mail.example.com.') +abs_ecdsa384_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 2 10800 3600 604800 86400", +) + +abs_other_ecdsa384_soa = dns.rrset.from_text( + "example.", + 86400, + "IN", + "SOA", + "ns1.example. hostmaster.example. 2 10800 3600 604800 86401", +) + +abs_ecdsa384_soa_rrsig = dns.rrset.from_text( + "example.", + 86400, + "IN", + "RRSIG", + "SOA 14 1 86400 20130929021229 20130921230729 63571 example. CrnCu34EeeRz0fEhL9PLlwjpBKGYW8QjBjFQTwd+ViVLRAS8tNkcDwQE NhSV89NEjj7ze1a/JcCfcJ+/mZgnvH4NHLNg3Tf6KuLZsgs2I4kKQXEk 37oIHravPEOlGYNI", +) + +abs_example_com = dns.name.from_text("example.com") + +abs_ed25519_mx = dns.rrset.from_text( + "example.com.", 3600, "IN", "MX", "10 mail.example.com." +) +abs_other_ed25519_mx = dns.rrset.from_text( + "example.com.", 3600, "IN", "MX", "11 mail.example.com." +) abs_ed25519_keys_1 = { abs_example_com: dns.rrset.from_text( - 'example.com', 3600, 'IN', 'DNSKEY', - '257 3 15 l02Woi0iS8Aa25FQkUd9RMzZHJpBoRQwAQEX1SxZJA4=') + "example.com", + 3600, + "IN", + "DNSKEY", + "257 3 15 l02Woi0iS8Aa25FQkUd9RMzZHJpBoRQwAQEX1SxZJA4=", + ) } -abs_ed25519_mx_rrsig_1 = dns.rrset.from_text('example.com.', 3600, 'IN', 'RRSIG', - 'MX 15 2 3600 1440021600 1438207200 3613 example.com. oL9krJun7xfBOIWcGHi7mag5/hdZrKWw15jPGrHpjQeRAvTdszaPD+QLs3fx8A4M3e23mRZ9VrbpMngwcrqNAg==') +abs_ed25519_mx_rrsig_1 = dns.rrset.from_text( + "example.com.", + 3600, + "IN", + "RRSIG", + "MX 15 2 3600 1440021600 1438207200 3613 example.com. oL9krJun7xfBOIWcGHi7mag5/hdZrKWw15jPGrHpjQeRAvTdszaPD+QLs3fx8A4M3e23mRZ9VrbpMngwcrqNAg==", +) abs_ed25519_keys_2 = { abs_example_com: dns.rrset.from_text( - 'example.com', 3600, 'IN', 'DNSKEY', - '257 3 15 zPnZ/QwEe7S8C5SPz2OfS5RR40ATk2/rYnE9xHIEijs=') + "example.com", + 3600, + "IN", + "DNSKEY", + "257 3 15 zPnZ/QwEe7S8C5SPz2OfS5RR40ATk2/rYnE9xHIEijs=", + ) } -abs_ed25519_mx_rrsig_2 = dns.rrset.from_text('example.com.', 3600, 'IN', 'RRSIG', - 'MX 15 2 3600 1440021600 1438207200 35217 example.com. zXQ0bkYgQTEFyfLyi9QoiY6D8ZdYo4wyUhVioYZXFdT410QPRITQSqJSnzQoSm5poJ7gD7AQR0O7KuI5k2pcBg==') +abs_ed25519_mx_rrsig_2 = dns.rrset.from_text( + "example.com.", + 3600, + "IN", + "RRSIG", + "MX 15 2 3600 1440021600 1438207200 35217 example.com. zXQ0bkYgQTEFyfLyi9QoiY6D8ZdYo4wyUhVioYZXFdT410QPRITQSqJSnzQoSm5poJ7gD7AQR0O7KuI5k2pcBg==", +) abs_ed448_mx = abs_ed25519_mx abs_other_ed448_mx = abs_other_ed25519_mx abs_ed448_keys_1 = { abs_example_com: dns.rrset.from_text( - 'example.com', 3600, 'IN', 'DNSKEY', - '257 3 16 3kgROaDjrh0H2iuixWBrc8g2EpBBLCdGzHmn+G2MpTPhpj/OiBVHHSfPodx1FYYUcJKm1MDpJtIA') + "example.com", + 3600, + "IN", + "DNSKEY", + "257 3 16 3kgROaDjrh0H2iuixWBrc8g2EpBBLCdGzHmn+G2MpTPhpj/OiBVHHSfPodx1FYYUcJKm1MDpJtIA", + ) } -abs_ed448_mx_rrsig_1 = dns.rrset.from_text('example.com.', 3600, 'IN', 'RRSIG', - 'MX 16 2 3600 1440021600 1438207200 9713 example.com. 3cPAHkmlnxcDHMyg7vFC34l0blBhuG1qpwLmjInI8w1CMB29FkEAIJUA0amxWndkmnBZ6SKiwZSAxGILn/NBtOXft0+Gj7FSvOKxE/07+4RQvE581N3Aj/JtIyaiYVdnYtyMWbSNyGEY2213WKsJlwEA') +abs_ed448_mx_rrsig_1 = dns.rrset.from_text( + "example.com.", + 3600, + "IN", + "RRSIG", + "MX 16 2 3600 1440021600 1438207200 9713 example.com. 3cPAHkmlnxcDHMyg7vFC34l0blBhuG1qpwLmjInI8w1CMB29FkEAIJUA0amxWndkmnBZ6SKiwZSAxGILn/NBtOXft0+Gj7FSvOKxE/07+4RQvE581N3Aj/JtIyaiYVdnYtyMWbSNyGEY2213WKsJlwEA", +) abs_ed448_keys_2 = { abs_example_com: dns.rrset.from_text( - 'example.com', 3600, 'IN', 'DNSKEY', - '257 3 16 kkreGWoccSDmUBGAe7+zsbG6ZAFQp+syPmYUurBRQc3tDjeMCJcVMRDmgcNLp5HlHAMy12VoISsA') + "example.com", + 3600, + "IN", + "DNSKEY", + "257 3 16 kkreGWoccSDmUBGAe7+zsbG6ZAFQp+syPmYUurBRQc3tDjeMCJcVMRDmgcNLp5HlHAMy12VoISsA", + ) } -abs_ed448_mx_rrsig_2 = dns.rrset.from_text('example.com.', 3600, 'IN', 'RRSIG', - 'MX 16 2 3600 1440021600 1438207200 38353 example.com. E1/oLjSGIbmLny/4fcgM1z4oL6aqo+izT3urCyHyvEp4Sp8Syg1eI+lJ57CSnZqjJP41O/9l4m0AsQ4f7qI1gVnML8vWWiyW2KXhT9kuAICUSxv5OWbf81Rq7Yu60npabODB0QFPb/rkW3kUZmQ0YQUA') +abs_ed448_mx_rrsig_2 = dns.rrset.from_text( + "example.com.", + 3600, + "IN", + "RRSIG", + "MX 16 2 3600 1440021600 1438207200 38353 example.com. E1/oLjSGIbmLny/4fcgM1z4oL6aqo+izT3urCyHyvEp4Sp8Syg1eI+lJ57CSnZqjJP41O/9l4m0AsQ4f7qI1gVnML8vWWiyW2KXhT9kuAICUSxv5OWbf81Rq7Yu60npabODB0QFPb/rkW3kUZmQ0YQUA", +) when5 = 1440021600 when5_start = 1438207200 wildcard_keys = { - abs_example_com : dns.rrset.from_text( - 'example.com', 3600, 'IN', 'DNSKEY', - '256 3 5 AwEAAecNZbwD2thg3kaRLVqCC7ASP/3F79ZIu7pCu8HvZZ6ZdinffnxT npNoVvavjouHKFYTtJyUZAfw3ZMJSsGvEerc7uh6Ex9TgvOJtWPGUtxB Nnni2u9Nk+5k6nJzMiS3sL3RLvrfZW5d2Bwbl9L5f9Ud+r2Dbm7EG3tY pMY5OE8f') + abs_example_com: dns.rrset.from_text( + "example.com", + 3600, + "IN", + "DNSKEY", + "256 3 5 AwEAAecNZbwD2thg3kaRLVqCC7ASP/3F79ZIu7pCu8HvZZ6ZdinffnxT npNoVvavjouHKFYTtJyUZAfw3ZMJSsGvEerc7uh6Ex9TgvOJtWPGUtxB Nnni2u9Nk+5k6nJzMiS3sL3RLvrfZW5d2Bwbl9L5f9Ud+r2Dbm7EG3tY pMY5OE8f", + ) } -wildcard_example_com = dns.name.from_text('*', abs_example_com) -wildcard_txt = dns.rrset.from_text('*.example.com.', 3600, 'IN', 'TXT', 'foo') -wildcard_txt_rrsig = dns.rrset.from_text('*.example.com.', 3600, 'IN', 'RRSIG', - 'TXT 5 2 3600 20200707211255 20200630180755 42486 example.com. qevJYhdAHq1VmehXQ5i+Epa32xs4zcd4qmb39pHa3GUKr1V504nxzdzQ gsT5mvDkRoY95+HAiysDON6DCDtZc69iBUIHWWuFo/OrcD2q/mWANG4x vyU28Pf0U1gN6Gd5iapKC0Ya12flKh//NQiNN2skOQ2MoF2MW2/MaAK2 HBc=') +wildcard_example_com = dns.name.from_text("*", abs_example_com) +wildcard_txt = dns.rrset.from_text("*.example.com.", 3600, "IN", "TXT", "foo") +wildcard_txt_rrsig = dns.rrset.from_text( + "*.example.com.", + 3600, + "IN", + "RRSIG", + "TXT 5 2 3600 20200707211255 20200630180755 42486 example.com. qevJYhdAHq1VmehXQ5i+Epa32xs4zcd4qmb39pHa3GUKr1V504nxzdzQ gsT5mvDkRoY95+HAiysDON6DCDtZc69iBUIHWWuFo/OrcD2q/mWANG4x vyU28Pf0U1gN6Gd5iapKC0Ya12flKh//NQiNN2skOQ2MoF2MW2/MaAK2 HBc=", +) wildcard_when = 1593541048 rsamd5_keys = { abs_example: dns.rrset.from_text( - 'example', 3600, 'in', 'dnskey', - '257 3 1 AwEAAewnoEWe+AVEnQzcZTwpl8K/QKuScYIX 9xHOhejAL1enMjE0j97Gq3XXJJPWF7eQQGHs 1De4Srv2UT0zRCLkH9r36lOR/ggANvthO/Ub Es0hlD3A58LumEPudgIDwEkxGvQAXMFTMw0x 1d/a82UtzmNoPVzFOl2r+OCXx9Jbdh/L; KSK; alg = RSAMD5; key id = 30239', - '256 3 1 AwEAAb8OJM5YcqaYG0fenUdRlrhBQ6LuwCvr 5BRlrVbVzadSDBpq+yIiklfdGNBg3WZztDy1 du62NWC/olMfc6uRe/SjqTa7IJ3MdEuZQXQw MedGdNSF73zbokx8wg7zBBr74xHczJcEpQhr ZLzwCDmIPu0yoVi3Yqdl4dm4vNBj9hAD; ZSK; alg = RSAMD5; key id = 62992') + "example", + 3600, + "in", + "dnskey", + "257 3 1 AwEAAewnoEWe+AVEnQzcZTwpl8K/QKuScYIX 9xHOhejAL1enMjE0j97Gq3XXJJPWF7eQQGHs 1De4Srv2UT0zRCLkH9r36lOR/ggANvthO/Ub Es0hlD3A58LumEPudgIDwEkxGvQAXMFTMw0x 1d/a82UtzmNoPVzFOl2r+OCXx9Jbdh/L; KSK; alg = RSAMD5; key id = 30239", + "256 3 1 AwEAAb8OJM5YcqaYG0fenUdRlrhBQ6LuwCvr 5BRlrVbVzadSDBpq+yIiklfdGNBg3WZztDy1 du62NWC/olMfc6uRe/SjqTa7IJ3MdEuZQXQw MedGdNSF73zbokx8wg7zBBr74xHczJcEpQhr ZLzwCDmIPu0yoVi3Yqdl4dm4vNBj9hAD; ZSK; alg = RSAMD5; key id = 62992", + ) } -rsamd5_ns = dns.rrset.from_text('example.', 3600, 'in', 'ns', - 'ns1.example.', 'ns2.example.') -rsamd5_ns_rrsig = dns.rrset.from_text('example.', 3600, 'in', 'rrsig', - 'NS 1 1 3600 20200825153103 20200726153103 62992 example. YPv0WVqzQBDH45mFcYGo9psCVoMoeeHeAugh 9RZuO2NmdwfQ3mmiQm7WJ3AYnzYIozFGf7CL nwn3vN8/fjsfcQgEv5xfhFTSd4IoAzJJiZAa vrI4L5590C/+aXQ8tjRmbMTPiqoudaXvsevE jP2lTFg5DCruJyFq5dnAY5b90RY=') +rsamd5_ns = dns.rrset.from_text( + "example.", 3600, "in", "ns", "ns1.example.", "ns2.example." +) +rsamd5_ns_rrsig = dns.rrset.from_text( + "example.", + 3600, + "in", + "rrsig", + "NS 1 1 3600 20200825153103 20200726153103 62992 example. YPv0WVqzQBDH45mFcYGo9psCVoMoeeHeAugh 9RZuO2NmdwfQ3mmiQm7WJ3AYnzYIozFGf7CL nwn3vN8/fjsfcQgEv5xfhFTSd4IoAzJJiZAa vrI4L5590C/+aXQ8tjRmbMTPiqoudaXvsevE jP2lTFg5DCruJyFq5dnAY5b90RY=", +) rsamd5_when = 1595781671 rsasha512_keys = { abs_example: dns.rrset.from_text( - 'example', 3600, 'in', 'dnskey', - '256 3 10 AwEAAb2JvKjZ6l5qg2ab3qqUQhLGGjsiMIuQ 2zhaXJHdTntS+8LgUXo5yLFn7YF9YL1VX9V4 5ASGxUpz0u0chjWqBNtUO3Ymzas/vck9o21M 2Ce/LrpfYsqvJaLvGf/dozW9uSeMQq1mPKYG xo4uxyhZBhZewX8znXZySrAIozBPH3yp ; ZSK; alg = RSASHA512 ; key id = 5957', - '257 3 10 AwEAAc7Lnoe+mHijJ8OOHgyJHKYantQGKx5t rIs267gOePyAL7cUt9HO1Sm3vABSGNsoHL6w 8/542SxGbT21osVISamtq7kUPTgDU9iKqCBq VdXEdzXYbhBKVoQkGPl4PflfbOgg/45xAiTi 7qOUERuRCPdKEkd4FW0tg6VfZmm7QjP1 ; KSK; alg = RSASHA512 ; key id = 53212') + "example", + 3600, + "in", + "dnskey", + "256 3 10 AwEAAb2JvKjZ6l5qg2ab3qqUQhLGGjsiMIuQ 2zhaXJHdTntS+8LgUXo5yLFn7YF9YL1VX9V4 5ASGxUpz0u0chjWqBNtUO3Ymzas/vck9o21M 2Ce/LrpfYsqvJaLvGf/dozW9uSeMQq1mPKYG xo4uxyhZBhZewX8znXZySrAIozBPH3yp ; ZSK; alg = RSASHA512 ; key id = 5957", + "257 3 10 AwEAAc7Lnoe+mHijJ8OOHgyJHKYantQGKx5t rIs267gOePyAL7cUt9HO1Sm3vABSGNsoHL6w 8/542SxGbT21osVISamtq7kUPTgDU9iKqCBq VdXEdzXYbhBKVoQkGPl4PflfbOgg/45xAiTi 7qOUERuRCPdKEkd4FW0tg6VfZmm7QjP1 ; KSK; alg = RSASHA512 ; key id = 53212", + ) } -rsasha512_ns = dns.rrset.from_text('example.', 3600, 'in', 'ns', - 'ns1.example.', 'ns2.example.') +rsasha512_ns = dns.rrset.from_text( + "example.", 3600, "in", "ns", "ns1.example.", "ns2.example." +) rsasha512_ns_rrsig = dns.rrset.from_text( - 'example.', 3600, 'in', 'rrsig', - 'NS 10 1 3600 20200825161255 20200726161255 5957 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=') + "example.", + 3600, + "in", + "rrsig", + "NS 10 1 3600 20200825161255 20200726161255 5957 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=", +) rsasha512_when = 1595783997 unknown_alg_keys = { abs_example: dns.rrset.from_text( - 'example', 3600, 'in', 'dnskey', - '256 3 100 Ym9ndXM=', - '257 3 100 Ym9ndXM=') + "example", 3600, "in", "dnskey", "256 3 100 Ym9ndXM=", "257 3 100 Ym9ndXM=" + ) } unknown_alg_ns_rrsig = dns.rrset.from_text( - 'example.', 3600, 'in', 'rrsig', - 'NS 100 1 3600 20200825161255 20200726161255 16713 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=') + "example.", + 3600, + "in", + "rrsig", + "NS 100 1 3600 20200825161255 20200726161255 16713 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=", +) fake_gost_keys = { abs_example: dns.rrset.from_text( - 'example', 3600, 'in', 'dnskey', - '256 3 12 Ym9ndXM=', - '257 3 12 Ym9ndXM=') + "example", 3600, "in", "dnskey", "256 3 12 Ym9ndXM=", "257 3 12 Ym9ndXM=" + ) } fake_gost_ns_rrsig = dns.rrset.from_text( - 'example.', 3600, 'in', 'rrsig', - 'NS 12 1 3600 20200825161255 20200726161255 16625 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=') + "example.", + 3600, + "in", + "rrsig", + "NS 12 1 3600 20200825161255 20200726161255 16625 example. P9A+1zYke7yIiKEnxFMm+UIW2CIwy2WDvbx6 g8hHiI8qISe6oeKveFW23OSk9+VwFgBiOpeM ygzzFbckY7RkGbOr4TR8ogDRANt6LhV402Hu SXTV9hCLVFWU4PS+/fxxfOHCetsY5tWWSxZi zSHfgpGfsHWzQoAamag4XYDyykc=", +) -@unittest.skipUnless(dns.dnssec._have_pyca, - "Python Cryptography cannot be imported") -class DNSSECValidatorTestCase(unittest.TestCase): +@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported") +class DNSSECValidatorTestCase(unittest.TestCase): def testAbsoluteRSAMD5Good(self): # type: () -> None - dns.dnssec.validate(rsamd5_ns, rsamd5_ns_rrsig, rsamd5_keys, None, - rsamd5_when) + dns.dnssec.validate(rsamd5_ns, rsamd5_ns_rrsig, rsamd5_keys, None, rsamd5_when) def testRSAMD5Keyid(self): self.assertEqual(dns.dnssec.key_id(rsamd5_keys[abs_example][0]), 30239) @@ -279,114 +452,148 @@ class DNSSECValidatorTestCase(unittest.TestCase): dns.dnssec.validate(abs_soa, abs_soa_rrsig, abs_keys, None, when) def testDuplicateKeytag(self): # type: () -> None - dns.dnssec.validate(abs_soa, abs_soa_rrsig, abs_keys_duplicate_keytag, None, when) + dns.dnssec.validate( + abs_soa, abs_soa_rrsig, abs_keys_duplicate_keytag, None, when + ) def testAbsoluteRSABad(self): # type: () -> None def bad(): # type: () -> None - dns.dnssec.validate(abs_other_soa, abs_soa_rrsig, abs_keys, None, - when) + dns.dnssec.validate(abs_other_soa, abs_soa_rrsig, abs_keys, None, when) + self.assertRaises(dns.dnssec.ValidationFailure, bad) def testRelativeRSAGood(self): # type: () -> None - dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, - abs_dnspython_org, when) + dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, abs_dnspython_org, when) # test the text conversion for origin too - dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, - 'dnspython.org', when) + dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, "dnspython.org", when) def testRelativeRSABad(self): # type: () -> None def bad(): # type: () -> None - dns.dnssec.validate(rel_other_soa, rel_soa_rrsig, rel_keys, - abs_dnspython_org, when) + dns.dnssec.validate( + rel_other_soa, rel_soa_rrsig, rel_keys, abs_dnspython_org, when + ) + self.assertRaises(dns.dnssec.ValidationFailure, bad) def testAbsoluteDSAGood(self): # type: () -> None - dns.dnssec.validate(abs_dsa_soa, abs_dsa_soa_rrsig, abs_dsa_keys, None, - when2) + dns.dnssec.validate(abs_dsa_soa, abs_dsa_soa_rrsig, abs_dsa_keys, None, when2) def testAbsoluteDSABad(self): # type: () -> None def bad(): # type: () -> None - dns.dnssec.validate(abs_other_dsa_soa, abs_dsa_soa_rrsig, - abs_dsa_keys, None, when2) + dns.dnssec.validate( + abs_other_dsa_soa, abs_dsa_soa_rrsig, abs_dsa_keys, None, when2 + ) + self.assertRaises(dns.dnssec.ValidationFailure, bad) def testAbsoluteECDSA256Good(self): # type: () -> None - dns.dnssec.validate(abs_ecdsa256_soa, abs_ecdsa256_soa_rrsig, - abs_ecdsa256_keys, None, when3) + dns.dnssec.validate( + abs_ecdsa256_soa, abs_ecdsa256_soa_rrsig, abs_ecdsa256_keys, None, when3 + ) def testAbsoluteECDSA256Bad(self): # type: () -> None def bad(): # type: () -> None - dns.dnssec.validate(abs_other_ecdsa256_soa, abs_ecdsa256_soa_rrsig, - abs_ecdsa256_keys, None, when3) + dns.dnssec.validate( + abs_other_ecdsa256_soa, + abs_ecdsa256_soa_rrsig, + abs_ecdsa256_keys, + None, + when3, + ) + self.assertRaises(dns.dnssec.ValidationFailure, bad) def testAbsoluteECDSA384Good(self): # type: () -> None - dns.dnssec.validate(abs_ecdsa384_soa, abs_ecdsa384_soa_rrsig, - abs_ecdsa384_keys, None, when4) + dns.dnssec.validate( + abs_ecdsa384_soa, abs_ecdsa384_soa_rrsig, abs_ecdsa384_keys, None, when4 + ) def testAbsoluteECDSA384Bad(self): # type: () -> None def bad(): # type: () -> None - dns.dnssec.validate(abs_other_ecdsa384_soa, abs_ecdsa384_soa_rrsig, - abs_ecdsa384_keys, None, when4) + dns.dnssec.validate( + abs_other_ecdsa384_soa, + abs_ecdsa384_soa_rrsig, + abs_ecdsa384_keys, + None, + when4, + ) + self.assertRaises(dns.dnssec.ValidationFailure, bad) def testAbsoluteED25519Good(self): # type: () -> None - dns.dnssec.validate(abs_ed25519_mx, abs_ed25519_mx_rrsig_1, - abs_ed25519_keys_1, None, when5) - dns.dnssec.validate(abs_ed25519_mx, abs_ed25519_mx_rrsig_2, - abs_ed25519_keys_2, None, when5) + dns.dnssec.validate( + abs_ed25519_mx, abs_ed25519_mx_rrsig_1, abs_ed25519_keys_1, None, when5 + ) + dns.dnssec.validate( + abs_ed25519_mx, abs_ed25519_mx_rrsig_2, abs_ed25519_keys_2, None, when5 + ) def testAbsoluteED25519Bad(self): # type: () -> None with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_other_ed25519_mx, abs_ed25519_mx_rrsig_1, - abs_ed25519_keys_1, None, when5) + dns.dnssec.validate( + abs_other_ed25519_mx, + abs_ed25519_mx_rrsig_1, + abs_ed25519_keys_1, + None, + when5, + ) with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_other_ed25519_mx, abs_ed25519_mx_rrsig_2, - abs_ed25519_keys_2, None, when5) + dns.dnssec.validate( + abs_other_ed25519_mx, + abs_ed25519_mx_rrsig_2, + abs_ed25519_keys_2, + None, + when5, + ) def testAbsoluteED448Good(self): # type: () -> None - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None, when5) - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_2, - abs_ed448_keys_2, None, when5) + dns.dnssec.validate( + abs_ed448_mx, abs_ed448_mx_rrsig_1, abs_ed448_keys_1, None, when5 + ) + dns.dnssec.validate( + abs_ed448_mx, abs_ed448_mx_rrsig_2, abs_ed448_keys_2, None, when5 + ) def testAbsoluteED448Bad(self): # type: () -> None with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_other_ed448_mx, abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None, when5) + dns.dnssec.validate( + abs_other_ed448_mx, abs_ed448_mx_rrsig_1, abs_ed448_keys_1, None, when5 + ) with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_other_ed448_mx, abs_ed448_mx_rrsig_2, - abs_ed448_keys_2, None, when5) + dns.dnssec.validate( + abs_other_ed448_mx, abs_ed448_mx_rrsig_2, abs_ed448_keys_2, None, when5 + ) def testAbsoluteRSASHA512Good(self): - dns.dnssec.validate(rsasha512_ns, rsasha512_ns_rrsig, rsasha512_keys, - None, rsasha512_when) + dns.dnssec.validate( + rsasha512_ns, rsasha512_ns_rrsig, rsasha512_keys, None, rsasha512_when + ) def testWildcardGoodAndBad(self): - dns.dnssec.validate(wildcard_txt, wildcard_txt_rrsig, - wildcard_keys, None, wildcard_when) + dns.dnssec.validate( + wildcard_txt, wildcard_txt_rrsig, wildcard_keys, None, wildcard_when + ) def clone_rrset(rrset, name): return dns.rrset.from_rdata(name, rrset.ttl, rrset[0]) - a_name = dns.name.from_text('a.example.com') + a_name = dns.name.from_text("a.example.com") a_txt = clone_rrset(wildcard_txt, a_name) a_txt_rrsig = clone_rrset(wildcard_txt_rrsig, a_name) - dns.dnssec.validate(a_txt, a_txt_rrsig, wildcard_keys, None, - wildcard_when) + dns.dnssec.validate(a_txt, a_txt_rrsig, wildcard_keys, None, wildcard_when) - abc_name = dns.name.from_text('a.b.c.example.com') + abc_name = dns.name.from_text("a.b.c.example.com") abc_txt = clone_rrset(wildcard_txt, abc_name) abc_txt_rrsig = clone_rrset(wildcard_txt_rrsig, abc_name) - dns.dnssec.validate(abc_txt, abc_txt_rrsig, wildcard_keys, None, - wildcard_when) + dns.dnssec.validate(abc_txt, abc_txt_rrsig, wildcard_keys, None, wildcard_when) - com_name = dns.name.from_text('com.') + com_name = dns.name.from_text("com.") com_txt = clone_rrset(wildcard_txt, com_name) com_txt_rrsig = clone_rrset(wildcard_txt_rrsig, abc_name) with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate_rrsig(com_txt, com_txt_rrsig[0], wildcard_keys, - None, wildcard_when) + dns.dnssec.validate_rrsig( + com_txt, com_txt_rrsig[0], wildcard_keys, None, wildcard_when + ) def testAlternateParameterFormats(self): # type: () -> None # Pass rrset and rrsigset as (name, rdataset) tuples, not rrsets @@ -408,46 +615,63 @@ class DNSSECValidatorTestCase(unittest.TestCase): dns.dnssec.validate(abs_soa, abs_soa_rrsig, keys, None, when) # Pass origin as a string, not a name. - dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, - 'dnspython.org', when) - dns.dnssec.validate_rrsig(rel_soa, rel_soa_rrsig[0], rel_keys, - 'dnspython.org', when) + dns.dnssec.validate(rel_soa, rel_soa_rrsig, rel_keys, "dnspython.org", when) + dns.dnssec.validate_rrsig( + rel_soa, rel_soa_rrsig[0], rel_keys, "dnspython.org", when + ) def testAbsoluteKeyNotFound(self): with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, {}, None, - when5) + dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, {}, None, when5) def testTimeBounds(self): # not yet valid with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None, when5_start - 1) + dns.dnssec.validate( + abs_ed448_mx, + abs_ed448_mx_rrsig_1, + abs_ed448_keys_1, + None, + when5_start - 1, + ) # expired with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None, when5 + 1) + dns.dnssec.validate( + abs_ed448_mx, abs_ed448_mx_rrsig_1, abs_ed448_keys_1, None, when5 + 1 + ) # expired using the current time (to test the "get the time" code # path) with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(abs_ed448_mx, abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None) + dns.dnssec.validate( + abs_ed448_mx, abs_ed448_mx_rrsig_1, abs_ed448_keys_1, None + ) def testOwnerNameMismatch(self): - bogus = dns.name.from_text('example.bogus') + bogus = dns.name.from_text("example.bogus") with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate((bogus, abs_ed448_mx), abs_ed448_mx_rrsig_1, - abs_ed448_keys_1, None, when5 + 1) + dns.dnssec.validate( + (bogus, abs_ed448_mx), + abs_ed448_mx_rrsig_1, + abs_ed448_keys_1, + None, + when5 + 1, + ) def testGOSTNotSupported(self): with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(rsasha512_ns, fake_gost_ns_rrsig, - fake_gost_keys, None, rsasha512_when) + dns.dnssec.validate( + rsasha512_ns, fake_gost_ns_rrsig, fake_gost_keys, None, rsasha512_when + ) def testUnknownAlgorithm(self): with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec.validate(rsasha512_ns, unknown_alg_ns_rrsig, - unknown_alg_keys, None, rsasha512_when) + dns.dnssec.validate( + rsasha512_ns, + unknown_alg_ns_rrsig, + unknown_alg_keys, + None, + rsasha512_when, + ) class DNSSECMiscTestCase(unittest.TestCase): @@ -468,83 +692,100 @@ class DNSSECMiscTestCase(unittest.TestCase): class DNSSECMakeDSTestCase(unittest.TestCase): - def testMnemonicParser(self): - good_ds_mnemonic = dns.rdata.from_text(dns.rdataclass.IN, - dns.rdatatype.DS, - '57349 RSASHA1 2 53A79A3E7488AB44FFC56B2D1109F0699D1796DD977E72108B841F96 E47D7013') + good_ds_mnemonic = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + "57349 RSASHA1 2 53A79A3E7488AB44FFC56B2D1109F0699D1796DD977E72108B841F96 E47D7013", + ) self.assertEqual(good_ds, good_ds_mnemonic) def testMakeExampleSHA1DS(self): # type: () -> None algorithm: Any - for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1): + for algorithm in ("SHA1", "sha1", dns.dnssec.DSDigest.SHA1): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha1) - ds = dns.dnssec.make_ds('example.', example_sep_key, algorithm) + ds = dns.dnssec.make_ds("example.", example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha1) def testMakeExampleSHA256DS(self): # type: () -> None algorithm: Any - for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256): + for algorithm in ("SHA256", "sha256", dns.dnssec.DSDigest.SHA256): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha256) def testMakeExampleSHA384DS(self): # type: () -> None algorithm: Any - for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384): + for algorithm in ("SHA384", "sha384", dns.dnssec.DSDigest.SHA384): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha384) def testMakeSHA256DS(self): # type: () -> None - ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, 'SHA256') + ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, "SHA256") self.assertEqual(ds, good_ds) def testInvalidAlgorithm(self): # type: () -> None algorithm: Any - for algorithm in (10, 'shax'): + for algorithm in (10, "shax"): with self.assertRaises(dns.dnssec.UnsupportedAlgorithm): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) def testReservedDigestType(self): # type: () -> None with self.assertRaises(dns.exception.SyntaxError) as cm: - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, - f'18673 3 0 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7') - self.assertEqual('digest type 0 is reserved', str(cm.exception)) + dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.DS, + f"18673 3 0 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7", + ) + self.assertEqual("digest type 0 is reserved", str(cm.exception)) def testUnknownDigestType(self): # type: () -> None digest_types = [dns.rdatatype.DS, dns.rdatatype.CDS] for rdtype in digest_types: - rd = dns.rdata.from_text(dns.rdataclass.IN, rdtype, - f'18673 3 5 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7') - assert isinstance(rd, dns.rdtypes.ANY.DS.DS) or isinstance(rd, dns.rdtypes.ANY.CDS.CDS) + rd = dns.rdata.from_text( + dns.rdataclass.IN, + rdtype, + f"18673 3 5 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7", + ) + assert isinstance(rd, dns.rdtypes.ANY.DS.DS) or isinstance( + rd, dns.rdtypes.ANY.CDS.CDS + ) self.assertEqual(rd.digest_type, 5) - self.assertEqual(rd.digest, bytes.fromhex('71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7')) + self.assertEqual( + rd.digest, bytes.fromhex("71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7") + ) def testInvalidDigestLength(self): # type: () -> None test_records = [] for rdata in [example_ds_sha1, example_ds_sha256, example_ds_sha384]: - flags, digest = rdata.to_text().rsplit(' ', 1) + flags, digest = rdata.to_text().rsplit(" ", 1) # Make sure the construction is working - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, f'{flags} {digest}') + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.DS, f"{flags} {digest}" + ) - test_records.append(f'{flags} {digest[:len(digest)//2]}') # too short digest - test_records.append(f'{flags} {digest*2}') # too long digest + test_records.append( + f"{flags} {digest[:len(digest)//2]}" + ) # too short digest + test_records.append(f"{flags} {digest*2}") # too long digest for record in test_records: with self.assertRaises(dns.exception.SyntaxError) as cm: dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.DS, record) - self.assertEqual('digest length inconsistent with digest type', str(cm.exception)) + self.assertEqual( + "digest length inconsistent with digest type", str(cm.exception) + ) def testInvalidDigestLengthCDS0(self): # type: () -> None # Make sure the construction is working - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, f'0 0 0 00') + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, f"0 0 0 00") test_records = { - 'expecting another identifier': ['0 0 0', '0 0 0 '], - 'digest length inconsistent with digest type': ['0 0 0 0000'], - 'Odd-length string': ['0 0 0 0', '0 0 0 000'], + "expecting another identifier": ["0 0 0", "0 0 0 "], + "digest length inconsistent with digest type": ["0 0 0 0000"], + "Odd-length string": ["0 0 0 0", "0 0 0 000"], } for msg, records in test_records.items(): for record in records: @@ -553,5 +794,5 @@ class DNSSECMakeDSTestCase(unittest.TestCase): self.assertEqual(msg, str(cm.exception)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_doh.py b/tests/test_doh.py index bc02e952..35cb1d3c 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -17,8 +17,10 @@ import unittest import random import socket + try: import ssl + _have_ssl = True except Exception: _have_ssl = False @@ -41,19 +43,19 @@ resolver_v6_addresses = [] try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: s.settimeout(4) - s.connect(('8.8.8.8', 53)) + s.connect(("8.8.8.8", 53)) resolver_v4_addresses = [ - '1.1.1.1', - '8.8.8.8', + "1.1.1.1", + "8.8.8.8", # '9.9.9.9', ] except Exception: pass try: with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s: - s.connect(('2001:4860:4860::8888', 53)) + s.connect(("2001:4860:4860::8888", 53)) resolver_v6_addresses = [ - '2606:4700:4700::1111', + "2606:4700:4700::1111", # Google says 404 # '2001:4860:4860::8888', # '2620:fe::fe', @@ -61,22 +63,25 @@ try: except Exception: pass -KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query', - 'https://dns.google/dns-query', - # 'https://dns11.quad9.net/dns-query', - ] +KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ + "https://cloudflare-dns.com/dns-query", + "https://dns.google/dns-query", + # 'https://dns11.quad9.net/dns-query', +] # Some tests require the internet to be available to run, so let's # skip those if it's not there. _network_available = True try: - socket.gethostbyname('dnspython.org') + socket.gethostbyname("dnspython.org") except socket.gaierror: _network_available = False -@unittest.skipUnless(dns.query._have_requests and _network_available, - "Python requests cannot be imported; no DNS over HTTPS (DOH)") +@unittest.skipUnless( + dns.query._have_requests and _network_available, + "Python requests cannot be imported; no DNS over HTTPS (DOH)", +) class DNSOverHTTPSTestCaseRequests(unittest.TestCase): def setUp(self): self.session = requests.sessions.Session() @@ -86,71 +91,77 @@ class DNSOverHTTPSTestCaseRequests(unittest.TestCase): def test_get_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=False, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_url, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=True, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_url, session=self.session, post=True, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): self.assertTrue(resolver_v4_addresses or resolver_v6_addresses) if resolver_v4_addresses: nameserver_ip = random.choice(resolver_v4_addresses) - q = dns.message.make_query('example.com.', dns.rdatatype.A) + q = dns.message.make_query("example.com.", dns.rdatatype.A) # For some reason Google's DNS over HTTPS fails when you POST to # https://8.8.8.8/dns-query # So we're just going to do GET requests here - r = dns.query.https(q, nameserver_ip, session=self.session, - post=False, timeout=4) + r = dns.query.https( + q, nameserver_ip, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) if resolver_v6_addresses: nameserver_ip = random.choice(resolver_v6_addresses) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_ip, session=self.session, - post=False, timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_ip, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_bootstrap_address(self): # We test this to see if v4 is available if resolver_v4_addresses: - ip = '185.228.168.168' - invalid_tls_url = 'https://{}/doh/family-filter/'.format(ip) - valid_tls_url = 'https://doh.cleanbrowsing.org/doh/family-filter/' - q = dns.message.make_query('example.com.', dns.rdatatype.A) + ip = "185.228.168.168" + invalid_tls_url = "https://{}/doh/family-filter/".format(ip) + valid_tls_url = "https://doh.cleanbrowsing.org/doh/family-filter/" + q = dns.message.make_query("example.com.", dns.rdatatype.A) # make sure CleanBrowsing's IP address will fail TLS certificate # check with self.assertRaises(SSLError): - dns.query.https(q, invalid_tls_url, session=self.session, - timeout=4) + dns.query.https(q, invalid_tls_url, session=self.session, timeout=4) # use host header - r = dns.query.https(q, valid_tls_url, session=self.session, - bootstrap_address=ip, timeout=4) + r = dns.query.https( + q, valid_tls_url, session=self.session, bootstrap_address=ip, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_new_session(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) + q = dns.message.make_query("example.com.", dns.rdatatype.A) r = dns.query.https(q, nameserver_url, timeout=4) self.assertTrue(q.is_response(r)) def test_resolver(self): res = dns.resolver.Resolver(configure=False) - res.nameservers = ['https://dns.google/dns-query'] - answer = res.resolve('dns.google', 'A') + res.nameservers = ["https://dns.google/dns-query"] + answer = res.resolve("dns.google", "A") seen = set([rdata.address for rdata in answer]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) -@unittest.skipUnless(dns.query._have_httpx and _network_available and _have_ssl, - "Python httpx cannot be imported; no DNS over HTTPS (DOH)") +@unittest.skipUnless( + dns.query._have_httpx and _network_available and _have_ssl, + "Python httpx cannot be imported; no DNS over HTTPS (DOH)", +) class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): def setUp(self): self.session = httpx.Client(http1=True, http2=True, verify=True) @@ -160,9 +171,10 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): def test_get_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=False, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_url, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_get_request_http1(self): @@ -170,72 +182,80 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): try: dns.query._have_http2 = False nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=False, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_url, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) finally: dns.query._have_http2 = saved_have_http2 def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=True, - timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_url, session=self.session, post=True, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): self.assertTrue(resolver_v4_addresses or resolver_v6_addresses) if resolver_v4_addresses: nameserver_ip = random.choice(resolver_v4_addresses) - q = dns.message.make_query('example.com.', dns.rdatatype.A) + q = dns.message.make_query("example.com.", dns.rdatatype.A) # For some reason Google's DNS over HTTPS fails when you POST to # https://8.8.8.8/dns-query # So we're just going to do GET requests here - r = dns.query.https(q, nameserver_ip, session=self.session, - post=False, timeout=4) + r = dns.query.https( + q, nameserver_ip, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) if resolver_v6_addresses: nameserver_ip = random.choice(resolver_v6_addresses) - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_ip, session=self.session, - post=False, timeout=4) + q = dns.message.make_query("example.com.", dns.rdatatype.A) + r = dns.query.https( + q, nameserver_ip, session=self.session, post=False, timeout=4 + ) self.assertTrue(q.is_response(r)) def test_bootstrap_address_fails(self): # We test this to see if v4 is available if resolver_v4_addresses: - ip = '185.228.168.168' - invalid_tls_url = 'https://{}/doh/family-filter/'.format(ip) - valid_tls_url = 'https://doh.cleanbrowsing.org/doh/family-filter/' - q = dns.message.make_query('example.com.', dns.rdatatype.A) + ip = "185.228.168.168" + invalid_tls_url = "https://{}/doh/family-filter/".format(ip) + valid_tls_url = "https://doh.cleanbrowsing.org/doh/family-filter/" + q = dns.message.make_query("example.com.", dns.rdatatype.A) # make sure CleanBrowsing's IP address will fail TLS certificate # check. with self.assertRaises(httpx.ConnectError): - dns.query.https(q, invalid_tls_url, session=self.session, - timeout=4) + dns.query.https(q, invalid_tls_url, session=self.session, timeout=4) # We can't do the Host header and SNI magic with httpx, but # we are demanding httpx be used by providing a session, so # we should get a NoDOH exception. with self.assertRaises(dns.query.NoDOH): - dns.query.https(q, valid_tls_url, session=self.session, - bootstrap_address=ip, timeout=4) + dns.query.https( + q, + valid_tls_url, + session=self.session, + bootstrap_address=ip, + timeout=4, + ) def test_new_session(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query('example.com.', dns.rdatatype.A) + q = dns.message.make_query("example.com.", dns.rdatatype.A) r = dns.query.https(q, nameserver_url, timeout=4) self.assertTrue(q.is_response(r)) def test_resolver(self): res = dns.resolver.Resolver(configure=False) - res.nameservers = ['https://dns.google/dns-query'] - answer = res.resolve('dns.google', 'A') + res.nameservers = ["https://dns.google/dns-query"] + answer = res.resolve("dns.google", "A") seen = set([rdata.address for rdata in answer]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_edns.py b/tests/test_edns.py index 427eb29c..37bd9709 100644 --- a/tests/test_edns.py +++ b/tests/test_edns.py @@ -25,128 +25,130 @@ from io import BytesIO import dns.edns import dns.wire + class OptionTestCase(unittest.TestCase): def testGenericOption(self): - opt = dns.edns.GenericOption(3, b'data') + opt = dns.edns.GenericOption(3, b"data") io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'data') + self.assertEqual(data, b"data") self.assertEqual(dns.edns.option_from_wire(3, data, 0, len(data)), opt) - self.assertEqual(str(opt), 'Generic 3') + self.assertEqual(str(opt), "Generic 3") def testECSOption_prefix_length(self): - opt = dns.edns.ECSOption('1.2.255.33', 20) + opt = dns.edns.ECSOption("1.2.255.33", 20) io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x01\x14\x00\x01\x02\xf0') + self.assertEqual(data, b"\x00\x01\x14\x00\x01\x02\xf0") def testECSOption(self): - opt = dns.edns.ECSOption('1.2.3.4', 24) + opt = dns.edns.ECSOption("1.2.3.4", 24) io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x01\x18\x00\x01\x02\x03') + self.assertEqual(data, b"\x00\x01\x18\x00\x01\x02\x03") # default srclen - opt = dns.edns.ECSOption('1.2.3.4') + opt = dns.edns.ECSOption("1.2.3.4") io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x01\x18\x00\x01\x02\x03') - self.assertEqual(opt.to_text(), 'ECS 1.2.3.4/24 scope/0') + self.assertEqual(data, b"\x00\x01\x18\x00\x01\x02\x03") + self.assertEqual(opt.to_text(), "ECS 1.2.3.4/24 scope/0") def testECSOption25(self): - opt = dns.edns.ECSOption('1.2.3.255', 25) + opt = dns.edns.ECSOption("1.2.3.255", 25) io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x01\x19\x00\x01\x02\x03\x80') + self.assertEqual(data, b"\x00\x01\x19\x00\x01\x02\x03\x80") opt2 = dns.edns.option_from_wire(dns.edns.ECS, data, 0, len(data)) self.assertEqual(opt2.otype, dns.edns.ECS) - self.assertEqual(opt2.address, '1.2.3.128') + self.assertEqual(opt2.address, "1.2.3.128") self.assertEqual(opt2.srclen, 25) self.assertEqual(opt2.scopelen, 0) def testECSOption_v6(self): - opt = dns.edns.ECSOption('2001:4b98::1') + opt = dns.edns.ECSOption("2001:4b98::1") io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x02\x38\x00\x20\x01\x4b\x98\x00\x00\x00') + self.assertEqual(data, b"\x00\x02\x38\x00\x20\x01\x4b\x98\x00\x00\x00") opt2 = dns.edns.option_from_wire(dns.edns.ECS, data, 0, len(data)) self.assertEqual(opt2.otype, dns.edns.ECS) - self.assertEqual(opt2.address, '2001:4b98::') + self.assertEqual(opt2.address, "2001:4b98::") self.assertEqual(opt2.srclen, 56) self.assertEqual(opt2.scopelen, 0) def testECSOption_from_text_valid(self): - ecs1 = dns.edns.ECSOption.from_text('1.2.3.4/24/0') - self.assertEqual(ecs1, dns.edns.ECSOption('1.2.3.4', 24, 0)) + ecs1 = dns.edns.ECSOption.from_text("1.2.3.4/24/0") + self.assertEqual(ecs1, dns.edns.ECSOption("1.2.3.4", 24, 0)) - ecs2 = dns.edns.ECSOption.from_text('1.2.3.4/24') - self.assertEqual(ecs2, dns.edns.ECSOption('1.2.3.4', 24, 0)) + ecs2 = dns.edns.ECSOption.from_text("1.2.3.4/24") + self.assertEqual(ecs2, dns.edns.ECSOption("1.2.3.4", 24, 0)) - ecs3 = dns.edns.ECSOption.from_text('ECS 1.2.3.4/24') - self.assertEqual(ecs3, dns.edns.ECSOption('1.2.3.4', 24, 0)) + ecs3 = dns.edns.ECSOption.from_text("ECS 1.2.3.4/24") + self.assertEqual(ecs3, dns.edns.ECSOption("1.2.3.4", 24, 0)) - ecs4 = dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') - self.assertEqual(ecs4, dns.edns.ECSOption('1.2.3.4', 24, 32)) + ecs4 = dns.edns.ECSOption.from_text("ECS 1.2.3.4/24/32") + self.assertEqual(ecs4, dns.edns.ECSOption("1.2.3.4", 24, 32)) - ecs5 = dns.edns.ECSOption.from_text('2001:4b98::1/64/56') - self.assertEqual(ecs5, dns.edns.ECSOption('2001:4b98::1', 64, 56)) + ecs5 = dns.edns.ECSOption.from_text("2001:4b98::1/64/56") + self.assertEqual(ecs5, dns.edns.ECSOption("2001:4b98::1", 64, 56)) - ecs6 = dns.edns.ECSOption.from_text('2001:4b98::1/64') - self.assertEqual(ecs6, dns.edns.ECSOption('2001:4b98::1', 64, 0)) + ecs6 = dns.edns.ECSOption.from_text("2001:4b98::1/64") + self.assertEqual(ecs6, dns.edns.ECSOption("2001:4b98::1", 64, 0)) - ecs7 = dns.edns.ECSOption.from_text('ECS 2001:4b98::1/0') - self.assertEqual(ecs7, dns.edns.ECSOption('2001:4b98::1', 0, 0)) + ecs7 = dns.edns.ECSOption.from_text("ECS 2001:4b98::1/0") + self.assertEqual(ecs7, dns.edns.ECSOption("2001:4b98::1", 0, 0)) - ecs8 = dns.edns.ECSOption.from_text('ECS 2001:4b98::1/64/128') - self.assertEqual(ecs8, dns.edns.ECSOption('2001:4b98::1', 64, 128)) + ecs8 = dns.edns.ECSOption.from_text("ECS 2001:4b98::1/64/128") + self.assertEqual(ecs8, dns.edns.ECSOption("2001:4b98::1", 64, 128)) def testECSOption_from_text_invalid(self): with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('some random text 1.2.3.4/24/0 24') + dns.edns.ECSOption.from_text("some random text 1.2.3.4/24/0 24") with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('1.2.3.4/twentyfour') + dns.edns.ECSOption.from_text("1.2.3.4/twentyfour") with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('BOGUS 1.2.3.4/5/6/7') + dns.edns.ECSOption.from_text("BOGUS 1.2.3.4/5/6/7") with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('1.2.3.4/5/6/7') + dns.edns.ECSOption.from_text("1.2.3.4/5/6/7") with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('1.2.3.4/24/O') # <-- that's not a zero + dns.edns.ECSOption.from_text("1.2.3.4/24/O") # <-- that's not a zero with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('') + dns.edns.ECSOption.from_text("") with self.assertRaises(ValueError): - dns.edns.ECSOption.from_text('1.2.3.4/2001:4b98::1/24') + dns.edns.ECSOption.from_text("1.2.3.4/2001:4b98::1/24") def testECSOption_from_wire_invalid(self): with self.assertRaises(ValueError): - opt = dns.edns.option_from_wire(dns.edns.ECS, - b'\x00\xff\x18\x00\x01\x02\x03', - 0, 7) + opt = dns.edns.option_from_wire( + dns.edns.ECS, b"\x00\xff\x18\x00\x01\x02\x03", 0, 7 + ) + def testEDEOption(self): opt = dns.edns.EDEOption(3) io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x03') - self.assertEqual(str(opt), 'EDE 3') + self.assertEqual(data, b"\x00\x03") + self.assertEqual(str(opt), "EDE 3") # with text - opt = dns.edns.EDEOption(16, 'test') + opt = dns.edns.EDEOption(16, "test") io = BytesIO() opt.to_wire(io) data = io.getvalue() - self.assertEqual(data, b'\x00\x10test') + self.assertEqual(data, b"\x00\x10test") def testEDEOption_invalid(self): with self.assertRaises(ValueError): @@ -157,39 +159,41 @@ class OptionTestCase(unittest.TestCase): opt = dns.edns.EDEOption(0, 0) def testEDEOption_from_wire(self): - data = b'\x00\01' + data = b"\x00\01" self.assertEqual( - dns.edns.option_from_wire(dns.edns.EDE, data, 0, 2), - dns.edns.EDEOption(1)) - data = b'\x00\01test' + dns.edns.option_from_wire(dns.edns.EDE, data, 0, 2), dns.edns.EDEOption(1) + ) + data = b"\x00\01test" self.assertEqual( dns.edns.option_from_wire(dns.edns.EDE, data, 0, 6), - dns.edns.EDEOption(1, 'test')) + dns.edns.EDEOption(1, "test"), + ) # utf-8 text MAY be null-terminated - data = b'\x00\01test\x00' + data = b"\x00\01test\x00" self.assertEqual( dns.edns.option_from_wire(dns.edns.EDE, data, 0, 7), - dns.edns.EDEOption(1, 'test')) + dns.edns.EDEOption(1, "test"), + ) def test_basic_relations(self): - o1 = dns.edns.ECSOption.from_text('1.2.3.0/24/0') - o2 = dns.edns.ECSOption.from_text('1.2.4.0/24/0') + o1 = dns.edns.ECSOption.from_text("1.2.3.0/24/0") + o2 = dns.edns.ECSOption.from_text("1.2.4.0/24/0") self.assertTrue(o1 == o1) self.assertTrue(o1 != o2) self.assertTrue(o1 < o2) self.assertTrue(o1 <= o2) self.assertTrue(o2 > o1) self.assertTrue(o2 >= o1) - o1 = dns.edns.ECSOption.from_text('1.2.4.0/23/0') - o2 = dns.edns.ECSOption.from_text('1.2.4.0/24/0') + o1 = dns.edns.ECSOption.from_text("1.2.4.0/23/0") + o2 = dns.edns.ECSOption.from_text("1.2.4.0/24/0") self.assertTrue(o1 < o2) - o1 = dns.edns.ECSOption.from_text('1.2.4.0/24/0') - o2 = dns.edns.ECSOption.from_text('1.2.4.0/24/1') + o1 = dns.edns.ECSOption.from_text("1.2.4.0/24/0") + o2 = dns.edns.ECSOption.from_text("1.2.4.0/24/1") self.assertTrue(o1 < o2) def test_incompatible_relations(self): - o1 = dns.edns.GenericOption(3, b'data') - o2 = dns.edns.ECSOption.from_text('1.2.3.5/24/0') + o1 = dns.edns.GenericOption(3, b"data") + o2 = dns.edns.ECSOption.from_text("1.2.3.5/24/0") for oper in [operator.lt, operator.le, operator.ge, operator.gt]: self.assertRaises(TypeError, lambda: oper(o1, o2)) self.assertFalse(o1 == o2) @@ -206,7 +210,7 @@ class OptionTestCase(unittest.TestCase): self.value = value def to_wire(self, file=None): - data = struct.pack('!I', self.value) + data = struct.pack("!I", self.value) if file: file.write(data) else: @@ -214,15 +218,16 @@ class OptionTestCase(unittest.TestCase): @classmethod def from_wire_parser(cls, otype, parser): - (value,) = parser.get_struct('!I') + (value,) = parser.get_struct("!I") return cls(value) try: dns.edns.register_type(U32Option, U32OptionType) - generic = dns.edns.GenericOption(U32OptionType, b'\x00\x00\x00\x01') + generic = dns.edns.GenericOption(U32OptionType, b"\x00\x00\x00\x01") wire1 = generic.to_wire() - u32 = dns.edns.option_from_wire_parser(U32OptionType, - dns.wire.Parser(wire1)) + u32 = dns.edns.option_from_wire_parser( + U32OptionType, dns.wire.Parser(wire1) + ) self.assertEqual(u32.value, 1) wire2 = u32.to_wire() self.assertEqual(wire1, wire2) diff --git a/tests/test_entropy.py b/tests/test_entropy.py index 74092e7e..b502eebb 100644 --- a/tests/test_entropy.py +++ b/tests/test_entropy.py @@ -6,15 +6,16 @@ import dns.entropy # these tests are mostly for minimal coverage testing + class EntropyTestCase(unittest.TestCase): def test_pool(self): - pool = dns.entropy.EntropyPool(b'seed-value') + pool = dns.entropy.EntropyPool(b"seed-value") self.assertEqual(pool.random_8(), 94) self.assertEqual(pool.random_16(), 61532) self.assertEqual(pool.random_32(), 4226376065) self.assertEqual(pool.random_between(10, 50), 29) # stir in some not-really-entropy to exercise the stir API - pool.stir(b'not-really-entropy') + pool.stir(b"not-really-entropy") def test_pool_random(self): pool = dns.entropy.EntropyPool() @@ -24,8 +25,10 @@ class EntropyTestCase(unittest.TestCase): def test_pool_random_between(self): pool = dns.entropy.EntropyPool() + def bad(): pool.random_between(0, 4294967296) + self.assertRaises(ValueError, bad) v = pool.random_between(50, 50 + 100000) self.assertTrue(v >= 50 and v <= 50 + 100000) @@ -42,7 +45,6 @@ class EntropyTestCase(unittest.TestCase): class EntropyForcePoolTestCase(unittest.TestCase): - def setUp(self): self.saved_system_random = dns.entropy.system_random dns.entropy.system_random = None diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index caaf88a4..9d983793 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -22,11 +22,10 @@ from dns.exception import DNSException class FormatedError(DNSException): fmt = "Custom format: {parameter}" - supp_kwargs = {'parameter'} + supp_kwargs = {"parameter"} class ExceptionTestCase(unittest.TestCase): - def test_custom_message(self): msg = "this is a custom message" try: @@ -42,7 +41,7 @@ class ExceptionTestCase(unittest.TestCase): def test_formatted_error(self): """Exceptions with explicit format has to respect it.""" - params = {'parameter': 'value'} + params = {"parameter": "value"} try: raise FormatedError(**params) except FormatedError as ex: @@ -59,5 +58,6 @@ class ExceptionTestCase(unittest.TestCase): with self.assertRaises(AssertionError): raise FormatedError(unsupported=2) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_flags.py b/tests/test_flags.py index 3f5fc696..27cf03bf 100644 --- a/tests/test_flags.py +++ b/tests/test_flags.py @@ -21,10 +21,10 @@ import dns.flags import dns.rcode import dns.opcode -class FlagsTestCase(unittest.TestCase): +class FlagsTestCase(unittest.TestCase): def test_rcode1(self): - self.assertEqual(dns.rcode.from_text('FORMERR'), dns.rcode.FORMERR) + self.assertEqual(dns.rcode.from_text("FORMERR"), dns.rcode.FORMERR) def test_rcode2(self): self.assertEqual(dns.rcode.to_text(dns.rcode.FORMERR), "FORMERR") @@ -33,12 +33,10 @@ class FlagsTestCase(unittest.TestCase): self.assertEqual(dns.rcode.to_flags(dns.rcode.FORMERR), (1, 0)) def test_rcode4(self): - self.assertEqual(dns.rcode.to_flags(dns.rcode.BADVERS), - (0, 0x01000000)) + self.assertEqual(dns.rcode.to_flags(dns.rcode.BADVERS), (0, 0x01000000)) def test_rcode6(self): - self.assertEqual(dns.rcode.from_flags(0, 0x01000000), - dns.rcode.BADVERS) + self.assertEqual(dns.rcode.from_flags(0, 0x01000000), dns.rcode.BADVERS) def test_rcode7(self): self.assertEqual(dns.rcode.from_flags(5, 0), dns.rcode.REFUSED) @@ -46,36 +44,39 @@ class FlagsTestCase(unittest.TestCase): def test_rcode8(self): def bad(): dns.rcode.to_flags(4096) + self.assertRaises(ValueError, bad) def test_flags1(self): - self.assertEqual(dns.flags.from_text("RA RD AA QR"), - dns.flags.QR|dns.flags.AA|dns.flags.RD|dns.flags.RA) + self.assertEqual( + dns.flags.from_text("RA RD AA QR"), + dns.flags.QR | dns.flags.AA | dns.flags.RD | dns.flags.RA, + ) def test_flags2(self): - flags = dns.flags.QR|dns.flags.AA|dns.flags.RD|dns.flags.RA + flags = dns.flags.QR | dns.flags.AA | dns.flags.RD | dns.flags.RA self.assertEqual(dns.flags.to_text(flags), "QR AA RD RA") def test_rcode_badvers(self): rcode = dns.rcode.BADVERS self.assertEqual(rcode.value, 16) - self.assertEqual(rcode.name, 'BADVERS') - self.assertEqual(dns.rcode.to_text(rcode), 'BADVERS') + self.assertEqual(rcode.name, "BADVERS") + self.assertEqual(dns.rcode.to_text(rcode), "BADVERS") def test_rcode_badsig(self): rcode = dns.rcode.BADSIG self.assertEqual(rcode.value, 16) # Yes, we mean BADVERS on the next line. BADSIG and BADVERS have # the same code. - self.assertEqual(rcode.name, 'BADVERS') - self.assertEqual(dns.rcode.to_text(rcode), 'BADVERS') + self.assertEqual(rcode.name, "BADVERS") + self.assertEqual(dns.rcode.to_text(rcode), "BADVERS") # In TSIG text mode, it should be BADSIG - self.assertEqual(dns.rcode.to_text(rcode, True), 'BADSIG') + self.assertEqual(dns.rcode.to_text(rcode, True), "BADSIG") def test_unknown_rcode(self): with self.assertRaises(dns.rcode.UnknownRcode): - dns.rcode.Rcode.make('BOGUS') + dns.rcode.Rcode.make("BOGUS") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_generate.py b/tests/test_generate.py index 3f7c9259..75f22583 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -16,7 +16,8 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import sys -sys.path.insert(0, '../') # Force the local project to be *the* dns + +sys.path.insert(0, "../") # Force the local project to be *the* dns import unittest @@ -145,449 +146,595 @@ $GENERATE 1-10 foo$ CNAME $.0 @ 3600 IN NS ns2 """ + def _rdata_sort(a): return (a[0], a[2].rdclass, a[2].to_text()) class GenerateTestCase(unittest.TestCase): + def testFromText(self): # type: () -> None + def bad(): # type: () -> None + dns.zone.from_text(example_text, "example.", relativize=True) - def testFromText(self): # type: () -> None - def bad(): # type: () -> None - dns.zone.from_text(example_text, 'example.', relativize=True) self.assertRaises(dns.zone.NoSOA, bad) - def testFromText1(self): # type: () -> None - def bad(): # type: () -> None - dns.zone.from_text(example_text1, 'example.', relativize=True) + def testFromText1(self): # type: () -> None + def bad(): # type: () -> None + dns.zone.from_text(example_text1, "example.", relativize=True) + self.assertRaises(dns.zone.NoSOA, bad) - def testIterateAllRdatas2(self): # type: () -> None - z = dns.zone.from_text(example_text2, 'example.', relativize=True) + def testIterateAllRdatas2(self): # type: () -> None + z = dns.zone.from_text(example_text2, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - (dns.name.from_text('foo3', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.3')), - (dns.name.from_text('foo4', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.4')), - (dns.name.from_text('foo5', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.5'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("foo3", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.3"), + ), + ( + dns.name.from_text("foo4", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.4"), + ), + ( + dns.name.from_text("foo5", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.5"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testIterateAllRdatas3(self): # type: () -> None - z = dns.zone.from_text(example_text3, 'example.', relativize=True) + def testIterateAllRdatas3(self): # type: () -> None + z = dns.zone.from_text(example_text3, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - (dns.name.from_text('foo4', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.4')), - (dns.name.from_text('foo6', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.6')), - (dns.name.from_text('foo8', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.8'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("foo4", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.4"), + ), + ( + dns.name.from_text("foo6", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.6"), + ), + ( + dns.name.from_text("foo8", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.8"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate1(self): # type: () -> None - z = dns.zone.from_text(example_text4, 'example.', relativize=True) + + def testGenerate1(self): # type: () -> None + z = dns.zone.from_text(example_text4, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - - (dns.name.from_text('wp-db01.services.mozilla.com', None), - 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db02.services.mozilla.com', None), - 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db03.services.mozilla.com', None), - 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("wp-db01.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db02.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db03.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate2(self): # type: () -> None - z = dns.zone.from_text(example_text5, 'example.', relativize=True) + def testGenerate2(self): # type: () -> None + z = dns.zone.from_text(example_text5, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - - (dns.name.from_text('wp-db21.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db22.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db23.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("wp-db21.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db22.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db23.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate3(self): # type: () -> None - z = dns.zone.from_text(example_text6, 'example.', relativize=True) + def testGenerate3(self): # type: () -> None + z = dns.zone.from_text(example_text6, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - (dns.name.from_text('wp-db21.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db22.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.')), - - (dns.name.from_text('wp-db23.services.mozilla.com', None), 0, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - 'SERVER.FOOBAR.'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("wp-db21.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db22.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ( + dns.name.from_text("wp-db23.services.mozilla.com", None), + 0, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.CNAME, "SERVER.FOOBAR." + ), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate4(self): # type: () -> None - z = dns.zone.from_text(example_text7, 'example.', relativize=True) + def testGenerate4(self): # type: () -> None + z = dns.zone.from_text(example_text7, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - - (dns.name.from_text('sync1.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - - (dns.name.from_text('sync2.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - - (dns.name.from_text('sync3.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("sync1.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("sync2.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("sync3.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate6(self): # type: () -> None - z = dns.zone.from_text(example_text9, 'example.', relativize=True) + def testGenerate6(self): # type: () -> None + z = dns.zone.from_text(example_text9, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - - (dns.name.from_text('wp-db01', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - (dns.name.from_text('wp-db02', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - - (dns.name.from_text('sync1.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - - (dns.name.from_text('sync2.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0')), - - (dns.name.from_text('sync3.db', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.10.16.0'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("wp-db01", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("wp-db02", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("sync1.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("sync2.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ( + dns.name.from_text("sync3.db", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.10.16.0"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate7(self): # type: () -> None - z = dns.zone.from_text(example_text10, 'example.', relativize=True) + def testGenerate7(self): # type: () -> None + z = dns.zone.from_text(example_text10, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2')), - - (dns.name.from_text('27.2', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.PTR, - 'zlb1.oob')), - - (dns.name.from_text('28.2', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.PTR, - 'zlb2.oob'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ( + dns.name.from_text("27.2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.PTR, "zlb1.oob"), + ), + ( + dns.name.from_text("28.2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.PTR, "zlb2.oob"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) - def testGenerate8(self): # type: () -> None - z = dns.zone.from_text(example_text11, 'example.', relativize=True) + def testGenerate8(self): # type: () -> None + z = dns.zone.from_text(example_text11, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + exl = [ + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), + 3600, + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - - (dns.name.from_text('prefix-027', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.27')), - - (dns.name.from_text('prefix-028', None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.28')), - - (dns.name.from_text('ns1', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), - 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("prefix-027", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.27"), + ), + ( + dns.name.from_text("prefix-028", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.28"), + ), + ( + dns.name.from_text("ns1", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), + 3600, + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) def testNoOrigin(self): def bad(): - dns.zone.from_text('$GENERATE 1-10 fooo$ CNAME $.0') + dns.zone.from_text("$GENERATE 1-10 fooo$ CNAME $.0") + self.assertRaises(dns.zone.UnknownOrigin, bad) def testBadRdata(self): def bad(): - dns.zone.from_text('$GENERATE 1-10 fooo$ CNAME 10 $.0', 'example') + dns.zone.from_text("$GENERATE 1-10 fooo$ CNAME 10 $.0", "example") + self.assertRaises(dns.exception.SyntaxError, bad) def testUsesLastTTL(self): - z = dns.zone.from_text(last_ttl_input, 'example') - rrs = z.find_rrset('foo9', 'CNAME') + z = dns.zone.from_text(last_ttl_input, "example") + rrs = z.find_rrset("foo9", "CNAME") self.assertEqual(rrs.ttl, 300) def testClassMismatch(self): def bad(): - dns.zone.from_text('$GENERATE 1-10 fooo$ CH CNAME $.0', 'example') + dns.zone.from_text("$GENERATE 1-10 fooo$ CH CNAME $.0", "example") + self.assertRaises(dns.exception.SyntaxError, bad) def testUnknownRdatatype(self): def bad(): - dns.zone.from_text('$GENERATE 1-10 fooo$ BOGUSTYPE $.0', 'example') + dns.zone.from_text("$GENERATE 1-10 fooo$ BOGUSTYPE $.0", "example") + self.assertRaises(dns.exception.SyntaxError, bad) def testBadAndDangling(self): def bad1(): - dns.zone.from_text('$GENERATE bogus fooo$ CNAME $.0', - 'example.') + dns.zone.from_text("$GENERATE bogus fooo$ CNAME $.0", "example.") + self.assertRaises(dns.exception.SyntaxError, bad1) + def bad2(): - dns.zone.from_text('$GENERATE 1-10', - 'example.') + dns.zone.from_text("$GENERATE 1-10", "example.") + self.assertRaises(dns.exception.SyntaxError, bad2) + def bad3(): - dns.zone.from_text('$GENERATE 1-10 foo$', - 'example.') + dns.zone.from_text("$GENERATE 1-10 foo$", "example.") + self.assertRaises(dns.exception.SyntaxError, bad3) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_grange.py b/tests/test_grange.py index 9b5ddd24..e28b042b 100644 --- a/tests/test_grange.py +++ b/tests/test_grange.py @@ -16,7 +16,8 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import sys -sys.path.insert(0, '../') + +sys.path.insert(0, "../") import unittest @@ -26,39 +27,38 @@ import dns.grange class GRangeTestCase(unittest.TestCase): - def testFromText1(self): - start, stop, step = dns.grange.from_text('1-1') + start, stop, step = dns.grange.from_text("1-1") self.assertEqual(start, 1) self.assertEqual(stop, 1) self.assertEqual(step, 1) def testFromText2(self): - start, stop, step = dns.grange.from_text('1-4') + start, stop, step = dns.grange.from_text("1-4") self.assertEqual(start, 1) self.assertEqual(stop, 4) self.assertEqual(step, 1) def testFromText3(self): - start, stop, step = dns.grange.from_text('4-255') + start, stop, step = dns.grange.from_text("4-255") self.assertEqual(start, 4) self.assertEqual(stop, 255) self.assertEqual(step, 1) def testFromText4(self): - start, stop, step = dns.grange.from_text('1-1/1') + start, stop, step = dns.grange.from_text("1-1/1") self.assertEqual(start, 1) self.assertEqual(stop, 1) self.assertEqual(step, 1) def testFromText5(self): - start, stop, step = dns.grange.from_text('1-4/2') + start, stop, step = dns.grange.from_text("1-4/2") self.assertEqual(start, 1) self.assertEqual(stop, 4) self.assertEqual(step, 2) def testFromText6(self): - start, stop, step = dns.grange.from_text('4-255/77') + start, stop, step = dns.grange.from_text("4-255/77") self.assertEqual(start, 4) self.assertEqual(stop, 255) self.assertEqual(step, 77) @@ -68,26 +68,27 @@ class GRangeTestCase(unittest.TestCase): start = 2 stop = 1 step = 1 - dns.grange.from_text('%d-%d/%d' % (start, stop, step)) + dns.grange.from_text("%d-%d/%d" % (start, stop, step)) self.assertTrue(False) def testFailFromText2(self): with self.assertRaises(dns.exception.SyntaxError): - start = '-1' + start = "-1" stop = 3 step = 1 - dns.grange.from_text('%s-%d/%d' % (start, stop, step)) + dns.grange.from_text("%s-%d/%d" % (start, stop, step)) def testFailFromText3(self): with self.assertRaises(dns.exception.SyntaxError): start = 1 stop = 4 - step = '-2' - dns.grange.from_text('%d-%d/%s' % (start, stop, step)) + step = "-2" + dns.grange.from_text("%d-%d/%s" % (start, stop, step)) def testFailFromText4(self): with self.assertRaises(dns.exception.SyntaxError): - dns.grange.from_text('1') + dns.grange.from_text("1") + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_immutable.py b/tests/test_immutable.py index 8ab145ea..fa762d89 100644 --- a/tests/test_immutable.py +++ b/tests/test_immutable.py @@ -7,17 +7,16 @@ import dns._immutable_ctx class ImmutableTestCase(unittest.TestCase): - def test_immutable_dict_hash(self): - d1 = dns.immutable.Dict({'a': 1, 'b': 2}) - d2 = dns.immutable.Dict({'b': 2, 'a': 1}) - d3 = {'b': 2, 'a': 1} + d1 = dns.immutable.Dict({"a": 1, "b": 2}) + d2 = dns.immutable.Dict({"b": 2, "a": 1}) + d3 = {"b": 2, "a": 1} self.assertEqual(d1, d2) self.assertEqual(d2, d3) self.assertEqual(hash(d1), hash(d2)) def test_immutable_dict_hash_cache(self): - d = dns.immutable.Dict({'a': 1, 'b': 2}) + d = dns.immutable.Dict({"a": 1, "b": 2}) self.assertEqual(d._hash, None) h1 = hash(d) self.assertEqual(d._hash, h1) @@ -26,19 +25,17 @@ class ImmutableTestCase(unittest.TestCase): def test_constify(self): items = ( - (bytearray([1, 2, 3]), b'\x01\x02\x03'), + (bytearray([1, 2, 3]), b"\x01\x02\x03"), ((1, 2, 3), (1, 2, 3)), ((1, [2], 3), (1, (2,), 3)), ([1, 2, 3], (1, 2, 3)), - ([1, {'a': [1, 2]}], - (1, dns.immutable.Dict({'a': (1, 2)}))), - ('hi', 'hi'), - (b'hi', b'hi'), + ([1, {"a": [1, 2]}], (1, dns.immutable.Dict({"a": (1, 2)}))), + ("hi", "hi"), + (b"hi", b"hi"), ) for input, expected in items: self.assertEqual(dns.immutable.constify(input), expected) - self.assertIsInstance(dns.immutable.constify({'a': 1}), - dns.immutable.Dict) + self.assertIsInstance(dns.immutable.constify({"a": 1}), dns.immutable.Dict) class DecoratorTestCase(unittest.TestCase): @@ -55,6 +52,7 @@ class DecoratorTestCase(unittest.TestCase): def __init__(self, a, b): super().__init__(a, akw=20) self.b = b + B = self.immutable_module.immutable(B) # note C is immutable by inheritance @@ -62,19 +60,23 @@ class DecoratorTestCase(unittest.TestCase): def __init__(self, a, b, c): super().__init__(a, b) self.c = c + C = self.immutable_module.immutable(C) class SA: - __slots__ = ('a', 'akw') + __slots__ = ("a", "akw") + def __init__(self, a, akw=10): self.a = a self.akw = akw class SB(A): - __slots__ = ('b') + __slots__ = "b" + def __init__(self, a, b): super().__init__(a, akw=20) self.b = b + SB = self.immutable_module.immutable(SB) # note SC is immutable by inheritance and has no slots of its own @@ -82,6 +84,7 @@ class DecoratorTestCase(unittest.TestCase): def __init__(self, a, b, c): super().__init__(a, b) self.c = c + SC = self.immutable_module.immutable(SC) return ((A, B, C), (SA, SB, SC)) @@ -115,10 +118,11 @@ class DecoratorTestCase(unittest.TestCase): self.a = a self.b = a del self.b + A = self.immutable_module.immutable(A) a = A(10) self.assertEqual(a.a, 10) - self.assertFalse(hasattr(a, 'b')) + self.assertFalse(hasattr(a, "b")) def test_no_collateral_damage(self): @@ -129,6 +133,7 @@ class DecoratorTestCase(unittest.TestCase): class A: def __init__(self, a): self.a = a + A = self.immutable_module.immutable(A) class B: @@ -136,6 +141,7 @@ class DecoratorTestCase(unittest.TestCase): self.b = a.a + b # rudely attempt to mutate innocent immutable bystander 'a' a.a = 1000 + B = self.immutable_module.immutable(B) a = A(10) diff --git a/tests/test_message.py b/tests/test_message.py index b0112aea..d2f5b0ea 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -46,8 +46,10 @@ wwww.dnspython.org. IN A ;AUTHORITY ;ADDITIONAL""" -goodhex = b'04d201000001000000000001047777777709646e73707974686f6e' \ - b'036f726700000100010000291000000080000000' +goodhex = ( + b"04d201000001000000000001047777777709646e73707974686f6e" + b"036f726700000100010000291000000080000000" +) goodwire = binascii.unhexlify(goodhex) @@ -67,19 +69,21 @@ dnspython.org. 3600 IN NS woof.play-bow.org. woof.play-bow.org. 3600 IN A 204.152.186.150 """ -goodhex2 = '04d2 8500 0001 0001 0003 0001' \ - '09646e73707974686f6e036f726700 0006 0001' \ - 'c00c 0006 0001 00000e10 0028 ' \ - '04776f6f66c00c 0a686f73746d6173746572c00c' \ - '7764289c 00000e10 00000708 00093a80 00000e10' \ - 'c00c 0002 0001 00000e10 0014' \ - '036e7331057374616666076e6f6d696e756dc016' \ - 'c00c 0002 0001 00000e10 0006 036e7332c063' \ - 'c00c 0002 0001 00000e10 0010 04776f6f6608706c61792d626f77c016' \ - 'c091 0001 0001 00000e10 0004 cc98ba96' +goodhex2 = ( + "04d2 8500 0001 0001 0003 0001" + "09646e73707974686f6e036f726700 0006 0001" + "c00c 0006 0001 00000e10 0028 " + "04776f6f66c00c 0a686f73746d6173746572c00c" + "7764289c 00000e10 00000708 00093a80 00000e10" + "c00c 0002 0001 00000e10 0014" + "036e7331057374616666076e6f6d696e756dc016" + "c00c 0002 0001 00000e10 0006 036e7332c063" + "c00c 0002 0001 00000e10 0010 04776f6f6608706c61792d626f77c016" + "c091 0001 0001 00000e10 0004 cc98ba96" +) -goodwire2 = binascii.unhexlify(goodhex2.replace(' ', '').encode()) +goodwire2 = binascii.unhexlify(goodhex2.replace(" ", "").encode()) query_text_2 = """id 1234 opcode QUERY @@ -94,8 +98,10 @@ wwww.dnspython.org. IN A ;AUTHORITY ;ADDITIONAL""" -goodhex3 = b'04d2010f0001000000000001047777777709646e73707974686f6e' \ - b'036f726700000100010000291000ff0080000000' +goodhex3 = ( + b"04d2010f0001000000000001047777777709646e73707974686f6e" + b"036f726700000100010000291000ff0080000000" +) goodwire3 = binascii.unhexlify(goodhex3) @@ -109,8 +115,8 @@ Königsgäßchen. IN NS Königsgäßchen. 3600 IN NS Königsgäßchen. """ -class MessageTestCase(unittest.TestCase): +class MessageTestCase(unittest.TestCase): def test_class(self): m = dns.message.from_text(query_text) self.assertTrue(isinstance(m, dns.message.QueryMessage)) @@ -155,8 +161,8 @@ class MessageTestCase(unittest.TestCase): self.assertEqual(str(m), query_text_2) def test_EDNS_options_wire(self): - m = dns.message.make_query('foo', 'A') - opt = dns.edns.GenericOption(3, b'data') + m = dns.message.make_query("foo", "A") + opt = dns.edns.GenericOption(3, b"data") m.use_edns(options=[opt]) m2 = dns.message.from_wire(m.to_wire()) self.assertEqual(m2.edns, 0) @@ -167,12 +173,16 @@ class MessageTestCase(unittest.TestCase): def bad(): q = dns.message.from_text(query_text) for i in range(0, 25): - rrset = dns.rrset.from_text('foo%d.' % i, 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '10.0.0.%d' % i) + rrset = dns.rrset.from_text( + "foo%d." % i, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + "10.0.0.%d" % i, + ) q.additional.append(rrset) q.to_wire(max_size=512) + self.assertRaises(dns.exception.TooBig, bad) def test_answer1(self): @@ -182,31 +192,34 @@ class MessageTestCase(unittest.TestCase): def test_TrailingJunk(self): def bad(): - badwire = goodwire + b'\x00' + badwire = goodwire + b"\x00" dns.message.from_wire(badwire) + self.assertRaises(dns.message.TrailingJunk, bad) def test_ShortHeader(self): def bad(): - badwire = b'\x00' * 11 + badwire = b"\x00" * 11 dns.message.from_wire(badwire) + self.assertRaises(dns.message.ShortHeader, bad) def test_RespondingToResponse(self): def bad(): - q = dns.message.make_query('foo', 'A') + q = dns.message.make_query("foo", "A") r1 = dns.message.make_response(q) dns.message.make_response(r1) + self.assertRaises(dns.exception.FormError, bad) def test_RespondingToEDNSRequestAndSettingRA(self): - q = dns.message.make_query('foo', 'A', use_edns=0) + q = dns.message.make_query("foo", "A", use_edns=0) r = dns.message.make_response(q, True) self.assertTrue(r.flags & dns.flags.RA != 0) self.assertEqual(r.edns, 0) def test_ExtendedRcodeSetting(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") m.set_rcode(4095) self.assertEqual(m.rcode(), 4095) self.assertEqual(m.edns, 0) @@ -214,63 +227,59 @@ class MessageTestCase(unittest.TestCase): self.assertEqual(m.rcode(), 2) def test_EDNSVersionCoherence(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") m.use_edns(1) self.assertEqual((m.ednsflags >> 16) & 0xFF, 1) def test_SettingNoEDNSOptionsImpliesNoEDNS(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") self.assertEqual(m.edns, -1) def test_SettingEDNSFlagsImpliesEDNS(self): - m = dns.message.make_query('foo', 'A', ednsflags=dns.flags.DO) + m = dns.message.make_query("foo", "A", ednsflags=dns.flags.DO) self.assertEqual(m.edns, 0) def test_SettingEDNSPayloadImpliesEDNS(self): - m = dns.message.make_query('foo', 'A', payload=4096) + m = dns.message.make_query("foo", "A", payload=4096) self.assertEqual(m.edns, 0) def test_SettingEDNSRequestPayloadImpliesEDNS(self): - m = dns.message.make_query('foo', 'A', request_payload=4096) + m = dns.message.make_query("foo", "A", request_payload=4096) self.assertEqual(m.edns, 0) def test_SettingOptionsImpliesEDNS(self): - m = dns.message.make_query('foo', 'A', options=[]) + m = dns.message.make_query("foo", "A", options=[]) self.assertEqual(m.edns, 0) def test_FindRRset(self): a = dns.message.from_text(answer_text) - n = dns.name.from_text('dnspython.org.') + n = dns.name.from_text("dnspython.org.") rrs1 = a.find_rrset(a.answer, n, dns.rdataclass.IN, dns.rdatatype.SOA) - rrs2 = a.find_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, - dns.rdatatype.SOA) + rrs2 = a.find_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, dns.rdatatype.SOA) self.assertEqual(rrs1, rrs2) def test_FindRRsetUnindexed(self): a = dns.message.from_text(answer_text) a.index = None - n = dns.name.from_text('dnspython.org.') + n = dns.name.from_text("dnspython.org.") rrs1 = a.find_rrset(a.answer, n, dns.rdataclass.IN, dns.rdatatype.SOA) - rrs2 = a.find_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, - dns.rdatatype.SOA) + rrs2 = a.find_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, dns.rdatatype.SOA) self.assertEqual(rrs1, rrs2) def test_GetRRset(self): a = dns.message.from_text(answer_text) a.index = None - n = dns.name.from_text('dnspython.org.') + n = dns.name.from_text("dnspython.org.") rrs1 = a.get_rrset(a.answer, n, dns.rdataclass.IN, dns.rdatatype.SOA) - rrs2 = a.get_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, - dns.rdatatype.SOA) + rrs2 = a.get_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, dns.rdatatype.SOA) self.assertEqual(rrs1, rrs2) def test_GetNonexistentRRset(self): a = dns.message.from_text(answer_text) a.index = None - n = dns.name.from_text('dnspython.org.') + n = dns.name.from_text("dnspython.org.") rrs1 = a.get_rrset(a.answer, n, dns.rdataclass.IN, dns.rdatatype.TXT) - rrs2 = a.get_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, - dns.rdatatype.TXT) + rrs2 = a.get_rrset(dns.message.ANSWER, n, dns.rdataclass.IN, dns.rdatatype.TXT) self.assertTrue(rrs1 is None) self.assertEqual(rrs1, rrs2) @@ -280,6 +289,7 @@ class MessageTestCase(unittest.TestCase): a.flags |= dns.flags.TC wire = a.to_wire(want_shuffle=False) dns.message.from_wire(wire, raise_on_truncation=True) + self.assertRaises(dns.message.Truncated, bad) def test_MessyTruncated(self): @@ -288,78 +298,82 @@ class MessageTestCase(unittest.TestCase): a.flags |= dns.flags.TC wire = a.to_wire(want_shuffle=False) dns.message.from_wire(wire[:-3], raise_on_truncation=True) + self.assertRaises(dns.message.Truncated, bad) def test_IDNA_2003(self): a = dns.message.from_text(idna_text, idna_codec=dns.name.IDNA_2003) - rrs = dns.rrset.from_text_list('xn--knigsgsschen-lcb0w.', 30, - 'in', 'ns', - ['xn--knigsgsschen-lcb0w.'], - idna_codec=dns.name.IDNA_2003) + rrs = dns.rrset.from_text_list( + "xn--knigsgsschen-lcb0w.", + 30, + "in", + "ns", + ["xn--knigsgsschen-lcb0w."], + idna_codec=dns.name.IDNA_2003, + ) self.assertEqual(a.answer[0], rrs) - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def test_IDNA_2008(self): a = dns.message.from_text(idna_text, idna_codec=dns.name.IDNA_2008) - rrs = dns.rrset.from_text_list('xn--knigsgchen-b4a3dun.', 30, - 'in', 'ns', - ['xn--knigsgchen-b4a3dun.'], - idna_codec=dns.name.IDNA_2008) + rrs = dns.rrset.from_text_list( + "xn--knigsgchen-b4a3dun.", + 30, + "in", + "ns", + ["xn--knigsgchen-b4a3dun."], + idna_codec=dns.name.IDNA_2008, + ) self.assertEqual(a.answer[0], rrs) def test_bad_section_number(self): - m = dns.message.make_query('foo', 'A') - self.assertRaises(ValueError, - lambda: m.section_number(123)) + m = dns.message.make_query("foo", "A") + self.assertRaises(ValueError, lambda: m.section_number(123)) def test_section_from_number(self): - m = dns.message.make_query('foo', 'A') - self.assertEqual(m.section_from_number(dns.message.QUESTION), - m.question) - self.assertEqual(m.section_from_number(dns.message.ANSWER), - m.answer) - self.assertEqual(m.section_from_number(dns.message.AUTHORITY), - m.authority) - self.assertEqual(m.section_from_number(dns.message.ADDITIONAL), - m.additional) - self.assertRaises(ValueError, - lambda: m.section_from_number(999)) + m = dns.message.make_query("foo", "A") + self.assertEqual(m.section_from_number(dns.message.QUESTION), m.question) + self.assertEqual(m.section_from_number(dns.message.ANSWER), m.answer) + self.assertEqual(m.section_from_number(dns.message.AUTHORITY), m.authority) + self.assertEqual(m.section_from_number(dns.message.ADDITIONAL), m.additional) + self.assertRaises(ValueError, lambda: m.section_from_number(999)) def test_wanting_EDNS_true_is_EDNS0(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") self.assertEqual(m.edns, -1) m.use_edns(True) self.assertEqual(m.edns, 0) def test_wanting_DNSSEC_turns_on_EDNS(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") self.assertEqual(m.edns, -1) m.want_dnssec() self.assertEqual(m.edns, 0) self.assertTrue(m.ednsflags & dns.flags.DO) def test_EDNS_default_payload_is_1232(self): - m = dns.message.make_query('foo', 'A') + m = dns.message.make_query("foo", "A") m.use_edns() self.assertEqual(m.payload, dns.message.DEFAULT_EDNS_PAYLOAD) def test_from_file(self): - m = dns.message.from_file(here('query')) + m = dns.message.from_file(here("query")) expected = dns.message.from_text(query_text) self.assertEqual(m, expected) def test_explicit_header_comment(self): - m = dns.message.from_text(';HEADER\n' + query_text) + m = dns.message.from_text(";HEADER\n" + query_text) expected = dns.message.from_text(query_text) self.assertEqual(m, expected) def test_repr(self): q = dns.message.from_text(query_text) - self.assertEqual(repr(q), '') + self.assertEqual(repr(q), "") def test_non_question_setters(self): - rrset = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1') + rrset = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1") q = dns.message.QueryMessage(id=1) q.answer = [rrset] self.assertEqual(q.sections[1], [rrset]) @@ -372,7 +386,7 @@ class MessageTestCase(unittest.TestCase): self.assertEqual(q.sections[3], [rrset]) def test_is_a_response_empty_question(self): - q = dns.message.make_query('www.dnspython.org.', 'a') + q = dns.message.make_query("www.dnspython.org.", "a") r = dns.message.make_response(q) r.question = [] r.set_rcode(dns.rcode.FORMERR) @@ -386,8 +400,8 @@ class MessageTestCase(unittest.TestCase): self.assertFalse(q.is_response(r)) r = dns.update.UpdateMessage(id=1) self.assertFalse(q.is_response(r)) - q1 = dns.message.make_query('www.dnspython.org.', 'a') - q2 = dns.message.make_query('www.google.com.', 'a') + q1 = dns.message.make_query("www.dnspython.org.", "a") + q2 = dns.message.make_query("www.google.com.", "a") # Give them the same id, as we want to test if responses for # differing questions are rejected. q1.id = 1 @@ -403,17 +417,22 @@ class MessageTestCase(unittest.TestCase): # something in the response's question section that is not in # the question's. We have to do multiple questions to test # this :) - r = dns.message.make_query('www.dnspython.org.', 'a') + r = dns.message.make_query("www.dnspython.org.", "a") r.flags |= dns.flags.QR r.id = 1 - r.find_rrset(r.question, dns.name.from_text('example'), - dns.rdataclass.IN, dns.rdatatype.A, create=True, - force_unique=True) + r.find_rrset( + r.question, + dns.name.from_text("example"), + dns.rdataclass.IN, + dns.rdatatype.A, + create=True, + force_unique=True, + ) self.assertFalse(q1.is_response(r)) def test_more_not_equal_cases(self): - q1 = dns.message.make_query('www.dnspython.org.', 'a') - q2 = dns.message.make_query('www.dnspython.org.', 'a') + q1 = dns.message.make_query("www.dnspython.org.", "a") + q2 = dns.message.make_query("www.dnspython.org.", "a") # ensure ids are same q1.id = 1 q2.id = 1 @@ -421,39 +440,49 @@ class MessageTestCase(unittest.TestCase): q2.flags |= dns.flags.QR self.assertFalse(q1 == q2) q2.flags = q1.flags - q2.find_rrset(q2.question, dns.name.from_text('example'), - dns.rdataclass.IN, dns.rdatatype.A, create=True, - force_unique=True) + q2.find_rrset( + q2.question, + dns.name.from_text("example"), + dns.rdataclass.IN, + dns.rdatatype.A, + create=True, + force_unique=True, + ) self.assertFalse(q1 == q2) def test_edns_properties(self): - q = dns.message.make_query('www.dnspython.org.', 'a') + q = dns.message.make_query("www.dnspython.org.", "a") self.assertEqual(q.edns, -1) self.assertEqual(q.payload, 0) self.assertEqual(q.options, ()) - q = dns.message.make_query('www.dnspython.org.', 'a', use_edns=0, - payload=4096) + q = dns.message.make_query("www.dnspython.org.", "a", use_edns=0, payload=4096) self.assertEqual(q.edns, 0) self.assertEqual(q.payload, 4096) self.assertEqual(q.options, ()) def test_setting_id(self): - q = dns.message.make_query('www.dnspython.org.', 'a', id=12345) + q = dns.message.make_query("www.dnspython.org.", "a", id=12345) self.assertEqual(q.id, 12345) def test_setting_flags(self): - q = dns.message.make_query('www.dnspython.org.', 'a', - flags=dns.flags.RD|dns.flags.CD) - self.assertEqual(q.flags, dns.flags.RD|dns.flags.CD) + q = dns.message.make_query( + "www.dnspython.org.", "a", flags=dns.flags.RD | dns.flags.CD + ) + self.assertEqual(q.flags, dns.flags.RD | dns.flags.CD) self.assertEqual(q.flags, 0x0110) def test_generic_message_class(self): q1 = dns.message.Message(id=1) q1.set_opcode(dns.opcode.NOTIFY) q1.flags |= dns.flags.AA - q1.find_rrset(q1.question, dns.name.from_text('example'), - dns.rdataclass.IN, dns.rdatatype.SOA, create=True, - force_unique=True) + q1.find_rrset( + q1.question, + dns.name.from_text("example"), + dns.rdataclass.IN, + dns.rdatatype.SOA, + create=True, + force_unique=True, + ) w = q1.to_wire() q2 = dns.message.from_wire(w) self.assertTrue(isinstance(q2, dns.message.Message)) @@ -478,7 +507,7 @@ class MessageTestCase(unittest.TestCase): dns.message.from_wire(wire) # Owner name not root name q = dns.message.Message(id=1) - rrs = dns.rrset.from_rdata('foo.', 0, opt) + rrs = dns.rrset.from_rdata("foo.", 0, opt) q.additional.append(rrs) wire = q.to_wire() with self.assertRaises(dns.message.BadEDNS): @@ -493,12 +522,20 @@ class MessageTestCase(unittest.TestCase): dns.message.from_wire(wire) def test_bad_tsig(self): - keyname = dns.name.from_text('key.') + keyname = dns.name.from_text("key.") # Not in additional q = dns.message.Message(id=1) - tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, - dns.tsig.HMAC_SHA256, 0, 300, b'1234', - 0, 0, b'') + tsig = dns.rdtypes.ANY.TSIG.TSIG( + dns.rdataclass.ANY, + dns.rdatatype.TSIG, + dns.tsig.HMAC_SHA256, + 0, + 300, + b"1234", + 0, + 0, + b"", + ) rrs = dns.rrset.from_rdata(keyname, 0, tsig) q.answer.append(rrs) wire = q.to_wire() @@ -512,31 +549,39 @@ class MessageTestCase(unittest.TestCase): with self.assertRaises(dns.message.BadTSIG): dns.message.from_wire(wire) # Class not ANY - tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.IN, dns.rdatatype.TSIG, - dns.tsig.HMAC_SHA256, 0, 300, b'1234', - 0, 0, b'') + tsig = dns.rdtypes.ANY.TSIG.TSIG( + dns.rdataclass.IN, + dns.rdatatype.TSIG, + dns.tsig.HMAC_SHA256, + 0, + 300, + b"1234", + 0, + 0, + b"", + ) rrs = dns.rrset.from_rdata(keyname, 0, tsig) wire = q.to_wire() with self.assertRaises(dns.message.BadTSIG): dns.message.from_wire(wire) def test_read_no_content_message(self): - m = dns.message.from_text(';comment') + m = dns.message.from_text(";comment") self.assertIsInstance(m, dns.message.QueryMessage) def test_eflags_turns_on_edns(self): - m = dns.message.from_text('eflags DO') + m = dns.message.from_text("eflags DO") self.assertIsInstance(m, dns.message.QueryMessage) self.assertEqual(m.edns, 0) def test_payload_turns_on_edns(self): - m = dns.message.from_text('payload 1200') + m = dns.message.from_text("payload 1200") self.assertIsInstance(m, dns.message.QueryMessage) self.assertEqual(m.payload, 1200) def test_bogus_header(self): with self.assertRaises(dns.message.UnknownHeaderField): - dns.message.from_text('bogus foo') + dns.message.from_text("bogus foo") def test_question_only(self): m = dns.message.from_text(answer_text) @@ -549,142 +594,178 @@ class MessageTestCase(unittest.TestCase): self.assertEqual(len(r.additional), 0) def test_bad_resolve_chaining(self): - r = dns.message.make_query('www.dnspython.org.', 'a') + r = dns.message.make_query("www.dnspython.org.", "a") with self.assertRaises(dns.message.NotQueryResponse): r.resolve_chaining() r.flags |= dns.flags.QR r.id = 1 - r.find_rrset(r.question, dns.name.from_text('example'), - dns.rdataclass.IN, dns.rdatatype.A, create=True, - force_unique=True) + r.find_rrset( + r.question, + dns.name.from_text("example"), + dns.rdataclass.IN, + dns.rdatatype.A, + create=True, + force_unique=True, + ) with self.assertRaises(dns.exception.FormError): r.resolve_chaining() def test_resolve_chaining_no_infinite_loop(self): - r = dns.message.from_text('''id 1 + r = dns.message.from_text( + """id 1 flags QR ;QUESTION www.example. IN CNAME ;AUTHORITY example. 300 IN SOA . . 1 2 3 4 5 -''') +""" + ) # passing is not going into an infinite loop in this call result = r.resolve_chaining() - self.assertEqual(result.canonical_name, - dns.name.from_text('www.example.')) + self.assertEqual(result.canonical_name, dns.name.from_text("www.example.")) self.assertEqual(result.minimum_ttl, 5) self.assertIsNone(result.answer) def test_bad_text_questions(self): with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 ;QUESTION example. -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 ;QUESTION example. IN -''') +""" + ) with self.assertRaises(dns.rdatatype.UnknownRdatatype): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 ;QUESTION example. INA -''') +""" + ) with self.assertRaises(dns.rdatatype.UnknownRdatatype): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 ;QUESTION example. IN BOGUS -''') +""" + ) def test_bad_text_rrs(self): with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. IN -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. 300 -''') +""" + ) with self.assertRaises(dns.rdatatype.UnknownRdatatype): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. 30a IN A -''') +""" + ) with self.assertRaises(dns.rdatatype.UnknownRdatatype): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. 300 INA A -''') +""" + ) with self.assertRaises(dns.exception.UnexpectedEnd): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION example. IN A ;ANSWER example. 300 IN A -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR opcode UPDATE ;ZONE example. IN SOA ;UPDATE example. 300 IN A -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR opcode UPDATE ;ZONE example. IN SOA ;UPDATE example. 300 NONE A -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR opcode UPDATE ;ZONE example. IN SOA ;PREREQ example. 300 NONE A 10.0.0.1 -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;ANSWER 300 IN A 10.0.0.1 -''') +""" + ) with self.assertRaises(dns.exception.SyntaxError): - dns.message.from_text('''id 1 + dns.message.from_text( + """id 1 flags QR ;QUESTION IN SOA -''') +""" + ) def test_from_wire_makes_Flag(self): m = dns.message.from_wire(goodwire) @@ -693,7 +774,7 @@ flags QR def test_continue_on_error(self): good_message = dns.message.from_text( -"""id 1234 + """id 1234 opcode QUERY rcode NOERROR flags QR AA RD @@ -703,21 +784,20 @@ www.dnspython.org. IN SOA www.dnspython.org. 300 IN SOA . . 1 2 3 4 4294967295 www.dnspython.org. 300 IN A 1.2.3.4 www.dnspython.org. 300 IN AAAA ::1 -""") +""" + ) wire = good_message.to_wire() # change ANCOUNT to 255 - bad_wire = wire[:6] + b'\x00\xff' + wire[8:] + bad_wire = wire[:6] + b"\x00\xff" + wire[8:] # change AAAA into rdata with rdlen 0 - bad_wire = bad_wire[:-18] + b'\x00' * 2 + bad_wire = bad_wire[:-18] + b"\x00" * 2 m = dns.message.from_wire(bad_wire, continue_on_error=True) self.assertEqual(len(m.errors), 2) print(m.errors) - self.assertEqual(str(m.errors[0].exception), - 'IPv6 addresses are 16 bytes long') - self.assertEqual(str(m.errors[1].exception), - 'DNS message is malformed.') + self.assertEqual(str(m.errors[0].exception), "IPv6 addresses are 16 bytes long") + self.assertEqual(str(m.errors[1].exception), "DNS message is malformed.") expected_message = dns.message.from_text( -"""id 1234 + """id 1234 opcode QUERY rcode NOERROR flags QR AA RD @@ -726,9 +806,10 @@ www.dnspython.org. IN SOA ;ANSWER www.dnspython.org. 300 IN SOA . . 1 2 3 4 4294967295 www.dnspython.org. 300 IN A 1.2.3.4 -""") +""" + ) self.assertEqual(m, expected_message) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_name.py b/tests/test_name.py index 45f83793..815fb102 100644 --- a/tests/test_name.py +++ b/tests/test_name.py @@ -16,7 +16,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Dict # pylint: disable=unused-import +from typing import Dict # pylint: disable=unused-import import copy import operator import pickle @@ -33,58 +33,58 @@ import dns.e164 class NameTestCase(unittest.TestCase): def setUp(self): - self.origin = dns.name.from_text('example.') + self.origin = dns.name.from_text("example.") def testFromTextRel1(self): - n = dns.name.from_text('foo.bar') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo.bar") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testFromTextRel2(self): - n = dns.name.from_text('foo.bar', origin=self.origin) - self.assertEqual(n.labels, (b'foo', b'bar', b'example', b'')) + n = dns.name.from_text("foo.bar", origin=self.origin) + self.assertEqual(n.labels, (b"foo", b"bar", b"example", b"")) def testFromTextRel3(self): - n = dns.name.from_text('foo.bar', origin=None) - self.assertEqual(n.labels, (b'foo', b'bar')) + n = dns.name.from_text("foo.bar", origin=None) + self.assertEqual(n.labels, (b"foo", b"bar")) def testFromTextRel4(self): - n = dns.name.from_text('@', origin=None) + n = dns.name.from_text("@", origin=None) self.assertEqual(n, dns.name.empty) def testFromTextRel5(self): - n = dns.name.from_text('@', origin=self.origin) + n = dns.name.from_text("@", origin=self.origin) self.assertEqual(n, self.origin) def testFromTextAbs1(self): - n = dns.name.from_text('foo.bar.') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo.bar.") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testTortureFromText(self): good = [ - br'.', - br'a', - br'a.', - br'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', - br'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', - br'\000.\008.\010.\032.\046.\092.\099.\255', - br'\\', - br'\..\.', - br'\\.\\', - br'!"#%&/()=+-', - br'\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255', - ] + rb".", + rb"a", + rb"a.", + rb"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + rb"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + rb"\000.\008.\010.\032.\046.\092.\099.\255", + rb"\\", + rb"\..\.", + rb"\\.\\", + rb'!"#%&/()=+-', + rb"\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255", + ] bad = [ - br'..', - br'.a', - br'\\..', - b'\\', # yes, we don't want the 'r' prefix! - br'\0', - br'\00', - br'\00Z', - br'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', - br'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', - br'\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255', - ] + rb"..", + rb".a", + rb"\\..", + b"\\", # yes, we don't want the 'r' prefix! + rb"\0", + rb"\00", + rb"\00Z", + rb"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + rb"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + rb"\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255.\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255\255", + ] for t in good: try: dns.name.from_text(t) @@ -102,11 +102,13 @@ class NameTestCase(unittest.TestCase): def testImmutable1(self): def bad(): self.origin.labels = () + self.assertRaises(TypeError, bad) def testImmutable2(self): def bad(): - self.origin.labels[0] = 'foo' # type: ignore + self.origin.labels[0] = "foo" # type: ignore + self.assertRaises(TypeError, bad) def testAbs1(self): @@ -119,19 +121,19 @@ class NameTestCase(unittest.TestCase): self.assertTrue(self.origin.is_absolute()) def testAbs4(self): - n = dns.name.from_text('foo', origin=None) + n = dns.name.from_text("foo", origin=None) self.assertFalse(n.is_absolute()) def testWild1(self): - n = dns.name.from_text('*.foo', origin=None) + n = dns.name.from_text("*.foo", origin=None) self.assertTrue(n.is_wild()) def testWild2(self): - n = dns.name.from_text('*a.foo', origin=None) + n = dns.name.from_text("*a.foo", origin=None) self.assertFalse(n.is_wild()) def testWild3(self): - n = dns.name.from_text('a.*.foo', origin=None) + n = dns.name.from_text("a.*.foo", origin=None) self.assertFalse(n.is_wild()) def testWild4(self): @@ -141,21 +143,21 @@ class NameTestCase(unittest.TestCase): self.assertFalse(dns.name.empty.is_wild()) def testHash1(self): - n1 = dns.name.from_text('fOo.COM') - n2 = dns.name.from_text('foo.com') + n1 = dns.name.from_text("fOo.COM") + n2 = dns.name.from_text("foo.com") self.assertEqual(hash(n1), hash(n2)) def testCompare1(self): - n1 = dns.name.from_text('a') - n2 = dns.name.from_text('b') + n1 = dns.name.from_text("a") + n2 = dns.name.from_text("b") self.assertLess(n1, n2) self.assertLessEqual(n1, n2) self.assertGreater(n2, n1) self.assertGreaterEqual(n2, n1) def testCompare2(self): - n1 = dns.name.from_text('') - n2 = dns.name.from_text('b') + n1 = dns.name.from_text("") + n2 = dns.name.from_text("b") self.assertLess(n1, n2) self.assertLessEqual(n1, n2) self.assertGreater(n2, n1) @@ -175,15 +177,15 @@ class NameTestCase(unittest.TestCase): self.assertFalse(dns.name.root.is_subdomain(dns.name.empty)) def testSubdomain3(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(n.is_subdomain(self.origin)) def testSubdomain4(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(n.is_subdomain(dns.name.root)) def testSubdomain5(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(n.is_subdomain(n)) def testSuperdomain1(self): @@ -193,277 +195,286 @@ class NameTestCase(unittest.TestCase): self.assertFalse(dns.name.root.is_superdomain(dns.name.empty)) def testSuperdomain3(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(self.origin.is_superdomain(n)) def testSuperdomain4(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(dns.name.root.is_superdomain(n)) def testSuperdomain5(self): - n = dns.name.from_text('foo', origin=self.origin) + n = dns.name.from_text("foo", origin=self.origin) self.assertTrue(n.is_superdomain(n)) def testCanonicalize1(self): - n = dns.name.from_text('FOO.bar', origin=self.origin) + n = dns.name.from_text("FOO.bar", origin=self.origin) c = n.canonicalize() - self.assertEqual(c.labels, (b'foo', b'bar', b'example', b'')) + self.assertEqual(c.labels, (b"foo", b"bar", b"example", b"")) def testToText1(self): - n = dns.name.from_text('FOO.bar', origin=self.origin) + n = dns.name.from_text("FOO.bar", origin=self.origin) t = n.to_text() - self.assertEqual(t, 'FOO.bar.example.') + self.assertEqual(t, "FOO.bar.example.") def testToText2(self): - n = dns.name.from_text('FOO.bar', origin=self.origin) + n = dns.name.from_text("FOO.bar", origin=self.origin) t = n.to_text(True) - self.assertEqual(t, 'FOO.bar.example') + self.assertEqual(t, "FOO.bar.example") def testToText3(self): - n = dns.name.from_text('FOO.bar', origin=None) + n = dns.name.from_text("FOO.bar", origin=None) t = n.to_text() - self.assertEqual(t, 'FOO.bar') + self.assertEqual(t, "FOO.bar") def testToText4(self): t = dns.name.empty.to_text() - self.assertEqual(t, '@') + self.assertEqual(t, "@") def testToText5(self): t = dns.name.root.to_text() - self.assertEqual(t, '.') + self.assertEqual(t, ".") def testToText6(self): - n = dns.name.from_text('FOO bar', origin=None) + n = dns.name.from_text("FOO bar", origin=None) t = n.to_text() - self.assertEqual(t, r'FOO\032bar') + self.assertEqual(t, r"FOO\032bar") def testToText7(self): - n = dns.name.from_text(r'FOO\.bar', origin=None) + n = dns.name.from_text(r"FOO\.bar", origin=None) t = n.to_text() - self.assertEqual(t, r'FOO\.bar') + self.assertEqual(t, r"FOO\.bar") def testToText8(self): - n = dns.name.from_text(r'\070OO\.bar', origin=None) + n = dns.name.from_text(r"\070OO\.bar", origin=None) t = n.to_text() - self.assertEqual(t, r'FOO\.bar') + self.assertEqual(t, r"FOO\.bar") def testToText9(self): - n = dns.name.from_text('FOO bar', origin=None) + n = dns.name.from_text("FOO bar", origin=None) t = n.to_unicode() - self.assertEqual(t, 'FOO\\032bar') + self.assertEqual(t, "FOO\\032bar") def testToText10(self): t = dns.name.empty.to_unicode() - self.assertEqual(t, '@') + self.assertEqual(t, "@") def testToText11(self): t = dns.name.root.to_unicode() - self.assertEqual(t, '.') + self.assertEqual(t, ".") def testToText12(self): - n = dns.name.from_text(r'a\.b.c') + n = dns.name.from_text(r"a\.b.c") t = n.to_unicode() - self.assertEqual(t, r'a\.b.c.') + self.assertEqual(t, r"a\.b.c.") def testToText13(self): - n = dns.name.from_text(r'\150\151\152\153\154\155\156\157\158\159.') + n = dns.name.from_text(r"\150\151\152\153\154\155\156\157\158\159.") t = n.to_text() - self.assertEqual(t, r'\150\151\152\153\154\155\156\157\158\159.') + self.assertEqual(t, r"\150\151\152\153\154\155\156\157\158\159.") def testToText14(self): # Something that didn't start as unicode should go to escapes and not # raise due to interpreting arbitrary binary DNS labels as UTF-8. - n = dns.name.from_text(r'\150\151\152\153\154\155\156\157\158\159.') + n = dns.name.from_text(r"\150\151\152\153\154\155\156\157\158\159.") t = n.to_unicode() - self.assertEqual(t, r'\150\151\152\153\154\155\156\157\158\159.') + self.assertEqual(t, r"\150\151\152\153\154\155\156\157\158\159.") def testSlice1(self): - n = dns.name.from_text(r'a.b.c.', origin=None) + n = dns.name.from_text(r"a.b.c.", origin=None) s = n[:] - self.assertEqual(s, (b'a', b'b', b'c', b'')) + self.assertEqual(s, (b"a", b"b", b"c", b"")) def testSlice2(self): - n = dns.name.from_text(r'a.b.c.', origin=None) + n = dns.name.from_text(r"a.b.c.", origin=None) s = n[:2] - self.assertEqual(s, (b'a', b'b')) + self.assertEqual(s, (b"a", b"b")) def testSlice3(self): - n = dns.name.from_text(r'a.b.c.', origin=None) + n = dns.name.from_text(r"a.b.c.", origin=None) s = n[2:] - self.assertEqual(s, (b'c', b'')) + self.assertEqual(s, (b"c", b"")) def testEmptyLabel1(self): def bad(): - dns.name.Name(['a', '', 'b']) + dns.name.Name(["a", "", "b"]) + self.assertRaises(dns.name.EmptyLabel, bad) def testEmptyLabel2(self): def bad(): - dns.name.Name(['', 'b']) + dns.name.Name(["", "b"]) + self.assertRaises(dns.name.EmptyLabel, bad) def testEmptyLabel3(self): - n = dns.name.Name(['b', '']) + n = dns.name.Name(["b", ""]) self.assertTrue(n) def testLongLabel(self): - n = dns.name.Name(['a' * 63]) + n = dns.name.Name(["a" * 63]) self.assertTrue(n) def testLabelTooLong(self): def bad(): - dns.name.Name(['a' * 64, 'b']) + dns.name.Name(["a" * 64, "b"]) + self.assertRaises(dns.name.LabelTooLong, bad) def testLongName(self): - n = dns.name.Name(['a' * 63, 'a' * 63, 'a' * 63, 'a' * 62]) + n = dns.name.Name(["a" * 63, "a" * 63, "a" * 63, "a" * 62]) self.assertTrue(n) def testNameTooLong(self): def bad(): - dns.name.Name(['a' * 63, 'a' * 63, 'a' * 63, 'a' * 63]) + dns.name.Name(["a" * 63, "a" * 63, "a" * 63, "a" * 63]) + self.assertRaises(dns.name.NameTooLong, bad) def testConcat1(self): - n1 = dns.name.Name(['a', 'b']) - n2 = dns.name.Name(['c', 'd']) - e = dns.name.Name(['a', 'b', 'c', 'd']) + n1 = dns.name.Name(["a", "b"]) + n2 = dns.name.Name(["c", "d"]) + e = dns.name.Name(["a", "b", "c", "d"]) r = n1 + n2 self.assertEqual(r, e) def testConcat2(self): - n1 = dns.name.Name(['a', 'b']) + n1 = dns.name.Name(["a", "b"]) n2 = dns.name.Name([]) - e = dns.name.Name(['a', 'b']) + e = dns.name.Name(["a", "b"]) r = n1 + n2 self.assertEqual(r, e) def testConcat3(self): n1 = dns.name.Name([]) - n2 = dns.name.Name(['a', 'b']) - e = dns.name.Name(['a', 'b']) + n2 = dns.name.Name(["a", "b"]) + e = dns.name.Name(["a", "b"]) r = n1 + n2 self.assertEqual(r, e) def testConcat4(self): - n1 = dns.name.Name(['a', 'b', '']) + n1 = dns.name.Name(["a", "b", ""]) n2 = dns.name.Name([]) - e = dns.name.Name(['a', 'b', '']) + e = dns.name.Name(["a", "b", ""]) r = n1 + n2 self.assertEqual(r, e) def testConcat5(self): - n1 = dns.name.Name(['a', 'b']) - n2 = dns.name.Name(['c', '']) - e = dns.name.Name(['a', 'b', 'c', '']) + n1 = dns.name.Name(["a", "b"]) + n2 = dns.name.Name(["c", ""]) + e = dns.name.Name(["a", "b", "c", ""]) r = n1 + n2 self.assertEqual(r, e) def testConcat6(self): def bad(): - n1 = dns.name.Name(['a', 'b', '']) - n2 = dns.name.Name(['c']) + n1 = dns.name.Name(["a", "b", ""]) + n2 = dns.name.Name(["c"]) return n1 + n2 + self.assertRaises(dns.name.AbsoluteConcatenation, bad) def testBadEscape(self): def bad(): - n = dns.name.from_text(r'a.b\0q1.c.') + n = dns.name.from_text(r"a.b\0q1.c.") + self.assertRaises(dns.name.BadEscape, bad) def testDigestable1(self): - n = dns.name.from_text('FOO.bar') + n = dns.name.from_text("FOO.bar") d = n.to_digestable() - self.assertEqual(d, b'\x03foo\x03bar\x00') + self.assertEqual(d, b"\x03foo\x03bar\x00") def testDigestable2(self): - n1 = dns.name.from_text('FOO.bar') - n2 = dns.name.from_text('foo.BAR.') + n1 = dns.name.from_text("FOO.bar") + n2 = dns.name.from_text("foo.BAR.") d1 = n1.to_digestable() d2 = n2.to_digestable() self.assertEqual(d1, d2) def testDigestable3(self): d = dns.name.root.to_digestable() - self.assertEqual(d, b'\x00') + self.assertEqual(d, b"\x00") def testDigestable4(self): - n = dns.name.from_text('FOO.bar', None) + n = dns.name.from_text("FOO.bar", None) d = n.to_digestable(dns.name.root) - self.assertEqual(d, b'\x03foo\x03bar\x00') + self.assertEqual(d, b"\x03foo\x03bar\x00") def testBadDigestable(self): def bad(): - n = dns.name.from_text('FOO.bar', None) + n = dns.name.from_text("FOO.bar", None) n.to_digestable() + self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin, bad) def testToWire1(self): - n = dns.name.from_text('FOO.bar') + n = dns.name.from_text("FOO.bar") f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n.to_wire(f, compress) - self.assertEqual(f.getvalue(), b'\x03FOO\x03bar\x00') + self.assertEqual(f.getvalue(), b"\x03FOO\x03bar\x00") def testToWire2(self): - n = dns.name.from_text('FOO.bar') + n = dns.name.from_text("FOO.bar") f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n.to_wire(f, compress) n.to_wire(f, compress) - self.assertEqual(f.getvalue(), b'\x03FOO\x03bar\x00\xc0\x00') + self.assertEqual(f.getvalue(), b"\x03FOO\x03bar\x00\xc0\x00") def testToWire3(self): - n1 = dns.name.from_text('FOO.bar') - n2 = dns.name.from_text('foo.bar') + n1 = dns.name.from_text("FOO.bar") + n2 = dns.name.from_text("foo.bar") f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n1.to_wire(f, compress) n2.to_wire(f, compress) - self.assertEqual(f.getvalue(), b'\x03FOO\x03bar\x00\xc0\x00') + self.assertEqual(f.getvalue(), b"\x03FOO\x03bar\x00\xc0\x00") def testToWire4(self): - n1 = dns.name.from_text('FOO.bar') - n2 = dns.name.from_text('a.foo.bar') + n1 = dns.name.from_text("FOO.bar") + n2 = dns.name.from_text("a.foo.bar") f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n1.to_wire(f, compress) n2.to_wire(f, compress) - self.assertEqual(f.getvalue(), b'\x03FOO\x03bar\x00\x01\x61\xc0\x00') + self.assertEqual(f.getvalue(), b"\x03FOO\x03bar\x00\x01\x61\xc0\x00") def testToWire5(self): - n1 = dns.name.from_text('FOO.bar') - n2 = dns.name.from_text('a.foo.bar') + n1 = dns.name.from_text("FOO.bar") + n2 = dns.name.from_text("a.foo.bar") f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n1.to_wire(f, compress) n2.to_wire(f, None) - self.assertEqual(f.getvalue(), - b'\x03FOO\x03bar\x00\x01\x61\x03foo\x03bar\x00') + self.assertEqual(f.getvalue(), b"\x03FOO\x03bar\x00\x01\x61\x03foo\x03bar\x00") def testToWire6(self): - n = dns.name.from_text('FOO.bar') + n = dns.name.from_text("FOO.bar") v = n.to_wire() - self.assertEqual(v, b'\x03FOO\x03bar\x00') + self.assertEqual(v, b"\x03FOO\x03bar\x00") def testToWireRelativeNameWithOrigin(self): - n = dns.name.from_text('FOO', None) - o = dns.name.from_text('bar') + n = dns.name.from_text("FOO", None) + o = dns.name.from_text("bar") v = n.to_wire(origin=o) - self.assertEqual(v, b'\x03FOO\x03bar\x00') + self.assertEqual(v, b"\x03FOO\x03bar\x00") def testToWireRelativeNameWithoutOrigin(self): - n = dns.name.from_text('FOO', None) + n = dns.name.from_text("FOO", None) + def bad(): v = n.to_wire() + self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin, bad) def testBadToWire(self): def bad(): - n = dns.name.from_text('FOO.bar', None) + n = dns.name.from_text("FOO.bar", None) f = BytesIO() - compress = {} # type: Dict[dns.name.Name,int] + compress = {} # type: Dict[dns.name.Name,int] n.to_wire(f, compress) + self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin, bad) def testGiantCompressionTable(self): @@ -471,14 +482,14 @@ class NameTestCase(unittest.TestCase): f = BytesIO() compress = {} # type: Dict[dns.name.Name,int] # exactly 16 bytes encoded - n = dns.name.from_text('0000000000.com.') + n = dns.name.from_text("0000000000.com.") n.to_wire(f, compress) # There are now two entries in the compression table (for the full # name, and for the com. suffix. self.assertEqual(len(compress), 2) for i in range(1023): # exactly 16 bytes encoded with compression - n = dns.name.from_text(f'{i:013d}.com') + n = dns.name.from_text(f"{i:013d}.com") n.to_wire(f, compress) # There are now 1025 entries in the compression table with # the last entry at offset 16368. @@ -486,143 +497,145 @@ class NameTestCase(unittest.TestCase): self.assertEqual(compress[n], 16368) # Adding another name should not increase the size of the compression # table, as the pointer would be at offset 16384, which is too big. - n = dns.name.from_text('toobig.com.') + n = dns.name.from_text("toobig.com.") n.to_wire(f, compress) self.assertEqual(len(compress), 1025) def testSplit1(self): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") (prefix, suffix) = n.split(2) - ep = dns.name.from_text('foo', None) - es = dns.name.from_text('bar.', None) + ep = dns.name.from_text("foo", None) + es = dns.name.from_text("bar.", None) self.assertEqual(prefix, ep) self.assertEqual(suffix, es) def testSplit2(self): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") (prefix, suffix) = n.split(1) - ep = dns.name.from_text('foo.bar', None) - es = dns.name.from_text('.', None) + ep = dns.name.from_text("foo.bar", None) + es = dns.name.from_text(".", None) self.assertEqual(prefix, ep) self.assertEqual(suffix, es) def testSplit3(self): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") (prefix, suffix) = n.split(0) - ep = dns.name.from_text('foo.bar.', None) - es = dns.name.from_text('', None) + ep = dns.name.from_text("foo.bar.", None) + es = dns.name.from_text("", None) self.assertEqual(prefix, ep) self.assertEqual(suffix, es) def testSplit4(self): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") (prefix, suffix) = n.split(3) - ep = dns.name.from_text('', None) - es = dns.name.from_text('foo.bar.', None) + ep = dns.name.from_text("", None) + es = dns.name.from_text("foo.bar.", None) self.assertEqual(prefix, ep) self.assertEqual(suffix, es) def testBadSplit1(self): def bad(): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") n.split(-1) + self.assertRaises(ValueError, bad) def testBadSplit2(self): def bad(): - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") n.split(4) + self.assertRaises(ValueError, bad) def testRelativize1(self): - n = dns.name.from_text('a.foo.bar.', None) - o = dns.name.from_text('bar.', None) - e = dns.name.from_text('a.foo', None) + n = dns.name.from_text("a.foo.bar.", None) + o = dns.name.from_text("bar.", None) + e = dns.name.from_text("a.foo", None) self.assertEqual(n.relativize(o), e) def testRelativize2(self): - n = dns.name.from_text('a.foo.bar.', None) + n = dns.name.from_text("a.foo.bar.", None) o = n e = dns.name.empty self.assertEqual(n.relativize(o), e) def testRelativize3(self): - n = dns.name.from_text('a.foo.bar.', None) - o = dns.name.from_text('blaz.', None) + n = dns.name.from_text("a.foo.bar.", None) + o = dns.name.from_text("blaz.", None) e = n self.assertEqual(n.relativize(o), e) def testRelativize4(self): - n = dns.name.from_text('a.foo', None) + n = dns.name.from_text("a.foo", None) o = dns.name.root e = n self.assertEqual(n.relativize(o), e) def testDerelativize1(self): - n = dns.name.from_text('a.foo', None) - o = dns.name.from_text('bar.', None) - e = dns.name.from_text('a.foo.bar.', None) + n = dns.name.from_text("a.foo", None) + o = dns.name.from_text("bar.", None) + e = dns.name.from_text("a.foo.bar.", None) self.assertEqual(n.derelativize(o), e) def testDerelativize2(self): n = dns.name.empty - o = dns.name.from_text('a.foo.bar.', None) + o = dns.name.from_text("a.foo.bar.", None) e = o self.assertEqual(n.derelativize(o), e) def testDerelativize3(self): - n = dns.name.from_text('a.foo.bar.', None) - o = dns.name.from_text('blaz.', None) + n = dns.name.from_text("a.foo.bar.", None) + o = dns.name.from_text("blaz.", None) e = n self.assertEqual(n.derelativize(o), e) def testChooseRelativity1(self): - n = dns.name.from_text('a.foo.bar.', None) - o = dns.name.from_text('bar.', None) - e = dns.name.from_text('a.foo', None) + n = dns.name.from_text("a.foo.bar.", None) + o = dns.name.from_text("bar.", None) + e = dns.name.from_text("a.foo", None) self.assertEqual(n.choose_relativity(o, True), e) def testChooseRelativity2(self): - n = dns.name.from_text('a.foo.bar.', None) - o = dns.name.from_text('bar.', None) + n = dns.name.from_text("a.foo.bar.", None) + o = dns.name.from_text("bar.", None) e = n self.assertEqual(n.choose_relativity(o, False), e) def testChooseRelativity3(self): - n = dns.name.from_text('a.foo', None) - o = dns.name.from_text('bar.', None) - e = dns.name.from_text('a.foo.bar.', None) + n = dns.name.from_text("a.foo", None) + o = dns.name.from_text("bar.", None) + e = dns.name.from_text("a.foo.bar.", None) self.assertEqual(n.choose_relativity(o, False), e) def testChooseRelativity4(self): - n = dns.name.from_text('a.foo', None) + n = dns.name.from_text("a.foo", None) o = None e = n self.assertEqual(n.choose_relativity(o, True), e) def testChooseRelativity5(self): - n = dns.name.from_text('a.foo', None) + n = dns.name.from_text("a.foo", None) o = None e = n self.assertEqual(n.choose_relativity(o, False), e) def testChooseRelativity6(self): - n = dns.name.from_text('a.foo.', None) + n = dns.name.from_text("a.foo.", None) o = None e = n self.assertEqual(n.choose_relativity(o, True), e) def testChooseRelativity7(self): - n = dns.name.from_text('a.foo.', None) + n = dns.name.from_text("a.foo.", None) o = None e = n self.assertEqual(n.choose_relativity(o, False), e) def testFromWire1(self): - w = b'\x03foo\x00\xc0\x00' + w = b"\x03foo\x00\xc0\x00" (n1, cused1) = dns.name.from_wire(w, 0) (n2, cused2) = dns.name.from_wire(w, cused1) - en1 = dns.name.from_text('foo.') + en1 = dns.name.from_text("foo.") en2 = en1 ecused1 = 5 ecused2 = 2 @@ -632,16 +645,16 @@ class NameTestCase(unittest.TestCase): self.assertEqual(cused2, ecused2) def testFromWire2(self): - w = b'\x03foo\x00\x01a\xc0\x00\x01b\xc0\x05' + w = b"\x03foo\x00\x01a\xc0\x00\x01b\xc0\x05" current = 0 (n1, cused1) = dns.name.from_wire(w, current) current += cused1 (n2, cused2) = dns.name.from_wire(w, current) current += cused2 (n3, cused3) = dns.name.from_wire(w, current) - en1 = dns.name.from_text('foo.') - en2 = dns.name.from_text('a.foo.') - en3 = dns.name.from_text('b.a.foo.') + en1 = dns.name.from_text("foo.") + en2 = dns.name.from_text("a.foo.") + en3 = dns.name.from_text("b.a.foo.") ecused1 = 5 ecused2 = 4 ecused3 = 4 @@ -654,405 +667,445 @@ class NameTestCase(unittest.TestCase): def testBadFromWire1(self): def bad(): - w = b'\x03foo\xc0\x04' + w = b"\x03foo\xc0\x04" dns.name.from_wire(w, 0) + self.assertRaises(dns.name.BadPointer, bad) def testBadFromWire2(self): def bad(): - w = b'\x03foo\xc0\x05' + w = b"\x03foo\xc0\x05" dns.name.from_wire(w, 0) + self.assertRaises(dns.name.BadPointer, bad) def testBadFromWire3(self): def bad(): - w = b'\xbffoo' + w = b"\xbffoo" dns.name.from_wire(w, 0) + self.assertRaises(dns.name.BadLabelType, bad) def testBadFromWire4(self): def bad(): - w = b'\x41foo' + w = b"\x41foo" dns.name.from_wire(w, 0) + self.assertRaises(dns.name.BadLabelType, bad) def testParent1(self): - n = dns.name.from_text('foo.bar.') - self.assertEqual(n.parent(), dns.name.from_text('bar.')) + n = dns.name.from_text("foo.bar.") + self.assertEqual(n.parent(), dns.name.from_text("bar.")) self.assertEqual(n.parent().parent(), dns.name.root) def testParent2(self): - n = dns.name.from_text('foo.bar', None) - self.assertEqual(n.parent(), dns.name.from_text('bar', None)) + n = dns.name.from_text("foo.bar", None) + self.assertEqual(n.parent(), dns.name.from_text("bar", None)) self.assertEqual(n.parent().parent(), dns.name.empty) def testParent3(self): def bad(): n = dns.name.root n.parent() + self.assertRaises(dns.name.NoParent, bad) def testParent4(self): def bad(): n = dns.name.empty n.parent() + self.assertRaises(dns.name.NoParent, bad) def testFromUnicode1(self): - n = dns.name.from_text('foo.bar') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo.bar") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testFromUnicode2(self): - n = dns.name.from_text('foo\u1234bar.bar') - self.assertEqual(n.labels, (b'xn--foobar-r5z', b'bar', b'')) + n = dns.name.from_text("foo\u1234bar.bar") + self.assertEqual(n.labels, (b"xn--foobar-r5z", b"bar", b"")) def testFromUnicodeAlternateDot1(self): - n = dns.name.from_text('foo\u3002bar') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo\u3002bar") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testFromUnicodeAlternateDot2(self): - n = dns.name.from_text('foo\uff0ebar') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo\uff0ebar") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testFromUnicodeAlternateDot3(self): - n = dns.name.from_text('foo\uff61bar') - self.assertEqual(n.labels, (b'foo', b'bar', b'')) + n = dns.name.from_text("foo\uff61bar") + self.assertEqual(n.labels, (b"foo", b"bar", b"")) def testFromUnicodeRoot(self): - n = dns.name.from_text('.') - self.assertEqual(n.labels, (b'',)) + n = dns.name.from_text(".") + self.assertEqual(n.labels, (b"",)) def testFromUnicodeAlternateRoot1(self): - n = dns.name.from_text('\u3002') - self.assertEqual(n.labels, (b'',)) + n = dns.name.from_text("\u3002") + self.assertEqual(n.labels, (b"",)) def testFromUnicodeAlternateRoot2(self): - n = dns.name.from_text('\uff0e') - self.assertEqual(n.labels, (b'',)) + n = dns.name.from_text("\uff0e") + self.assertEqual(n.labels, (b"",)) def testFromUnicodeAlternateRoot3(self): - n = dns.name.from_text('\uff61') - self.assertEqual(n.labels, (b'', )) + n = dns.name.from_text("\uff61") + self.assertEqual(n.labels, (b"",)) def testFromUnicodeIDNA2003Explicit(self): - t = 'Königsgäßchen' + t = "Königsgäßchen" e = dns.name.from_unicode(t, idna_codec=dns.name.IDNA_2003) - self.assertEqual(str(e), 'xn--knigsgsschen-lcb0w.') + self.assertEqual(str(e), "xn--knigsgsschen-lcb0w.") def testFromUnicodeIDNA2003Default(self): - t = 'Königsgäßchen' + t = "Königsgäßchen" e = dns.name.from_unicode(t) - self.assertEqual(str(e), 'xn--knigsgsschen-lcb0w.') + self.assertEqual(str(e), "xn--knigsgsschen-lcb0w.") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testFromUnicodeIDNA2008(self): - t = 'Königsgäßchen' + t = "Königsgäßchen" + def bad(): codec = dns.name.IDNA_2008_Strict return dns.name.from_unicode(t, idna_codec=codec) + self.assertRaises(dns.name.IDNAException, bad) e1 = dns.name.from_unicode(t, idna_codec=dns.name.IDNA_2008) - self.assertEqual(str(e1), 'xn--knigsgchen-b4a3dun.') + self.assertEqual(str(e1), "xn--knigsgchen-b4a3dun.") c2 = dns.name.IDNA_2008_Transitional e2 = dns.name.from_unicode(t, idna_codec=c2) - self.assertEqual(str(e2), 'xn--knigsgsschen-lcb0w.') + self.assertEqual(str(e2), "xn--knigsgsschen-lcb0w.") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testFromUnicodeIDNA2008Mixed(self): # the IDN rules for names are very restrictive, disallowing # practical names like '_sip._tcp.Königsgäßchen'. Dnspython # has a "practical" mode which permits labels which are purely # ASCII to go straight through, and thus not invalid useful # things in the real world. - t = '_sip._tcp.Königsgäßchen' + t = "_sip._tcp.Königsgäßchen" + def bad1(): codec = dns.name.IDNA_2008_Strict return dns.name.from_unicode(t, idna_codec=codec) + def bad2(): codec = dns.name.IDNA_2008_UTS_46 return dns.name.from_unicode(t, idna_codec=codec) + def bad3(): codec = dns.name.IDNA_2008_Transitional return dns.name.from_unicode(t, idna_codec=codec) + self.assertRaises(dns.name.IDNAException, bad1) self.assertRaises(dns.name.IDNAException, bad2) self.assertRaises(dns.name.IDNAException, bad3) - e = dns.name.from_unicode(t, - idna_codec=dns.name.IDNA_2008_Practical) - self.assertEqual(str(e), '_sip._tcp.xn--knigsgchen-b4a3dun.') + e = dns.name.from_unicode(t, idna_codec=dns.name.IDNA_2008_Practical) + self.assertEqual(str(e), "_sip._tcp.xn--knigsgchen-b4a3dun.") def testFromUnicodeEscapes(self): - n = dns.name.from_unicode(r'\097.\098.\099.') + n = dns.name.from_unicode(r"\097.\098.\099.") t = n.to_unicode() - self.assertEqual(t, 'a.b.c.') + self.assertEqual(t, "a.b.c.") def testToUnicode1(self): - n = dns.name.from_text('foo.bar') + n = dns.name.from_text("foo.bar") s = n.to_unicode() - self.assertEqual(s, 'foo.bar.') + self.assertEqual(s, "foo.bar.") def testToUnicode2(self): - n = dns.name.from_text('foo\u1234bar.bar') + n = dns.name.from_text("foo\u1234bar.bar") s = n.to_unicode() - self.assertEqual(s, 'foo\u1234bar.bar.') + self.assertEqual(s, "foo\u1234bar.bar.") def testToUnicode3(self): - n = dns.name.from_text('foo.bar') + n = dns.name.from_text("foo.bar") s = n.to_unicode() - self.assertEqual(s, 'foo.bar.') + self.assertEqual(s, "foo.bar.") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testToUnicode4(self): - n = dns.name.from_text('ドメイン.テスト', - idna_codec=dns.name.IDNA_2008) + n = dns.name.from_text("ドメイン.テスト", idna_codec=dns.name.IDNA_2008) s = n.to_unicode() - self.assertEqual(str(n), 'xn--eckwd4c7c.xn--zckzah.') - self.assertEqual(s, 'ドメイン.テスト.') + self.assertEqual(str(n), "xn--eckwd4c7c.xn--zckzah.") + self.assertEqual(s, "ドメイン.テスト.") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testToUnicode5(self): # Exercise UTS 46 remapping in decode. This doesn't normally happen # as you can see from us having to instantiate the codec as # transitional with strict decoding, not one of our usual choices. codec = dns.name.IDNA2008Codec(True, True, False, True) - n = dns.name.from_text('xn--gro-7ka.com') - self.assertEqual(n.to_unicode(idna_codec=codec), - 'gross.com.') + n = dns.name.from_text("xn--gro-7ka.com") + self.assertEqual(n.to_unicode(idna_codec=codec), "gross.com.") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testToUnicode6(self): # Test strict 2008 decoding without UTS 46 - n = dns.name.from_text('xn--gro-7ka.com') - self.assertEqual(n.to_unicode(idna_codec=dns.name.IDNA_2008_Strict), - 'groß.com.') + n = dns.name.from_text("xn--gro-7ka.com") + self.assertEqual( + n.to_unicode(idna_codec=dns.name.IDNA_2008_Strict), "groß.com." + ) def testDefaultDecodeIsJustPunycode(self): # groß.com. in IDNA2008 form, pre-encoded. - n = dns.name.from_text('xn--gro-7ka.com') + n = dns.name.from_text("xn--gro-7ka.com") # output using default codec which just decodes the punycode and # doesn't test for IDNA2003 or IDNA2008. - self.assertEqual(n.to_unicode(), 'groß.com.') + self.assertEqual(n.to_unicode(), "groß.com.") def testStrictINDA2003Decode(self): # groß.com. in IDNA2008 form, pre-encoded. - n = dns.name.from_text('xn--gro-7ka.com') + n = dns.name.from_text("xn--gro-7ka.com") + def bad(): # This throws in IDNA2003 because it doesn't "round trip". n.to_unicode(idna_codec=dns.name.IDNA_2003_Strict) + self.assertRaises(dns.name.IDNAException, bad) def testINDA2008Decode(self): # groß.com. in IDNA2008 form, pre-encoded. - n = dns.name.from_text('xn--gro-7ka.com') - self.assertEqual(n.to_unicode(idna_codec=dns.name.IDNA_2008), - 'groß.com.') + n = dns.name.from_text("xn--gro-7ka.com") + self.assertEqual(n.to_unicode(idna_codec=dns.name.IDNA_2008), "groß.com.") def testToUnicodeOmitFinalDot(self): # groß.com. in IDNA2008 form, pre-encoded. - n = dns.name.from_text('xn--gro-7ka.com') - self.assertEqual(n.to_unicode(True, dns.name.IDNA_2008), - 'groß.com') + n = dns.name.from_text("xn--gro-7ka.com") + self.assertEqual(n.to_unicode(True, dns.name.IDNA_2008), "groß.com") def testIDNA2003Misc(self): - self.assertEqual(dns.name.IDNA_2003.encode(''), b'') - self.assertRaises(dns.name.LabelTooLong, - lambda: dns.name.IDNA_2003.encode('x' * 64)) - - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + self.assertEqual(dns.name.IDNA_2003.encode(""), b"") + self.assertRaises( + dns.name.LabelTooLong, lambda: dns.name.IDNA_2003.encode("x" * 64) + ) + + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testIDNA2008Misc(self): - self.assertEqual(dns.name.IDNA_2008.encode(''), b'') - self.assertRaises(dns.name.LabelTooLong, - lambda: dns.name.IDNA_2008.encode('x' * 64)) - self.assertRaises(dns.name.LabelTooLong, - lambda: dns.name.IDNA_2008.encode('groß' + 'x' * 60)) + self.assertEqual(dns.name.IDNA_2008.encode(""), b"") + self.assertRaises( + dns.name.LabelTooLong, lambda: dns.name.IDNA_2008.encode("x" * 64) + ) + self.assertRaises( + dns.name.LabelTooLong, lambda: dns.name.IDNA_2008.encode("groß" + "x" * 60) + ) def testReverseIPv4(self): - e = dns.name.from_text('1.0.0.127.in-addr.arpa.') - n = dns.reversename.from_address('127.0.0.1') + e = dns.name.from_text("1.0.0.127.in-addr.arpa.") + n = dns.reversename.from_address("127.0.0.1") self.assertEqual(e, n) def testReverseIPv6(self): - e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.') - n = dns.reversename.from_address('::1') + e = dns.name.from_text( + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa." + ) + n = dns.reversename.from_address("::1") self.assertEqual(e, n) def testReverseIPv6MappedIpv4(self): - e = dns.name.from_text('1.0.0.127.in-addr.arpa.') - n = dns.reversename.from_address('::ffff:127.0.0.1') + e = dns.name.from_text("1.0.0.127.in-addr.arpa.") + n = dns.reversename.from_address("::ffff:127.0.0.1") self.assertEqual(e, n) def testBadReverseIPv4(self): def bad(): - dns.reversename.from_address('127.0.foo.1') + dns.reversename.from_address("127.0.foo.1") + self.assertRaises(dns.exception.SyntaxError, bad) def testBadReverseIPv6(self): def bad(): - dns.reversename.from_address('::1::1') + dns.reversename.from_address("::1::1") + self.assertRaises(dns.exception.SyntaxError, bad) def testReverseIPv4AlternateOrigin(self): - e = dns.name.from_text('1.0.0.127.foo.bar.') - origin = dns.name.from_text('foo.bar') - n = dns.reversename.from_address('127.0.0.1', v4_origin=origin) + e = dns.name.from_text("1.0.0.127.foo.bar.") + origin = dns.name.from_text("foo.bar") + n = dns.reversename.from_address("127.0.0.1", v4_origin=origin) self.assertEqual(e, n) def testReverseIPv6AlternateOrigin(self): - e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar.') - origin = dns.name.from_text('foo.bar') - n = dns.reversename.from_address('::1', v6_origin=origin) + e = dns.name.from_text( + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar." + ) + origin = dns.name.from_text("foo.bar") + n = dns.reversename.from_address("::1", v6_origin=origin) self.assertEqual(e, n) def testForwardIPv4(self): - n = dns.name.from_text('1.0.0.127.in-addr.arpa.') - e = '127.0.0.1' + n = dns.name.from_text("1.0.0.127.in-addr.arpa.") + e = "127.0.0.1" text = dns.reversename.to_address(n) self.assertEqual(text, e) def testForwardIPv6(self): - n = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.') - e = '::1' + n = dns.name.from_text( + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa." + ) + e = "::1" text = dns.reversename.to_address(n) self.assertEqual(text, e) def testForwardIPv4AlternateOrigin(self): - n = dns.name.from_text('1.0.0.127.foo.bar.') - e = '127.0.0.1' - origin = dns.name.from_text('foo.bar') + n = dns.name.from_text("1.0.0.127.foo.bar.") + e = "127.0.0.1" + origin = dns.name.from_text("foo.bar") text = dns.reversename.to_address(n, v4_origin=origin) self.assertEqual(text, e) def testForwardIPv6AlternateOrigin(self): - n = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar.') - e = '::1' - origin = dns.name.from_text('foo.bar') + n = dns.name.from_text( + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar." + ) + e = "::1" + origin = dns.name.from_text("foo.bar") text = dns.reversename.to_address(n, v6_origin=origin) self.assertEqual(text, e) def testUnknownReverseOrigin(self): - n = dns.name.from_text('1.2.3.4.unknown.') + n = dns.name.from_text("1.2.3.4.unknown.") with self.assertRaises(dns.exception.SyntaxError): dns.reversename.to_address(n) def testE164ToEnum(self): - text = '+1 650 555 1212' - e = dns.name.from_text('2.1.2.1.5.5.5.0.5.6.1.e164.arpa.') + text = "+1 650 555 1212" + e = dns.name.from_text("2.1.2.1.5.5.5.0.5.6.1.e164.arpa.") n = dns.e164.from_e164(text) self.assertEqual(n, e) def testEnumToE164(self): - n = dns.name.from_text('2.1.2.1.5.5.5.0.5.6.1.e164.arpa.') - e = '+16505551212' + n = dns.name.from_text("2.1.2.1.5.5.5.0.5.6.1.e164.arpa.") + e = "+16505551212" text = dns.e164.to_e164(n) self.assertEqual(text, e) def testBadEnumToE164(self): - n = dns.name.from_text('2.1.2.q.5.5.5.0.5.6.1.e164.arpa.') - self.assertRaises(dns.exception.SyntaxError, - lambda: dns.e164.to_e164(n)) + n = dns.name.from_text("2.1.2.q.5.5.5.0.5.6.1.e164.arpa.") + self.assertRaises(dns.exception.SyntaxError, lambda: dns.e164.to_e164(n)) def test_incompatible_relations(self): - n1 = dns.name.from_text('example') - n2 = 'abc' + n1 = dns.name.from_text("example") + n2 = "abc" for oper in [operator.lt, operator.le, operator.ge, operator.gt]: self.assertRaises(TypeError, lambda: oper(n1, n2)) self.assertFalse(n1 == n2) self.assertTrue(n1 != n2) def testFromUnicodeSimpleEscape(self): - n = dns.name.from_unicode(r'a.\b') - e = dns.name.from_unicode(r'a.b') + n = dns.name.from_unicode(r"a.\b") + e = dns.name.from_unicode(r"a.b") self.assertEqual(n, e) def testFromUnicodeBadEscape(self): def bad1(): - n = dns.name.from_unicode(r'a.b\0q1.c.') + n = dns.name.from_unicode(r"a.b\0q1.c.") + self.assertRaises(dns.name.BadEscape, bad1) + def bad2(): - n = dns.name.from_unicode(r'a.b\0') + n = dns.name.from_unicode(r"a.b\0") + self.assertRaises(dns.name.BadEscape, bad2) def testFromUnicodeNotString(self): def bad(): - dns.name.from_unicode(b'123') # type: ignore + dns.name.from_unicode(b"123") # type: ignore + self.assertRaises(ValueError, bad) def testFromUnicodeBadOrigin(self): def bad(): - dns.name.from_unicode('example', 123) # type: ignore + dns.name.from_unicode("example", 123) # type: ignore + self.assertRaises(ValueError, bad) def testFromUnicodeEmptyLabel(self): def bad(): - dns.name.from_unicode('a..b.example') + dns.name.from_unicode("a..b.example") + self.assertRaises(dns.name.EmptyLabel, bad) def testFromUnicodeEmptyName(self): - self.assertEqual(dns.name.from_unicode('@', None), dns.name.empty) + self.assertEqual(dns.name.from_unicode("@", None), dns.name.empty) def testFromTextNotString(self): def bad(): dns.name.from_text(123) # type: ignore + self.assertRaises(ValueError, bad) def testFromTextBadOrigin(self): def bad(): - dns.name.from_text('example', 123) # type: ignore + dns.name.from_text("example", 123) # type: ignore + self.assertRaises(ValueError, bad) def testFromWireNotBytes(self): def bad(): dns.name.from_wire(123, 0) # type: ignore + self.assertRaises(ValueError, bad) def testBadPunycode(self): c = dns.name.IDNACodec() with self.assertRaises(dns.name.IDNAException): - c.decode(b'xn--0000h') + c.decode(b"xn--0000h") def testRootLabel2003StrictDecode(self): c = dns.name.IDNA_2003_Strict - self.assertEqual(c.decode(b''), '') + self.assertEqual(c.decode(b""), "") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testRootLabel2008StrictDecode(self): c = dns.name.IDNA_2008_Strict - self.assertEqual(c.decode(b''), '') + self.assertEqual(c.decode(b""), "") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testCodecNotFoundRaises(self): dns.name.have_idna_2008 = False with self.assertRaises(dns.name.NoIDNA2008): c = dns.name.IDNA2008Codec() - c.encode('Königsgäßchen') + c.encode("Königsgäßchen") with self.assertRaises(dns.name.NoIDNA2008): c = dns.name.IDNA2008Codec(strict_decode=True) - c.decode(b'xn--eckwd4c7c.xn--zckzah.') + c.decode(b"xn--eckwd4c7c.xn--zckzah.") dns.name.have_idna_2008 = True - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testBadPunycodeStrict2008(self): c = dns.name.IDNA2008Codec(strict_decode=True) with self.assertRaises(dns.name.IDNAException): - c.decode(b'xn--0000h') + c.decode(b"xn--0000h") def testRelativizeSubtractionSyntax(self): - n = dns.name.from_text('foo.example.') - o = dns.name.from_text('example.') - e = dns.name.from_text('foo', None) + n = dns.name.from_text("foo.example.") + o = dns.name.from_text("example.") + e = dns.name.from_text("foo", None) self.assertEqual(n - o, e) def testCopy(self): - n1 = dns.name.from_text('foo.example.') + n1 = dns.name.from_text("foo.example.") n2 = copy.copy(n1) self.assertTrue(n1 is not n2) # the Name constructor always copies labels, so there is no @@ -1063,7 +1116,7 @@ class NameTestCase(unittest.TestCase): self.assertTrue(l is n2[i]) def testDeepCopy(self): - n1 = dns.name.from_text('foo.example.') + n1 = dns.name.from_text("foo.example.") n2 = copy.deepcopy(n1) self.assertTrue(n1 is not n2) self.assertTrue(n1.labels is not n2.labels) @@ -1072,19 +1125,20 @@ class NameTestCase(unittest.TestCase): self.assertTrue(l is n2[i]) def testNoAttributeDeletion(self): - n = dns.name.from_text('foo.example.') + n = dns.name.from_text("foo.example.") with self.assertRaises(TypeError): del n.labels def testUnicodeEscapify(self): - n = dns.name.from_unicode('Königsgäßchen;\ttext') - self.assertEqual(n.to_unicode(), 'königsgässchen\\;\\009text.') + n = dns.name.from_unicode("Königsgäßchen;\ttext") + self.assertEqual(n.to_unicode(), "königsgässchen\\;\\009text.") def test_pickle(self): - n1 = dns.name.from_text('foo.example') + n1 = dns.name.from_text("foo.example") p = pickle.dumps(n1) n2 = pickle.loads(p) self.assertEqual(n1, n2) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_namedict.py b/tests/test_namedict.py index 73097a6a..c5563e55 100644 --- a/tests/test_namedict.py +++ b/tests/test_namedict.py @@ -20,17 +20,17 @@ import unittest import dns.name import dns.namedict -class NameTestCase(unittest.TestCase): +class NameTestCase(unittest.TestCase): def setUp(self): self.ndict = dns.namedict.NameDict() - n1 = dns.name.from_text('foo.bar.') - n2 = dns.name.from_text('bar.') + n1 = dns.name.from_text("foo.bar.") + n2 = dns.name.from_text("bar.") self.ndict[n1] = 1 self.ndict[n2] = 2 self.rndict = dns.namedict.NameDict() - n1 = dns.name.from_text('foo.bar', None) - n2 = dns.name.from_text('bar', None) + n1 = dns.name.from_text("foo.bar", None) + n2 = dns.name.from_text("bar", None) self.rndict[n1] = 1 self.rndict[n2] = 2 @@ -38,94 +38,97 @@ class NameTestCase(unittest.TestCase): self.assertEqual(self.ndict.max_depth, 3) def testLookup1(self): - k = dns.name.from_text('foo.bar.') + k = dns.name.from_text("foo.bar.") self.assertEqual(self.ndict[k], 1) def testLookup2(self): - k = dns.name.from_text('foo.bar.') + k = dns.name.from_text("foo.bar.") self.assertEqual(self.ndict.get_deepest_match(k)[1], 1) def testLookup3(self): - k = dns.name.from_text('a.b.c.foo.bar.') + k = dns.name.from_text("a.b.c.foo.bar.") self.assertEqual(self.ndict.get_deepest_match(k)[1], 1) def testLookup4(self): - k = dns.name.from_text('a.b.c.bar.') + k = dns.name.from_text("a.b.c.bar.") self.assertEqual(self.ndict.get_deepest_match(k)[1], 2) def testLookup5(self): def bad(): - n = dns.name.from_text('a.b.c.') + n = dns.name.from_text("a.b.c.") self.ndict.get_deepest_match(n) + self.assertRaises(KeyError, bad) def testLookup6(self): def bad(): self.ndict.get_deepest_match(dns.name.empty) + self.assertRaises(KeyError, bad) def testLookup7(self): self.ndict[dns.name.empty] = 100 - n = dns.name.from_text('a.b.c.') + n = dns.name.from_text("a.b.c.") v = self.ndict.get_deepest_match(n)[1] self.assertEqual(v, 100) def testLookup8(self): def bad(): - self.ndict['foo'] = 100 + self.ndict["foo"] = 100 + self.assertRaises(ValueError, bad) def testRelDepth(self): self.assertEqual(self.rndict.max_depth, 2) def testRelLookup1(self): - k = dns.name.from_text('foo.bar', None) + k = dns.name.from_text("foo.bar", None) self.assertEqual(self.rndict[k], 1) def testRelLookup2(self): - k = dns.name.from_text('foo.bar', None) + k = dns.name.from_text("foo.bar", None) self.assertEqual(self.rndict.get_deepest_match(k)[1], 1) def testRelLookup3(self): - k = dns.name.from_text('a.b.c.foo.bar', None) + k = dns.name.from_text("a.b.c.foo.bar", None) self.assertEqual(self.rndict.get_deepest_match(k)[1], 1) def testRelLookup4(self): - k = dns.name.from_text('a.b.c.bar', None) + k = dns.name.from_text("a.b.c.bar", None) self.assertEqual(self.rndict.get_deepest_match(k)[1], 2) def testRelLookup7(self): self.rndict[dns.name.empty] = 100 - n = dns.name.from_text('a.b.c', None) + n = dns.name.from_text("a.b.c", None) v = self.rndict.get_deepest_match(n)[1] self.assertEqual(v, 100) def test_max_depth_increases(self): - n = dns.name.from_text('a.foo.bar.') + n = dns.name.from_text("a.foo.bar.") self.assertEqual(self.ndict.max_depth, 3) self.ndict[n] = 1 self.assertEqual(self.ndict.max_depth, 4) def test_delete_no_max_depth_change(self): self.assertEqual(self.ndict.max_depth, 3) - n = dns.name.from_text('bar.') + n = dns.name.from_text("bar.") del self.ndict[n] self.assertEqual(self.ndict.max_depth, 3) self.assertEqual(self.ndict.get(n), None) def test_delete_max_depth_changes(self): self.assertEqual(self.ndict.max_depth, 3) - n = dns.name.from_text('foo.bar.') + n = dns.name.from_text("foo.bar.") del self.ndict[n] self.assertEqual(self.ndict.max_depth, 2) self.assertEqual(self.ndict.get(n), None) def test_delete_multiple_max_depth_changes(self): self.assertEqual(self.ndict.max_depth, 3) - nr = dns.name.from_text('roo.') + nr = dns.name.from_text("roo.") self.ndict[nr] = 1 - nf = dns.name.from_text('foo.bar.') - nb = dns.name.from_text('bar.bar.') + nf = dns.name.from_text("foo.bar.") + nb = dns.name.from_text("bar.bar.") self.ndict[nb] = 1 self.assertEqual(self.ndict.max_depth, 3) self.assertEqual(self.ndict.max_depth_items, 2) @@ -139,8 +142,8 @@ class NameTestCase(unittest.TestCase): self.assertEqual(self.ndict.get(nb), None) def test_iter(self): - nf = dns.name.from_text('foo.bar.') - nb = dns.name.from_text('bar.') + nf = dns.name.from_text("foo.bar.") + nb = dns.name.from_text("bar.") keys = set([x for x in self.ndict]) self.assertEqual(len(keys), 2) self.assertTrue(nf in keys) @@ -150,12 +153,13 @@ class NameTestCase(unittest.TestCase): self.assertEqual(len(self.ndict), 2) def test_haskey(self): - nf = dns.name.from_text('foo.bar.') - nb = dns.name.from_text('bar.') - nx = dns.name.from_text('x.') + nf = dns.name.from_text("foo.bar.") + nb = dns.name.from_text("bar.") + nx = dns.name.from_text("x.") self.assertTrue(self.ndict.has_key(nf)) self.assertTrue(self.ndict.has_key(nb)) self.assertFalse(self.ndict.has_key(nx)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_nsec3.py b/tests/test_nsec3.py index bf7d1151..dd86f2c4 100644 --- a/tests/test_nsec3.py +++ b/tests/test_nsec3.py @@ -24,25 +24,32 @@ import dns.rdatatype import dns.rdtypes.ANY.TXT import dns.ttl + class NSEC3TestCase(unittest.TestCase): def test_NSEC3_bitmap(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NSEC3, - u"1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715 A CAA TYPE65534") - bitmap = bytearray(b'\0' * 32) + rdata = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.NSEC3, + "1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715 A CAA TYPE65534", + ) + bitmap = bytearray(b"\0" * 32) bitmap[31] = bitmap[31] | 2 - self.assertEqual(rdata.windows, ((0, b'@'), - (1, b'@'), # CAA = 257 - (255, bitmap) - )) + self.assertEqual( + rdata.windows, ((0, b"@"), (1, b"@"), (255, bitmap)) # CAA = 257 + ) def test_NSEC3_bad_bitmaps(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NSEC3, - u"1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715 A CAA") + rdata = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.NSEC3, + "1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715 A CAA", + ) with self.assertRaises(dns.exception.FormError): copy = bytearray(rdata.to_wire()) copy[-3] = 0 - dns.rdata.from_wire('IN', 'NSEC3', copy, 0, len(copy)) + dns.rdata.from_wire("IN", "NSEC3", copy, 0, len(copy)) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_nsec3_hash.py b/tests/test_nsec3_hash.py index f7c43372..8cb6792a 100644 --- a/tests/test_nsec3_hash.py +++ b/tests/test_nsec3_hash.py @@ -56,10 +56,10 @@ class NSEC3Hash(unittest.TestCase): "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", - dnssec.NSEC3Hash.SHA1 + dnssec.NSEC3Hash.SHA1, ), ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "SHA1"), - ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1") + ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1"), ] def test_hash_function(self): diff --git a/tests/test_ntoaaton.py b/tests/test_ntoaaton.py index 7e30bce2..2468486b 100644 --- a/tests/test_ntoaaton.py +++ b/tests/test_ntoaaton.py @@ -30,166 +30,181 @@ ntoa4 = dns.ipv4.inet_ntoa aton6 = dns.ipv6.inet_aton ntoa6 = dns.ipv6.inet_ntoa -v4_bad_addrs = ['256.1.1.1', '1.1.1', '1.1.1.1.1', - '+1.1.1.1', '1.1.1.1+', '1..2.3.4', '.1.2.3.4', - '1.2.3.4.'] +v4_bad_addrs = [ + "256.1.1.1", + "1.1.1", + "1.1.1.1.1", + "+1.1.1.1", + "1.1.1.1+", + "1..2.3.4", + ".1.2.3.4", + "1.2.3.4.", +] -class NtoAAtoNTestCase(unittest.TestCase): +class NtoAAtoNTestCase(unittest.TestCase): def test_aton1(self): - a = aton6('::') - self.assertEqual(a, b'\x00' * 16) + a = aton6("::") + self.assertEqual(a, b"\x00" * 16) def test_aton2(self): - a = aton6('::1') - self.assertEqual(a, b'\x00' * 15 + b'\x01') + a = aton6("::1") + self.assertEqual(a, b"\x00" * 15 + b"\x01") def test_aton3(self): - a = aton6('::10.0.0.1') - self.assertEqual(a, b'\x00' * 12 + b'\x0a\x00\x00\x01') + a = aton6("::10.0.0.1") + self.assertEqual(a, b"\x00" * 12 + b"\x0a\x00\x00\x01") def test_aton4(self): - a = aton6('abcd::dcba') - self.assertEqual(a, b'\xab\xcd' + b'\x00' * 12 + b'\xdc\xba') + a = aton6("abcd::dcba") + self.assertEqual(a, b"\xab\xcd" + b"\x00" * 12 + b"\xdc\xba") def test_aton5(self): - a = aton6('1:2:3:4:5:6:7:8') - self.assertEqual(a, - binascii.unhexlify(b'00010002000300040005000600070008')) + a = aton6("1:2:3:4:5:6:7:8") + self.assertEqual(a, binascii.unhexlify(b"00010002000300040005000600070008")) def test_bad_aton1(self): def bad(): - aton6('abcd:dcba') + aton6("abcd:dcba") + self.assertRaises(dns.exception.SyntaxError, bad) def test_bad_aton2(self): def bad(): - aton6('abcd::dcba::1') + aton6("abcd::dcba::1") + self.assertRaises(dns.exception.SyntaxError, bad) def test_bad_aton3(self): def bad(): - aton6('1:2:3:4:5:6:7:8:9') + aton6("1:2:3:4:5:6:7:8:9") + self.assertRaises(dns.exception.SyntaxError, bad) def test_bad_aton4(self): def bad(): - aton4('001.002.003.004') + aton4("001.002.003.004") + self.assertRaises(dns.exception.SyntaxError, bad) def test_aton6(self): - a = aton6('::') - self.assertEqual(a, b'\x00' * 16) + a = aton6("::") + self.assertEqual(a, b"\x00" * 16) def test_aton7(self): - a = aton6('::1') - self.assertEqual(a, b'\x00' * 15 + b'\x01') + a = aton6("::1") + self.assertEqual(a, b"\x00" * 15 + b"\x01") def test_aton8(self): - a = aton6('::10.0.0.1') - self.assertEqual(a, b'\x00' * 12 + b'\x0a\x00\x00\x01') + a = aton6("::10.0.0.1") + self.assertEqual(a, b"\x00" * 12 + b"\x0a\x00\x00\x01") def test_aton9(self): - a = aton6('abcd::dcba') - self.assertEqual(a, b'\xab\xcd' + b'\x00' * 12 + b'\xdc\xba') + a = aton6("abcd::dcba") + self.assertEqual(a, b"\xab\xcd" + b"\x00" * 12 + b"\xdc\xba") def test_ntoa1(self): - b = binascii.unhexlify(b'00010002000300040005000600070008') + b = binascii.unhexlify(b"00010002000300040005000600070008") t = ntoa6(b) - self.assertEqual(t, '1:2:3:4:5:6:7:8') + self.assertEqual(t, "1:2:3:4:5:6:7:8") def test_ntoa2(self): - b = b'\x00' * 16 + b = b"\x00" * 16 t = ntoa6(b) - self.assertEqual(t, '::') + self.assertEqual(t, "::") def test_ntoa3(self): - b = b'\x00' * 15 + b'\x01' + b = b"\x00" * 15 + b"\x01" t = ntoa6(b) - self.assertEqual(t, '::1') + self.assertEqual(t, "::1") def test_ntoa4(self): - b = b'\x80' + b'\x00' * 15 + b = b"\x80" + b"\x00" * 15 t = ntoa6(b) - self.assertEqual(t, '8000::') + self.assertEqual(t, "8000::") def test_ntoa5(self): - b = b'\x01\xcd' + b'\x00' * 12 + b'\x03\xef' + b = b"\x01\xcd" + b"\x00" * 12 + b"\x03\xef" t = ntoa6(b) - self.assertEqual(t, '1cd::3ef') + self.assertEqual(t, "1cd::3ef") def test_ntoa6(self): - b = binascii.unhexlify(b'ffff00000000ffff000000000000ffff') + b = binascii.unhexlify(b"ffff00000000ffff000000000000ffff") t = ntoa6(b) - self.assertEqual(t, 'ffff:0:0:ffff::ffff') + self.assertEqual(t, "ffff:0:0:ffff::ffff") def test_ntoa7(self): - b = binascii.unhexlify(b'00000000ffff000000000000ffffffff') + b = binascii.unhexlify(b"00000000ffff000000000000ffffffff") t = ntoa6(b) - self.assertEqual(t, '0:0:ffff::ffff:ffff') + self.assertEqual(t, "0:0:ffff::ffff:ffff") def test_ntoa8(self): - b = binascii.unhexlify(b'ffff0000ffff00000000ffff00000000') + b = binascii.unhexlify(b"ffff0000ffff00000000ffff00000000") t = ntoa6(b) - self.assertEqual(t, 'ffff:0:ffff::ffff:0:0') + self.assertEqual(t, "ffff:0:ffff::ffff:0:0") def test_ntoa9(self): - b = binascii.unhexlify(b'0000000000000000000000000a000001') + b = binascii.unhexlify(b"0000000000000000000000000a000001") t = ntoa6(b) - self.assertEqual(t, '::10.0.0.1') + self.assertEqual(t, "::10.0.0.1") def test_ntoa10(self): - b = binascii.unhexlify(b'0000000000000000000000010a000001') + b = binascii.unhexlify(b"0000000000000000000000010a000001") t = ntoa6(b) - self.assertEqual(t, '::1:a00:1') + self.assertEqual(t, "::1:a00:1") def test_ntoa11(self): - b = binascii.unhexlify(b'00000000000000000000ffff0a000001') + b = binascii.unhexlify(b"00000000000000000000ffff0a000001") t = ntoa6(b) - self.assertEqual(t, '::ffff:10.0.0.1') + self.assertEqual(t, "::ffff:10.0.0.1") def test_ntoa12(self): - b = binascii.unhexlify(b'000000000000000000000000ffffffff') + b = binascii.unhexlify(b"000000000000000000000000ffffffff") t = ntoa6(b) - self.assertEqual(t, '::255.255.255.255') + self.assertEqual(t, "::255.255.255.255") def test_ntoa13(self): - b = binascii.unhexlify(b'00000000000000000000ffffffffffff') + b = binascii.unhexlify(b"00000000000000000000ffffffffffff") t = ntoa6(b) - self.assertEqual(t, '::ffff:255.255.255.255') + self.assertEqual(t, "::ffff:255.255.255.255") def test_ntoa14(self): - b = binascii.unhexlify(b'0000000000000000000000000001ffff') + b = binascii.unhexlify(b"0000000000000000000000000001ffff") t = ntoa6(b) - self.assertEqual(t, '::0.1.255.255') + self.assertEqual(t, "::0.1.255.255") def test_ntoa15(self): # This exercises the current_len > best_len branch in the <= case. - b = binascii.unhexlify(b'0000ffff00000000ffff00000000ffff') + b = binascii.unhexlify(b"0000ffff00000000ffff00000000ffff") t = ntoa6(b) - self.assertEqual(t, '0:ffff::ffff:0:0:ffff') + self.assertEqual(t, "0:ffff::ffff:0:0:ffff") def test_bad_ntoa1(self): def bad(): - ntoa6(b'') + ntoa6(b"") + self.assertRaises(ValueError, bad) def test_bad_ntoa2(self): def bad(): - ntoa6(b'\x00' * 17) + ntoa6(b"\x00" * 17) + self.assertRaises(ValueError, bad) def test_bad_ntoa3(self): def bad(): - ntoa4(b'\x00' * 5) + ntoa4(b"\x00" * 5) + # Ideally we'd have been consistent and raised ValueError as # we do for IPv6, but oh well! self.assertRaises(dns.exception.SyntaxError, bad) def test_good_v4_aton(self): - pairs = [('1.2.3.4', b'\x01\x02\x03\x04'), - ('255.255.255.255', b'\xff\xff\xff\xff'), - ('0.0.0.0', b'\x00\x00\x00\x00')] + pairs = [ + ("1.2.3.4", b"\x01\x02\x03\x04"), + ("255.255.255.255", b"\xff\xff\xff\xff"), + ("0.0.0.0", b"\x00\x00\x00\x00"), + ] for (t, b) in pairs: b1 = aton4(t) t1 = ntoa4(b1) @@ -200,43 +215,47 @@ class NtoAAtoNTestCase(unittest.TestCase): def make_bad(a): def bad(): return aton4(a) + return bad + for addr in v4_bad_addrs: self.assertRaises(dns.exception.SyntaxError, make_bad(addr)) def test_bad_v6_aton(self): - addrs = ['+::0', '0::0::', '::0::', '1:2:3:4:5:6:7:8:9', - ':::::::'] - embedded = ['::' + x for x in v4_bad_addrs] + addrs = ["+::0", "0::0::", "::0::", "1:2:3:4:5:6:7:8:9", ":::::::"] + embedded = ["::" + x for x in v4_bad_addrs] addrs.extend(embedded) + def make_bad(a): def bad(): x = aton6(a) + return bad + for addr in addrs: self.assertRaises(dns.exception.SyntaxError, make_bad(addr)) def test_rfc5952_section_4_2_2(self): - addr = '2001:db8:0:1:1:1:1:1' + addr = "2001:db8:0:1:1:1:1:1" b1 = aton6(addr) t1 = ntoa6(b1) self.assertEqual(t1, addr) def test_is_mapped(self): - t1 = '2001:db8:0:1:1:1:1:1' - t2 = '::ffff:127.0.0.1' - t3 = '1::ffff:127.0.0.1' + t1 = "2001:db8:0:1:1:1:1:1" + t2 = "::ffff:127.0.0.1" + t3 = "1::ffff:127.0.0.1" self.assertFalse(dns.ipv6.is_mapped(aton6(t1))) self.assertTrue(dns.ipv6.is_mapped(aton6(t2))) self.assertFalse(dns.ipv6.is_mapped(aton6(t3))) def test_is_multicast(self): - t1 = '223.0.0.1' - t2 = '240.0.0.1' - t3 = '224.0.0.1' - t4 = '239.0.0.1' - t5 = 'fe00::1' - t6 = 'ff00::1' + t1 = "223.0.0.1" + t2 = "240.0.0.1" + t3 = "224.0.0.1" + t4 = "239.0.0.1" + t5 = "fe00::1" + t6 = "ff00::1" self.assertFalse(dns.inet.is_multicast(t1)) self.assertFalse(dns.inet.is_multicast(t2)) self.assertTrue(dns.inet.is_multicast(t3)) @@ -246,52 +265,58 @@ class NtoAAtoNTestCase(unittest.TestCase): def test_is_multicast_bad_input(self): def bad(): - dns.inet.is_multicast('hello world') + dns.inet.is_multicast("hello world") + self.assertRaises(ValueError, bad) def test_ignore_scope(self): - t1 = 'fe80::1%lo0' - t2 = 'fe80::1' + t1 = "fe80::1%lo0" + t2 = "fe80::1" self.assertEqual(aton6(t1, True), aton6(t2)) def test_do_not_ignore_scope(self): def bad(): - t1 = 'fe80::1%lo0' + t1 = "fe80::1%lo0" aton6(t1) + self.assertRaises(dns.exception.SyntaxError, bad) def test_multiple_scopes_bad(self): def bad(): - t1 = 'fe80::1%lo0%lo1' + t1 = "fe80::1%lo0%lo1" aton6(t1, True) + self.assertRaises(dns.exception.SyntaxError, bad) def test_ptontop(self): - for (af, a) in [(socket.AF_INET, '1.2.3.4'), - (socket.AF_INET6, '2001:db8:0:1:1:1:1:1')]: - self.assertEqual(dns.inet.inet_ntop(af, dns.inet.inet_pton(af, a)), - a) + for (af, a) in [ + (socket.AF_INET, "1.2.3.4"), + (socket.AF_INET6, "2001:db8:0:1:1:1:1:1"), + ]: + self.assertEqual(dns.inet.inet_ntop(af, dns.inet.inet_pton(af, a)), a) def test_isaddress(self): - for (t, e) in [('1.2.3.4', True), - ('2001:db8:0:1:1:1:1:1', True), - ('hello world', False), - ('http://www.dnspython.org', False), - ('1.2.3.4a', False), - ('2001:db8:0:1:1:1:1:q1', False)]: + for (t, e) in [ + ("1.2.3.4", True), + ("2001:db8:0:1:1:1:1:1", True), + ("hello world", False), + ("http://www.dnspython.org", False), + ("1.2.3.4a", False), + ("2001:db8:0:1:1:1:1:q1", False), + ]: self.assertEqual(dns.inet.is_address(t), e) def test_low_level_address_tuple(self): - t = dns.inet.low_level_address_tuple(('1.2.3.4', 53)) - self.assertEqual(t, ('1.2.3.4', 53)) - t = dns.inet.low_level_address_tuple(('2600::1', 53)) - self.assertEqual(t, ('2600::1', 53, 0, 0)) - t = dns.inet.low_level_address_tuple(('1.2.3.4', 53), socket.AF_INET) - self.assertEqual(t, ('1.2.3.4', 53)) - t = dns.inet.low_level_address_tuple(('2600::1', 53), socket.AF_INET6) - self.assertEqual(t, ('2600::1', 53, 0, 0)) - t = dns.inet.low_level_address_tuple(('fd80::1%2', 53), socket.AF_INET6) - self.assertEqual(t, ('fd80::1', 53, 0, 2)) + t = dns.inet.low_level_address_tuple(("1.2.3.4", 53)) + self.assertEqual(t, ("1.2.3.4", 53)) + t = dns.inet.low_level_address_tuple(("2600::1", 53)) + self.assertEqual(t, ("2600::1", 53, 0, 0)) + t = dns.inet.low_level_address_tuple(("1.2.3.4", 53), socket.AF_INET) + self.assertEqual(t, ("1.2.3.4", 53)) + t = dns.inet.low_level_address_tuple(("2600::1", 53), socket.AF_INET6) + self.assertEqual(t, ("2600::1", 53, 0, 0)) + t = dns.inet.low_level_address_tuple(("fd80::1%2", 53), socket.AF_INET6) + self.assertEqual(t, ("fd80::1", 53, 0, 2)) try: # This can fail on windows for python < 3.8, so we tolerate # the failure and only test if we have something we can work @@ -307,20 +332,24 @@ class NtoAAtoNTestCase(unittest.TestCase): pair = p break if pair: - address = 'fd80::1%' + pair[1] - t = dns.inet.low_level_address_tuple((address, 53), - socket.AF_INET6) - self.assertEqual(t, ('fd80::1', 53, 0, pair[0])) + address = "fd80::1%" + pair[1] + t = dns.inet.low_level_address_tuple((address, 53), socket.AF_INET6) + self.assertEqual(t, ("fd80::1", 53, 0, pair[0])) + def bad(): bogus = socket.AF_INET + socket.AF_INET6 + 1 - t = dns.inet.low_level_address_tuple(('2600::1', 53), bogus) + t = dns.inet.low_level_address_tuple(("2600::1", 53), bogus) + self.assertRaises(NotImplementedError, bad) def test_bogus_family(self): - self.assertRaises(NotImplementedError, - lambda: dns.inet.inet_pton(12345, 'bogus')) - self.assertRaises(NotImplementedError, - lambda: dns.inet.inet_ntop(12345, b'bogus')) + self.assertRaises( + NotImplementedError, lambda: dns.inet.inet_pton(12345, "bogus") + ) + self.assertRaises( + NotImplementedError, lambda: dns.inet.inet_ntop(12345, b"bogus") + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_processing_order.py b/tests/test_processing_order.py index 76754dde..d2025d68 100644 --- a/tests/test_processing_order.py +++ b/tests/test_processing_order.py @@ -1,12 +1,10 @@ - import dns.rdata import dns.rdataset import dns.rdtypes.IN.SRV def test_processing_order_shuffle(): - rds = dns.rdataset.from_text('in', 'a', 300, - '10.0.0.1', '10.0.0.2', '10.0.0.3') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2", "10.0.0.3") seen = set() for i in range(100): po = rds.processing_order() @@ -18,8 +16,7 @@ def test_processing_order_shuffle(): def test_processing_order_priority_mx(): - rds = dns.rdataset.from_text('in', 'mx', 300, - '10 a', '20 b', '20 c') + rds = dns.rdataset.from_text("in", "mx", 300, "10 a", "20 b", "20 c") seen = set() for i in range(100): po = rds.processing_order() @@ -32,8 +29,9 @@ def test_processing_order_priority_mx(): def test_processing_order_priority_weighted(): - rds = dns.rdataset.from_text('in', 'srv', 300, - '1 10 1234 a', '2 90 1234 b', '2 10 1234 c') + rds = dns.rdataset.from_text( + "in", "srv", 300, "1 10 1234 a", "2 90 1234 b", "2 10 1234 c" + ) seen = set() weight_90_count = 0 weight_10_count = 0 @@ -58,9 +56,15 @@ def test_processing_order_priority_weighted(): def test_processing_order_priority_naptr(): - rds = dns.rdataset.from_text('in', 'naptr', 300, - '1 10 a b c foo.', '1 20 a b c foo.', - '2 10 a b c foo.', '2 10 d e f bar.') + rds = dns.rdataset.from_text( + "in", + "naptr", + 300, + "1 10 a b c foo.", + "1 20 a b c foo.", + "2 10 a b c foo.", + "2 10 d e f bar.", + ) seen = set() for i in range(100): po = rds.processing_order() @@ -74,26 +78,27 @@ def test_processing_order_priority_naptr(): def test_processing_order_empty(): - rds = dns.rdataset.from_text('in', 'naptr', 300) + rds = dns.rdataset.from_text("in", "naptr", 300) po = rds.processing_order() assert po == [] def test_processing_singleton_priority(): - rds = dns.rdataset.from_text('in', 'mx', 300, '10 a') + rds = dns.rdataset.from_text("in", "mx", 300, "10 a") po = rds.processing_order() assert po == [rds[0]] def test_processing_singleton_weighted(): - rds = dns.rdataset.from_text('in', 'srv', 300, '1 10 1234 a') + rds = dns.rdataset.from_text("in", "srv", 300, "1 10 1234 a") po = rds.processing_order() assert po == [rds[0]] def test_processing_all_zero_weight_srv(): - rds = dns.rdataset.from_text('in', 'srv', 300, - '1 0 1234 a', '1 0 1234 b', '1 0 1234 c') + rds = dns.rdataset.from_text( + "in", "srv", 300, "1 0 1234 a", "1 0 1234 b", "1 0 1234 c" + ) seen = set() for i in range(100): po = rds.processing_order() @@ -108,10 +113,14 @@ def test_processing_order_uri(): # We're testing here just to provide coverage for URI methods; the # testing of the weighting algorithm is done above in tests with # SRV. - rds = dns.rdataset.from_text('in', 'uri', 300, - '1 1 "ftp://ftp1.example.com/public"', - '2 2 "ftp://ftp2.example.com/public"', - '3 3 "ftp://ftp3.example.com/public"') + rds = dns.rdataset.from_text( + "in", + "uri", + 300, + '1 1 "ftp://ftp1.example.com/public"', + '2 2 "ftp://ftp2.example.com/public"', + '3 3 "ftp://ftp3.example.com/public"', + ) po = rds.processing_order() assert len(po) == 3 for i in range(3): @@ -122,10 +131,14 @@ def test_processing_order_svcb(): # We're testing here just to provide coverage for SVCB methods; the # testing of the priority algorithm is done above in tests with # MX and NAPTR. - rds = dns.rdataset.from_text('in', 'svcb', 300, - "1 . mandatory=alpn alpn=h2", - "2 . mandatory=alpn alpn=h2", - "3 . mandatory=alpn alpn=h2") + rds = dns.rdataset.from_text( + "in", + "svcb", + 300, + "1 . mandatory=alpn alpn=h2", + "2 . mandatory=alpn alpn=h2", + "3 . mandatory=alpn alpn=h2", + ) po = rds.processing_order() assert len(po) == 3 for i in range(3): diff --git a/tests/test_query.py b/tests/test_query.py index 2d954e36..e8a53902 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -22,6 +22,7 @@ import unittest try: import ssl + have_ssl = True except Exception: have_ssl = False @@ -40,7 +41,7 @@ import dns.zone # skip those if it's not there. _network_available = True try: - socket.gethostbyname('dnspython.org') + socket.gethostbyname("dnspython.org") except socket.gaierror: _network_available = False @@ -49,16 +50,21 @@ except socket.gaierror: # those tests. try: from .nanonameserver import Server + _nanonameserver_available = True except ImportError: _nanonameserver_available = False + class Server(object): pass + # Probe for IPv4 and IPv6 query_addresses = [] -for (af, address) in ((socket.AF_INET, '8.8.8.8'), - (socket.AF_INET6, '2001:4860:4860::8888')): +for (af, address) in ( + (socket.AF_INET, "8.8.8.8"), + (socket.AF_INET6, "2001:4860:4860::8888"), +): try: with socket.socket(af, socket.SOCK_DGRAM) as s: # Connecting a UDP socket is supposed to return ENETUNREACH if @@ -68,85 +74,93 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'), except Exception: pass -keyring = dns.tsigkeyring.from_text({'name': 'tDz6cfXXGtNivRpQ98hr6A=='}) +keyring = dns.tsigkeyring.from_text({"name": "tDz6cfXXGtNivRpQ98hr6A=="}) + @unittest.skipIf(not _network_available, "Internet not reachable") class QueryTests(unittest.TestCase): - def testQueryUDP(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryUDPWithSocket(self): for address in query_addresses: - with socket.socket(dns.inet.af_for_address(address), - socket.SOCK_DGRAM) as s: + with socket.socket( + dns.inet.af_for_address(address), socket.SOCK_DGRAM + ) as s: s.setblocking(0) - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address, sock=s, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryTCP(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, address, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryTCPWithSocket(self): for address in query_addresses: - with socket.socket(dns.inet.af_for_address(address), - socket.SOCK_STREAM) as s: + with socket.socket( + dns.inet.af_for_address(address), socket.SOCK_STREAM + ) as s: ll = dns.inet.low_level_address_tuple((address, 53)) s.settimeout(2) s.connect(ll) s.setblocking(0) - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, None, sock=s, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryTLS(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, address, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLSWithSocket(self): for address in query_addresses: - with socket.socket(dns.inet.af_for_address(address), - socket.SOCK_STREAM) as base_s: + with socket.socket( + dns.inet.af_for_address(address), socket.SOCK_STREAM + ) as base_s: ll = dns.inet.low_level_address_tuple((address, 853)) base_s.settimeout(2) base_s.connect(ll) @@ -155,21 +169,24 @@ class QueryTests(unittest.TestCase): ctx.minimum_version = ssl.TLSVersion.TLSv1_2 else: ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - with ctx.wrap_socket(base_s, server_hostname='dns.google') as s: # lgtm[py/insecure-protocol] + with ctx.wrap_socket( + base_s, server_hostname="dns.google" + ) as s: # lgtm[py/insecure-protocol] s.setblocking(0) - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, None, sock=s, timeout=2) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testQueryUDPFallback(self): for address in query_addresses: - qname = dns.name.from_text('.') + qname = dns.name.from_text(".") q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertTrue(tcp) @@ -184,116 +201,120 @@ class QueryTests(unittest.TestCase): tcp_s.settimeout(2) tcp_s.connect(ll) tcp_s.setblocking(0) - qname = dns.name.from_text('.') + qname = dns.name.from_text(".") q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) - (_, tcp) = dns.query.udp_with_fallback(q, address, - udp_sock=udp_s, - tcp_sock=tcp_s, - timeout=2) + (_, tcp) = dns.query.udp_with_fallback( + q, address, udp_sock=udp_s, tcp_sock=tcp_s, timeout=2 + ) self.assertTrue(tcp) def testQueryUDPFallbackNoFallback(self): for address in query_addresses: - qname = dns.name.from_text('dns.google.') + qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertFalse(tcp) def testUDPReceiveQuery(self): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: - listener.bind(('127.0.0.1', 0)) + listener.bind(("127.0.0.1", 0)) with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: - sender.bind(('127.0.0.1', 0)) - q = dns.message.make_query('dns.google', dns.rdatatype.A) + sender.bind(("127.0.0.1", 0)) + q = dns.message.make_query("dns.google", dns.rdatatype.A) dns.query.send_udp(sender, q, listener.getsockname()) expiration = time.time() + 2 - (q, _, addr) = dns.query.receive_udp(listener, - expiration=expiration) + (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration) self.assertEqual(addr, sender.getsockname()) # for brevity _d_and_s = dns.query._destination_and_source -class DestinationAndSourceTests(unittest.TestCase): +class DestinationAndSourceTests(unittest.TestCase): def test_af_inferred_from_where(self): - (af, d, s) = _d_and_s('1.2.3.4', 53, None, 0) + (af, d, s) = _d_and_s("1.2.3.4", 53, None, 0) self.assertEqual(af, socket.AF_INET) def test_af_inferred_from_where(self): - (af, d, s) = _d_and_s('1::2', 53, None, 0) + (af, d, s) = _d_and_s("1::2", 53, None, 0) self.assertEqual(af, socket.AF_INET6) def test_af_inferred_from_source(self): - (af, d, s) = _d_and_s('https://example/dns-query', 443, - '1.2.3.4', 0, False) + (af, d, s) = _d_and_s("https://example/dns-query", 443, "1.2.3.4", 0, False) self.assertEqual(af, socket.AF_INET) def test_af_mismatch(self): def bad(): - (af, d, s) = _d_and_s('1::2', 53, '1.2.3.4', 0) + (af, d, s) = _d_and_s("1::2", 53, "1.2.3.4", 0) + self.assertRaises(ValueError, bad) def test_source_port_but_no_af_inferred(self): def bad(): - (af, d, s) = _d_and_s('https://example/dns-query', 443, - None, 12345, False) + (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 12345, False) + self.assertRaises(ValueError, bad) def test_where_must_be_an_address(self): def bad(): - (af, d, s) = _d_and_s('not a valid address', 53, '1.2.3.4', 0) + (af, d, s) = _d_and_s("not a valid address", 53, "1.2.3.4", 0) + self.assertRaises(ValueError, bad) def test_destination_is_none_of_where_url(self): - (af, d, s) = _d_and_s('https://example/dns-query', 443, None, 0, False) + (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 0, False) self.assertEqual(d, None) def test_v4_wildcard_source_set(self): - (af, d, s) = _d_and_s('1.2.3.4', 53, None, 12345) - self.assertEqual(s, ('0.0.0.0', 12345)) + (af, d, s) = _d_and_s("1.2.3.4", 53, None, 12345) + self.assertEqual(s, ("0.0.0.0", 12345)) def test_v6_wildcard_source_set(self): - (af, d, s) = _d_and_s('1::2', 53, None, 12345) - self.assertEqual(s, ('::', 12345, 0, 0)) + (af, d, s) = _d_and_s("1::2", 53, None, 12345) + self.assertEqual(s, ("::", 12345, 0, 0)) class AddressesEqualTestCase(unittest.TestCase): - def test_v4(self): - self.assertTrue(dns.query._addresses_equal(socket.AF_INET, - ('10.0.0.1', 53), - ('10.0.0.1', 53))) - self.assertFalse(dns.query._addresses_equal(socket.AF_INET, - ('10.0.0.1', 53), - ('10.0.0.2', 53))) + self.assertTrue( + dns.query._addresses_equal( + socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.1", 53) + ) + ) + self.assertFalse( + dns.query._addresses_equal( + socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.2", 53) + ) + ) def test_v6(self): - self.assertTrue(dns.query._addresses_equal(socket.AF_INET6, - ('1::1', 53), - ('0001:0000::1', 53))) - self.assertFalse(dns.query._addresses_equal(socket.AF_INET6, - ('::1', 53), - ('::2', 53))) + self.assertTrue( + dns.query._addresses_equal( + socket.AF_INET6, ("1::1", 53), ("0001:0000::1", 53) + ) + ) + self.assertFalse( + dns.query._addresses_equal(socket.AF_INET6, ("::1", 53), ("::2", 53)) + ) def test_mixed(self): - self.assertFalse(dns.query._addresses_equal(socket.AF_INET, - ('10.0.0.1', 53), - ('::2', 53))) + self.assertFalse( + dns.query._addresses_equal(socket.AF_INET, ("10.0.0.1", 53), ("::2", 53)) + ) -axfr_zone = ''' +axfr_zone = """ $TTL 300 @ SOA ns1 root 1 7200 900 1209600 86400 @ NS ns1 @ NS ns2 ns1 A 10.0.0.1 ns2 A 10.0.0.1 -''' +""" -class AXFRNanoNameserver(Server): +class AXFRNanoNameserver(Server): def handle(self, request): self.zone = dns.zone.from_text(axfr_zone, origin=self.origin) self.origin = self.zone.origin @@ -307,11 +328,11 @@ class AXFRNanoNameserver(Server): response.question = [] response.flags |= dns.flags.AA for (name, rdataset) in self.zone.iterate_rdatasets(): - if rdataset.rdtype == dns.rdatatype.SOA and \ - name == dns.name.empty: + if rdataset.rdtype == dns.rdatatype.SOA and name == dns.name.empty: continue - rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, - rdataset.covers) + rrset = dns.rrset.RRset( + name, rdataset.rdclass, rdataset.rdtype, rdataset.covers + ) rrset.update(rdataset) response.answer.append(rrset) items.append(response) @@ -322,7 +343,8 @@ class AXFRNanoNameserver(Server): items.append(response) return items -ixfr_message = '''id 12345 + +ixfr_message = """id 12345 opcode QUERY rcode NOERROR flags AA @@ -340,11 +362,11 @@ example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 added2.example. 300 IN A 10.0.0.5 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 -''' +""" -ixfr_trailing_junk = ixfr_message + 'junk.example. 300 IN A 10.0.0.6' +ixfr_trailing_junk = ixfr_message + "junk.example. 300 IN A 10.0.0.6" -ixfr_up_to_date_message = '''id 12345 +ixfr_up_to_date_message = """id 12345 opcode QUERY rcode NOERROR flags AA @@ -352,9 +374,9 @@ flags AA example. IN IXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400 -''' +""" -axfr_trailing_junk = '''id 12345 +axfr_trailing_junk = """id 12345 opcode QUERY rcode NOERROR flags AA @@ -367,10 +389,10 @@ added2.example. 300 IN A 10.0.0.5 changed.example. 300 IN A 10.0.0.4 example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 junk.example. 300 IN A 10.0.0.6 -''' +""" -class IXFRNanoNameserver(Server): +class IXFRNanoNameserver(Server): def __init__(self, response_text): super().__init__() self.response_text = response_text @@ -383,137 +405,167 @@ class IXFRNanoNameserver(Server): except Exception: pass + @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class XfrTests(unittest.TestCase): - def test_axfr(self): - expected = dns.zone.from_text(axfr_zone, origin='example') - with AXFRNanoNameserver(origin='example') as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - port=ns.tcp_address[1]) + expected = dns.zone.from_text(axfr_zone, origin="example") + with AXFRNanoNameserver(origin="example") as ns: + xfr = dns.query.xfr(ns.tcp_address[0], "example", port=ns.tcp_address[1]) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_tsig(self): - expected = dns.zone.from_text(axfr_zone, origin='example') - with AXFRNanoNameserver(origin='example', keyring=keyring) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - port=ns.tcp_address[1], - keyring=keyring, keyname='name') + expected = dns.zone.from_text(axfr_zone, origin="example") + with AXFRNanoNameserver(origin="example", keyring=keyring) as ns: + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + port=ns.tcp_address[1], + keyring=keyring, + keyname="name", + ) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_root_tsig(self): - expected = dns.zone.from_text(axfr_zone, origin='.') - with AXFRNanoNameserver(origin='.', keyring=keyring) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], '.', - port=ns.tcp_address[1], - keyring=keyring, keyname='name') + expected = dns.zone.from_text(axfr_zone, origin=".") + with AXFRNanoNameserver(origin=".", keyring=keyring) as ns: + xfr = dns.query.xfr( + ns.tcp_address[0], + ".", + port=ns.tcp_address[1], + keyring=keyring, + keyname="name", + ) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_udp(self): def bad(): - with AXFRNanoNameserver(origin='example') as ns: - xfr = dns.query.xfr(ns.udp_address[0], 'example', - port=ns.udp_address[1], use_udp=True) + with AXFRNanoNameserver(origin="example") as ns: + xfr = dns.query.xfr( + ns.udp_address[0], "example", port=ns.udp_address[1], use_udp=True + ) l = list(xfr) + self.assertRaises(ValueError, bad) def test_axfr_bad_rcode(self): def bad(): # We just use Server here as by default it will refuse. with Server() as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - port=ns.tcp_address[1]) + xfr = dns.query.xfr( + ns.tcp_address[0], "example", port=ns.tcp_address[1] + ) l = list(xfr) + self.assertRaises(dns.query.TransferError, bad) def test_axfr_trailing_junk(self): # we use the IXFR server here as it returns messages def bad(): with IXFRNanoNameserver(axfr_trailing_junk) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - dns.rdatatype.AXFR, - port=ns.tcp_address[1]) + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + dns.rdatatype.AXFR, + port=ns.tcp_address[1], + ) l = list(xfr) + self.assertRaises(dns.exception.FormError, bad) def test_ixfr_tcp(self): with IXFRNanoNameserver(ixfr_message) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - dns.rdatatype.IXFR, - port=ns.tcp_address[1], - serial=2, - relativize=False) + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + dns.rdatatype.IXFR, + port=ns.tcp_address[1], + serial=2, + relativize=False, + ) l = list(xfr) self.assertEqual(len(l), 1) - expected = dns.message.from_text(ixfr_message, - one_rr_per_rrset=True) + expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_udp(self): with IXFRNanoNameserver(ixfr_message) as ns: - xfr = dns.query.xfr(ns.udp_address[0], 'example', - dns.rdatatype.IXFR, - port=ns.udp_address[1], - serial=2, - relativize=False, use_udp=True) + xfr = dns.query.xfr( + ns.udp_address[0], + "example", + dns.rdatatype.IXFR, + port=ns.udp_address[1], + serial=2, + relativize=False, + use_udp=True, + ) l = list(xfr) self.assertEqual(len(l), 1) - expected = dns.message.from_text(ixfr_message, - one_rr_per_rrset=True) + expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_up_to_date(self): with IXFRNanoNameserver(ixfr_up_to_date_message) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - dns.rdatatype.IXFR, - port=ns.tcp_address[1], - serial=2, - relativize=False) + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + dns.rdatatype.IXFR, + port=ns.tcp_address[1], + serial=2, + relativize=False, + ) l = list(xfr) self.assertEqual(len(l), 1) - expected = dns.message.from_text(ixfr_up_to_date_message, - one_rr_per_rrset=True) + expected = dns.message.from_text( + ixfr_up_to_date_message, one_rr_per_rrset=True + ) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_trailing_junk(self): def bad(): with IXFRNanoNameserver(ixfr_trailing_junk) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - dns.rdatatype.IXFR, - port=ns.tcp_address[1], - serial=2, - relativize=False) + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + dns.rdatatype.IXFR, + port=ns.tcp_address[1], + serial=2, + relativize=False, + ) l = list(xfr) + self.assertRaises(dns.exception.FormError, bad) def test_ixfr_base_serial_mismatch(self): def bad(): with IXFRNanoNameserver(ixfr_message) as ns: - xfr = dns.query.xfr(ns.tcp_address[0], 'example', - dns.rdatatype.IXFR, - port=ns.tcp_address[1], - serial=1, - relativize=False) + xfr = dns.query.xfr( + ns.tcp_address[0], + "example", + dns.rdatatype.IXFR, + port=ns.tcp_address[1], + serial=1, + relativize=False, + ) l = list(xfr) + self.assertRaises(dns.exception.FormError, bad) -class TSIGNanoNameserver(Server): +class TSIGNanoNameserver(Server): def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: - if request.qtype == dns.rdatatype.A and \ - request.qclass == dns.rdataclass.IN: - rrs = dns.rrset.from_text(request.qname, 300, - 'IN', 'A', '1.2.3.4') + if request.qtype == dns.rdatatype.A and request.qclass == dns.rdataclass.IN: + rrs = dns.rrset.from_text(request.qname, 300, "IN", "A", "1.2.3.4") response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR) response.flags |= dns.flags.AA @@ -521,27 +573,26 @@ class TSIGNanoNameserver(Server): pass return response + @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class TsigTests(unittest.TestCase): - def test_tsig(self): with TSIGNanoNameserver(keyring=keyring) as ns: - qname = dns.name.from_text('example.com') - q = dns.message.make_query(qname, 'A') - q.use_tsig(keyring=keyring, keyname='name') - response = dns.query.udp(q, ns.udp_address[0], - port=ns.udp_address[1]) + qname = dns.name.from_text("example.com") + q = dns.message.make_query(qname, "A") + q.use_tsig(keyring=keyring, keyname="name") + response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1]) self.assertTrue(response.had_tsig) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) - self.assertTrue('1.2.3.4' in seen) + self.assertTrue("1.2.3.4" in seen) -@unittest.skipIf(sys.platform == 'win32', - 'low level tests do not work on win32') -class LowLevelWaitTests(unittest.TestCase): +@unittest.skipIf(sys.platform == "win32", "low level tests do not work on win32") +class LowLevelWaitTests(unittest.TestCase): def test_wait_for(self): try: (l, r) = socket.socketpair() @@ -560,28 +611,32 @@ class LowLevelWaitTests(unittest.TestCase): class MiscTests(unittest.TestCase): def test_matches_destination(self): - self.assertTrue(dns.query._matches_destination(socket.AF_INET, - ('10.0.0.1', 1234), - ('10.0.0.1', 1234), - True)) - self.assertTrue(dns.query._matches_destination(socket.AF_INET6, - ('1::2', 1234), - ('0001::2', 1234), - True)) - self.assertTrue(dns.query._matches_destination(socket.AF_INET, - ('10.0.0.1', 1234), - None, - True)) - self.assertFalse(dns.query._matches_destination(socket.AF_INET, - ('10.0.0.1', 1234), - ('10.0.0.2', 1234), - True)) - self.assertFalse(dns.query._matches_destination(socket.AF_INET, - ('10.0.0.1', 1234), - ('10.0.0.1', 1235), - True)) + self.assertTrue( + dns.query._matches_destination( + socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1234), True + ) + ) + self.assertTrue( + dns.query._matches_destination( + socket.AF_INET6, ("1::2", 1234), ("0001::2", 1234), True + ) + ) + self.assertTrue( + dns.query._matches_destination( + socket.AF_INET, ("10.0.0.1", 1234), None, True + ) + ) + self.assertFalse( + dns.query._matches_destination( + socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.2", 1234), True + ) + ) + self.assertFalse( + dns.query._matches_destination( + socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), True + ) + ) with self.assertRaises(dns.query.UnexpectedSource): - dns.query._matches_destination(socket.AF_INET, - ('10.0.0.1', 1234), - ('10.0.0.1', 1235), - False) + dns.query._matches_destination( + socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False + ) diff --git a/tests/test_rdata.py b/tests/test_rdata.py index c002e7ab..a1c066af 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -44,49 +44,46 @@ import tests.ttxt_module import tests.md_module from tests.util import here -class RdataTestCase(unittest.TestCase): +class RdataTestCase(unittest.TestCase): def test_str(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - "1.2.3.4") + rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4") self.assertEqual(rdata.address, "1.2.3.4") def test_unicode(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - u"1.2.3.4") + rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4") self.assertEqual(rdata.address, "1.2.3.4") def test_module_registration(self): TTXT = 64001 - dns.rdata.register_type(tests.ttxt_module, TTXT, 'TTXT') - rdata = dns.rdata.from_text(dns.rdataclass.IN, TTXT, 'hello world') - self.assertEqual(rdata.strings, (b'hello', b'world')) - self.assertEqual(dns.rdatatype.to_text(TTXT), 'TTXT') - self.assertEqual(dns.rdatatype.from_text('TTXT'), TTXT) + dns.rdata.register_type(tests.ttxt_module, TTXT, "TTXT") + rdata = dns.rdata.from_text(dns.rdataclass.IN, TTXT, "hello world") + self.assertEqual(rdata.strings, (b"hello", b"world")) + self.assertEqual(dns.rdatatype.to_text(TTXT), "TTXT") + self.assertEqual(dns.rdatatype.from_text("TTXT"), TTXT) def test_module_reregistration(self): def bad(): TTXTTWO = dns.rdatatype.TXT - dns.rdata.register_type(tests.ttxt_module, TTXTTWO, 'TTXTTWO') + dns.rdata.register_type(tests.ttxt_module, TTXTTWO, "TTXTTWO") + self.assertRaises(dns.rdata.RdatatypeExists, bad) def test_module_registration_singleton(self): STXT = 64002 - dns.rdata.register_type(tests.stxt_module, STXT, 'STXT', - is_singleton=True) - rdata1 = dns.rdata.from_text(dns.rdataclass.IN, STXT, 'hello') - rdata2 = dns.rdata.from_text(dns.rdataclass.IN, STXT, 'world') + dns.rdata.register_type(tests.stxt_module, STXT, "STXT", is_singleton=True) + rdata1 = dns.rdata.from_text(dns.rdataclass.IN, STXT, "hello") + rdata2 = dns.rdata.from_text(dns.rdataclass.IN, STXT, "world") rdataset = dns.rdataset.from_rdata(3600, rdata1, rdata2) self.assertEqual(len(rdataset), 1) - self.assertEqual(rdataset[0].strings, (b'world',)) + self.assertEqual(rdataset[0].strings, (b"world",)) def test_replace(self): a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4") a2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "2.3.4.5") self.assertEqual(a1.replace(address="2.3.4.5"), a2) - mx = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - "10 foo.example") + mx = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "10 foo.example") name = dns.name.from_text("bar.example") self.assertEqual(mx.replace(preference=20).preference, 20) self.assertEqual(mx.replace(preference=20).exchange, mx.exchange) @@ -103,8 +100,7 @@ class RdataTestCase(unittest.TestCase): a1.replace(address="bogus") def test_replace_comment(self): - a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - "1.2.3.4 ;foo") + a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4 ;foo") self.assertEqual(a1.rdcomment, "foo") a2 = a1.replace(rdcomment="bar") self.assertEqual(a1, a2) @@ -120,16 +116,18 @@ class RdataTestCase(unittest.TestCase): def test_to_generic(self): a = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4") - self.assertEqual(str(a.to_generic()), r'\# 4 01020304') + self.assertEqual(str(a.to_generic()), r"\# 4 01020304") mx = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "10 foo.") - self.assertEqual(str(mx.to_generic()), r'\# 7 000a03666f6f00') + self.assertEqual(str(mx.to_generic()), r"\# 7 000a03666f6f00") - origin = dns.name.from_text('example') - ns = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - "foo.example.", relativize_to=origin) - self.assertEqual(str(ns.to_generic(origin=origin)), - r'\# 13 03666f6f076578616d706c6500') + origin = dns.name.from_text("example") + ns = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.NS, "foo.example.", relativize_to=origin + ) + self.assertEqual( + str(ns.to_generic(origin=origin)), r"\# 13 03666f6f076578616d706c6500" + ) def test_txt_unicode(self): # TXT records are not defined for Unicode, but if we get @@ -138,8 +136,9 @@ class RdataTestCase(unittest.TestCase): # to_text(), it does NOT convert embedded UTF-8 back to # Unicode; it's just treated as binary TXT data. Probably # there should be a TXT-like record with an encoding field. - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.TXT, - '"foo\u200bbar"') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.TXT, '"foo\u200bbar"' + ) self.assertEqual(str(rdata), '"foo\\226\\128\\139bar"') # We used to encode UTF-8 in UTF-8 because we processed # escapes in quoted strings immediately. This meant that the @@ -148,28 +147,35 @@ class RdataTestCase(unittest.TestCase): # point, emitting \\195\\162 instead of \\226, and thus # from_text followed by to_text was not the equal to the # original input like it ought to be. - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.TXT, - '"foo\\226\\128\\139bar"') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.TXT, '"foo\\226\\128\\139bar"' + ) self.assertEqual(str(rdata), '"foo\\226\\128\\139bar"') # Our fix for TXT-like records uses a new tokenizer method, # unescape_to_bytes(), which converts Unicode to UTF-8 only # once. - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.TXT, - '"foo\u200b\\123bar"') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.TXT, '"foo\u200b\\123bar"' + ) self.assertEqual(str(rdata), '"foo\\226\\128\\139{bar"') def test_unicode_idna2003_in_rdata(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - "Königsgäßchen") - self.assertEqual(str(rdata.target), 'xn--knigsgsschen-lcb0w') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.NS, "Königsgäßchen" + ) + self.assertEqual(str(rdata.target), "xn--knigsgsschen-lcb0w") - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def test_unicode_idna2008_in_rdata(self): - rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - "Königsgäßchen", - idna_codec=dns.name.IDNA_2008) - self.assertEqual(str(rdata.target), 'xn--knigsgchen-b4a3dun') + rdata = dns.rdata.from_text( + dns.rdataclass.IN, + dns.rdatatype.NS, + "Königsgäßchen", + idna_codec=dns.name.IDNA_2008, + ) + self.assertEqual(str(rdata.target), "xn--knigsgchen-b4a3dun") def test_digestable_downcasing(self): # Make sure all the types listed in RFC 4034 section 6.2 are @@ -183,32 +189,35 @@ class RdataTestCase(unittest.TestCase): # NSEC3, whose downcasing was removed by RFC 6840 section 5.1 # cases = [ - ('SOA', 'NAME NAME 1 2 3 4 5'), - ('AFSDB', '0 NAME'), - ('CNAME', 'NAME'), - ('DNAME', 'NAME'), - ('KX', '10 NAME'), - ('MX', '10 NAME'), - ('NS', 'NAME'), - ('NAPTR', '0 0 a B c NAME'), - ('PTR', 'NAME'), - ('PX', '65535 NAME NAME'), - ('RP', 'NAME NAME'), - ('RT', '0 NAME'), - ('SRV', '0 0 0 NAME'), - ('RRSIG', - 'A 1 3 3600 20200701000000 20200601000000 1 NAME Ym9ndXM=') + ("SOA", "NAME NAME 1 2 3 4 5"), + ("AFSDB", "0 NAME"), + ("CNAME", "NAME"), + ("DNAME", "NAME"), + ("KX", "10 NAME"), + ("MX", "10 NAME"), + ("NS", "NAME"), + ("NAPTR", "0 0 a B c NAME"), + ("PTR", "NAME"), + ("PX", "65535 NAME NAME"), + ("RP", "NAME NAME"), + ("RT", "0 NAME"), + ("SRV", "0 0 0 NAME"), + ("RRSIG", "A 1 3 3600 20200701000000 20200601000000 1 NAME Ym9ndXM="), ] for rdtype, text in cases: - upper_origin = dns.name.from_text('EXAMPLE') - lower_origin = dns.name.from_text('example') - canonical_text = text.replace('NAME', 'name') - rdata = dns.rdata.from_text(dns.rdataclass.IN, rdtype, text, - origin=upper_origin, relativize=False) - canonical_rdata = dns.rdata.from_text(dns.rdataclass.IN, rdtype, - canonical_text, - origin=lower_origin, - relativize=False) + upper_origin = dns.name.from_text("EXAMPLE") + lower_origin = dns.name.from_text("example") + canonical_text = text.replace("NAME", "name") + rdata = dns.rdata.from_text( + dns.rdataclass.IN, rdtype, text, origin=upper_origin, relativize=False + ) + canonical_rdata = dns.rdata.from_text( + dns.rdataclass.IN, + rdtype, + canonical_text, + origin=lower_origin, + relativize=False, + ) digestable_wire = rdata.to_digestable() f = io.BytesIO() canonical_rdata.to_wire(f) @@ -221,23 +230,22 @@ class RdataTestCase(unittest.TestCase): # handled properly. # cases = [ - ('HIP', '2 200100107B1A74DF365639CC39F1D578 Ym9ndXM= NAME name'), - ('IPSECKEY', '10 3 2 NAME Ym9ndXM='), - ('NSEC', 'NAME A'), + ("HIP", "2 200100107B1A74DF365639CC39F1D578 Ym9ndXM= NAME name"), + ("IPSECKEY", "10 3 2 NAME Ym9ndXM="), + ("NSEC", "NAME A"), ] for rdtype, text in cases: - origin = dns.name.from_text('example') - rdata = dns.rdata.from_text(dns.rdataclass.IN, rdtype, text, - origin=origin, relativize=False) + origin = dns.name.from_text("example") + rdata = dns.rdata.from_text( + dns.rdataclass.IN, rdtype, text, origin=origin, relativize=False + ) digestable_wire = rdata.to_digestable(origin) expected_wire = rdata.to_wire(origin=origin) self.assertEqual(digestable_wire, expected_wire) def test_basic_relations(self): - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") self.assertTrue(r1 == r1) self.assertTrue(r1 != r2) self.assertTrue(r1 < r2) @@ -246,10 +254,8 @@ class RdataTestCase(unittest.TestCase): self.assertTrue(r2 >= r1) def test_incompatible_relations(self): - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.AAAA, - '::1') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.AAAA, "::1") for oper in [operator.lt, operator.le, operator.ge, operator.gt]: self.assertRaises(TypeError, lambda: oper(r1, r2)) self.assertFalse(r1 == r2) @@ -257,32 +263,34 @@ class RdataTestCase(unittest.TestCase): def test_immutability(self): def bad1(): - r = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - r.address = '10.0.0.2' + r = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + r.address = "10.0.0.2" + self.assertRaises(TypeError, bad1) + def bad2(): - r = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') + r = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") del r.address + self.assertRaises(TypeError, bad2) def test_pickle(self): - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") p = pickle.dumps(r1) r2 = pickle.loads(p) self.assertEqual(r1, r2) # Pickle something with a longer inheritance chain - r3 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '10 mail.example.') + r3 = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.MX, "10 mail.example." + ) p = pickle.dumps(r3) r4 = pickle.loads(p) self.assertEqual(r3, r4) def test_AFSDB_properties(self): - rd = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.AFSDB, - '0 afsdb.example.') + rd = dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.AFSDB, "0 afsdb.example." + ) self.assertEqual(rd.preference, rd.subtype) self.assertEqual(rd.exchange, rd.hostname) @@ -293,93 +301,118 @@ class RdataTestCase(unittest.TestCase): def test_misc_good_LOC_text(self): # test just degrees - self.equal_loc('60 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 0 0 N 24 39 0.000 E 10.00m 20m 2000m 20m') - self.equal_loc('60 0 0 N 24 E 10.00m 20m 2000m 20m', - '60 0 0 N 24 0 0 E 10.00m 20m 2000m 20m') + self.equal_loc( + "60 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 0 0 N 24 39 0.000 E 10.00m 20m 2000m 20m", + ) + self.equal_loc( + "60 0 0 N 24 E 10.00m 20m 2000m 20m", + "60 0 0 N 24 0 0 E 10.00m 20m 2000m 20m", + ) # test variable length latitude - self.equal_loc('60 9 0.510 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.51 N 24 39 0.000 E 10.00m 20m 2000m 20m') - self.equal_loc('60 9 0.500 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.5 N 24 39 0.000 E 10.00m 20m 2000m 20m') - self.equal_loc('60 9 1.000 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 1 N 24 39 0.000 E 10.00m 20m 2000m 20m') + self.equal_loc( + "60 9 0.510 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.51 N 24 39 0.000 E 10.00m 20m 2000m 20m", + ) + self.equal_loc( + "60 9 0.500 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.5 N 24 39 0.000 E 10.00m 20m 2000m 20m", + ) + self.equal_loc( + "60 9 1.000 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 1 N 24 39 0.000 E 10.00m 20m 2000m 20m", + ) # test variable length longtitude - self.equal_loc('60 9 0.000 N 24 39 0.510 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 0.51 E 10.00m 20m 2000m 20m') - self.equal_loc('60 9 0.000 N 24 39 0.500 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 0.5 E 10.00m 20m 2000m 20m') - self.equal_loc('60 9 0.000 N 24 39 1.000 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 1 E 10.00m 20m 2000m 20m') + self.equal_loc( + "60 9 0.000 N 24 39 0.510 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.51 E 10.00m 20m 2000m 20m", + ) + self.equal_loc( + "60 9 0.000 N 24 39 0.500 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.5 E 10.00m 20m 2000m 20m", + ) + self.equal_loc( + "60 9 0.000 N 24 39 1.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 1 E 10.00m 20m 2000m 20m", + ) # test siz, hp, vp defaults - self.equal_loc('60 9 0.510 N 24 39 0.000 E 10.00m', - '60 9 0.51 N 24 39 0.000 E 10.00m 1m 10000m 10m') - self.equal_loc('60 9 0.510 N 24 39 0.000 E 10.00m 2m', - '60 9 0.51 N 24 39 0.000 E 10.00m 2m 10000m 10m') - self.equal_loc('60 9 0.510 N 24 39 0.000 E 10.00m 2m 2000m', - '60 9 0.51 N 24 39 0.000 E 10.00m 2m 2000m 10m') + self.equal_loc( + "60 9 0.510 N 24 39 0.000 E 10.00m", + "60 9 0.51 N 24 39 0.000 E 10.00m 1m 10000m 10m", + ) + self.equal_loc( + "60 9 0.510 N 24 39 0.000 E 10.00m 2m", + "60 9 0.51 N 24 39 0.000 E 10.00m 2m 10000m 10m", + ) + self.equal_loc( + "60 9 0.510 N 24 39 0.000 E 10.00m 2m 2000m", + "60 9 0.51 N 24 39 0.000 E 10.00m 2m 2000m 10m", + ) # test siz, hp, vp optional units - self.equal_loc('60 9 0.510 N 24 39 0.000 E 1m 20m 2000m 20m', - '60 9 0.51 N 24 39 0.000 E 1 20 2000 20') + self.equal_loc( + "60 9 0.510 N 24 39 0.000 E 1m 20m 2000m 20m", + "60 9 0.51 N 24 39 0.000 E 1 20 2000 20", + ) def test_LOC_to_text_SW_hemispheres(self): # As an extra, we test int->float conversion in the constructor loc = LOC(dns.rdataclass.IN, dns.rdatatype.LOC, -60, -24, 1) - text = '60 0 0.000 S 24 0 0.000 W 0.01m' + text = "60 0 0.000 S 24 0 0.000 W 0.01m" self.assertEqual(loc.to_text(), text) def test_zero_size(self): # This is to exercise the 0 path in _exponent_of. - loc = dns.rdata.from_text('in', 'loc', '60 S 24 W 1 0') + loc = dns.rdata.from_text("in", "loc", "60 S 24 W 1 0") self.assertEqual(loc.size, 0.0) def test_bad_LOC_text(self): - bad_locs = ['60 9 a.000 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 60.000 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.00a N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.0001 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.000 Z 24 39 0.000 E 10.00m 20m 2000m 20m', - '91 9 0.000 N 24 39 0.000 E 10.00m 20m 2000m 20m', - '60 60 0.000 N 24 39 0.000 E 10.00m 20m 2000m 20m', - - '60 9 0.000 N 24 39 a.000 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 60.000 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 0.00a E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 0.0001 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 39 0.000 Z 10.00m 20m 2000m 20m', - '60 9 0.000 N 181 39 0.000 E 10.00m 20m 2000m 20m', - '60 9 0.000 N 24 60 0.000 E 10.00m 20m 2000m 20m', - - '60 9 0.000 N 24 39 0.000 E 10.00m 100000000m 2000m 20m', - '60 9 0.000 N 24 39 0.000 E 10.00m 20m 100000000m 20m', - '60 9 0.000 N 24 39 0.000 E 10.00m 20m 20m 100000000m', - ] + bad_locs = [ + "60 9 a.000 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 60.000 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.00a N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.0001 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 Z 24 39 0.000 E 10.00m 20m 2000m 20m", + "91 9 0.000 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 60 0.000 N 24 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 a.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 60.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.00a E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.0001 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.000 Z 10.00m 20m 2000m 20m", + "60 9 0.000 N 181 39 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 60 0.000 E 10.00m 20m 2000m 20m", + "60 9 0.000 N 24 39 0.000 E 10.00m 100000000m 2000m 20m", + "60 9 0.000 N 24 39 0.000 E 10.00m 20m 100000000m 20m", + "60 9 0.000 N 24 39 0.000 E 10.00m 20m 20m 100000000m", + ] for loc in bad_locs: with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.LOC, loc) def test_bad_LOC_wire(self): - bad_locs = [(0, 0, 0, 0x934fd901, 0x80000000, 100), - (0, 0, 0, 0x6cb026ff, 0x80000000, 100), - (0, 0, 0, 0x80000000, 0xa69fb201, 100), - (0, 0, 0, 0x80000000, 0x59604dff, 100), - (0xa0, 0, 0, 0x80000000, 0x80000000, 100), - (0x0a, 0, 0, 0x80000000, 0x80000000, 100), - (0, 0xa0, 0, 0x80000000, 0x80000000, 100), - (0, 0x0a, 0, 0x80000000, 0x80000000, 100), - (0, 0, 0xa0, 0x80000000, 0x80000000, 100), - (0, 0, 0x0a, 0x80000000, 0x80000000, 100), - ] + bad_locs = [ + (0, 0, 0, 0x934FD901, 0x80000000, 100), + (0, 0, 0, 0x6CB026FF, 0x80000000, 100), + (0, 0, 0, 0x80000000, 0xA69FB201, 100), + (0, 0, 0, 0x80000000, 0x59604DFF, 100), + (0xA0, 0, 0, 0x80000000, 0x80000000, 100), + (0x0A, 0, 0, 0x80000000, 0x80000000, 100), + (0, 0xA0, 0, 0x80000000, 0x80000000, 100), + (0, 0x0A, 0, 0x80000000, 0x80000000, 100), + (0, 0, 0xA0, 0x80000000, 0x80000000, 100), + (0, 0, 0x0A, 0x80000000, 0x80000000, 100), + ] for t in bad_locs: with self.assertRaises(dns.exception.FormError): - wire = struct.pack('!BBBBIII', 0, t[0], t[1], t[2], - t[3], t[4], t[5]) - dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.LOC, - wire, 0, len(wire)) + wire = struct.pack("!BBBBIII", 0, t[0], t[1], t[2], t[3], t[4], t[5]) + dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.LOC, wire, 0, len(wire) + ) with self.assertRaises(dns.exception.FormError): - wire = struct.pack('!BBBBIII', 1, 0, 0, 0, 0, 0, 0) - dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.LOC, - wire, 0, len(wire)) + wire = struct.pack("!BBBBIII", 1, 0, 0, 0, 0, 0, 0) + dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.LOC, wire, 0, len(wire) + ) def equal_wks(self, a, b): rda = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.WKS, a) @@ -387,19 +420,20 @@ class RdataTestCase(unittest.TestCase): self.assertEqual(rda, rdb) def test_misc_good_WKS_text(self): - self.equal_wks('10.0.0.1 tcp ( http )', '10.0.0.1 6 ( 80 )') - self.equal_wks('10.0.0.1 udp ( domain )', '10.0.0.1 17 ( 53 )') + self.equal_wks("10.0.0.1 tcp ( http )", "10.0.0.1 6 ( 80 )") + self.equal_wks("10.0.0.1 udp ( domain )", "10.0.0.1 17 ( 53 )") def test_misc_bad_WKS_text(self): try: - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.WKS, - '10.0.0.1 132 ( domain )') + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.WKS, "10.0.0.1 132 ( domain )" + ) self.assertTrue(False) # should not happen except dns.exception.SyntaxError as e: self.assertIsInstance(e.__cause__, NotImplementedError) def test_GPOS_float_converters(self): - rd = dns.rdata.from_text('in', 'gpos', '49 0 0') + rd = dns.rdata.from_text("in", "gpos", "49 0 0") self.assertEqual(rd.float_latitude, 49.0) self.assertEqual(rd.float_longitude, 0.0) self.assertEqual(rd.float_altitude, 0.0) @@ -415,233 +449,233 @@ class RdataTestCase(unittest.TestCase): self.assertEqual(rd.float_altitude, 0.0) def test_bad_GPOS_text(self): - bad_gpos = ['"-" "116.8652" "250"', - '"+" "116.8652" "250"', - '"" "116.8652" "250"', - '"." "116.8652" "250"', - '".a" "116.8652" "250"', - '"a." "116.8652" "250"', - '"a.a" "116.8652" "250"', - # We don't need to test all the bad permutations again - # but we do want to test that badness is detected - # in the other strings - '"0" "a" "250"', - '"0" "0" "a"', - # finally test bounds - '"90.1" "0" "0"', - '"-90.1" "0" "0"', - '"0" "180.1" "0"', - '"0" "-180.1" "0"', - ] + bad_gpos = [ + '"-" "116.8652" "250"', + '"+" "116.8652" "250"', + '"" "116.8652" "250"', + '"." "116.8652" "250"', + '".a" "116.8652" "250"', + '"a." "116.8652" "250"', + '"a.a" "116.8652" "250"', + # We don't need to test all the bad permutations again + # but we do want to test that badness is detected + # in the other strings + '"0" "a" "250"', + '"0" "0" "a"', + # finally test bounds + '"90.1" "0" "0"', + '"-90.1" "0" "0"', + '"0" "180.1" "0"', + '"0" "-180.1" "0"', + ] for gpos in bad_gpos: with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.GPOS, gpos) def test_bad_GPOS_wire(self): - bad_gpos = [b'\x01', - b'\x01\x31\x01', - b'\x01\x31\x01\x31\x01', - ] + bad_gpos = [ + b"\x01", + b"\x01\x31\x01", + b"\x01\x31\x01\x31\x01", + ] for wire in bad_gpos: - self.assertRaises(dns.exception.FormError, - lambda: dns.rdata.from_wire(dns.rdataclass.IN, - dns.rdatatype.GPOS, - wire, 0, len(wire))) + self.assertRaises( + dns.exception.FormError, + lambda: dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.GPOS, wire, 0, len(wire) + ), + ) def test_chaos(self): # avoid red spot on our coverage :) - r1 = dns.rdata.from_text(dns.rdataclass.CH, dns.rdatatype.A, - 'chaos. 12345') + r1 = dns.rdata.from_text(dns.rdataclass.CH, dns.rdatatype.A, "chaos. 12345") w = r1.to_wire() - r2 = dns.rdata.from_wire(dns.rdataclass.CH, dns.rdatatype.A, w, 0, - len(w)) + r2 = dns.rdata.from_wire(dns.rdataclass.CH, dns.rdatatype.A, w, 0, len(w)) self.assertEqual(r1, r2) - self.assertEqual(r1.domain, dns.name.from_text('chaos')) + self.assertEqual(r1.domain, dns.name.from_text("chaos")) # the address input is octal self.assertEqual(r1.address, 0o12345) - self.assertEqual(r1.to_text(), 'chaos. 12345') + self.assertEqual(r1.to_text(), "chaos. 12345") def test_opt_repr(self): opt = OPT(4096, dns.rdatatype.OPT, ()) - self.assertEqual(repr(opt), '') + self.assertEqual(repr(opt), "") def test_opt_short_lengths(self): with self.assertRaises(dns.exception.FormError): - parser = dns.wire.Parser(bytes.fromhex('f00102')) + parser = dns.wire.Parser(bytes.fromhex("f00102")) OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser) with self.assertRaises(dns.exception.FormError): - parser = dns.wire.Parser(bytes.fromhex('f00100030000')) + parser = dns.wire.Parser(bytes.fromhex("f00100030000")) OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser) def test_from_wire_parser(self): - wire = bytes.fromhex('01020304') - rdata = dns.rdata.from_wire('in', 'a', wire, 0, 4) - self.assertEqual(rdata, dns.rdata.from_text('in', 'a', '1.2.3.4')) + wire = bytes.fromhex("01020304") + rdata = dns.rdata.from_wire("in", "a", wire, 0, 4) + self.assertEqual(rdata, dns.rdata.from_text("in", "a", "1.2.3.4")) def test_unpickle(self): - expected_mx = dns.rdata.from_text('in', 'mx', '10 mx.example.') - with open(here('mx-2-0.pickle'), 'rb') as f: + expected_mx = dns.rdata.from_text("in", "mx", "10 mx.example.") + with open(here("mx-2-0.pickle"), "rb") as f: mx = pickle.load(f) self.assertEqual(mx, expected_mx) self.assertIsNone(mx.rdcomment) def test_escaped_newline_in_quoted_string(self): - rd = dns.rdata.from_text('in', 'txt', '"foo\\\nbar"') - self.assertEqual(rd.strings, (b'foo\nbar',)) + rd = dns.rdata.from_text("in", "txt", '"foo\\\nbar"') + self.assertEqual(rd.strings, (b"foo\nbar",)) self.assertEqual(rd.to_text(), '"foo\\010bar"') def test_escaped_newline_in_nonquoted_string(self): with self.assertRaises(dns.exception.UnexpectedEnd): - dns.rdata.from_text('in', 'txt', 'foo\\\nbar') + dns.rdata.from_text("in", "txt", "foo\\\nbar") def test_wordbreak(self): - text = b'abcdefgh' - self.assertEqual(dns.rdata._wordbreak(text, 4), 'abcd efgh') - self.assertEqual(dns.rdata._wordbreak(text, 0), 'abcdefgh') + text = b"abcdefgh" + self.assertEqual(dns.rdata._wordbreak(text, 4), "abcd efgh") + self.assertEqual(dns.rdata._wordbreak(text, 0), "abcdefgh") def test_escapify(self): - self.assertEqual(dns.rdata._escapify('abc'), 'abc') - self.assertEqual(dns.rdata._escapify(b'abc'), 'abc') - self.assertEqual(dns.rdata._escapify(bytearray(b'abc')), 'abc') + self.assertEqual(dns.rdata._escapify("abc"), "abc") + self.assertEqual(dns.rdata._escapify(b"abc"), "abc") + self.assertEqual(dns.rdata._escapify(bytearray(b"abc")), "abc") self.assertEqual(dns.rdata._escapify(b'ab"c'), 'ab\\"c') - self.assertEqual(dns.rdata._escapify(b'ab\\c'), 'ab\\\\c') - self.assertEqual(dns.rdata._escapify(b'ab\x01c'), 'ab\\001c') + self.assertEqual(dns.rdata._escapify(b"ab\\c"), "ab\\\\c") + self.assertEqual(dns.rdata._escapify(b"ab\x01c"), "ab\\001c") def test_truncate_bitmap(self): - self.assertEqual(dns.rdata._truncate_bitmap(b'\x00\x01\x00\x00'), - b'\x00\x01') - self.assertEqual(dns.rdata._truncate_bitmap(b'\x00\x01\x00\x01'), - b'\x00\x01\x00\x01') - self.assertEqual(dns.rdata._truncate_bitmap(b'\x00\x00\x00\x00'), - b'\x00') + self.assertEqual(dns.rdata._truncate_bitmap(b"\x00\x01\x00\x00"), b"\x00\x01") + self.assertEqual( + dns.rdata._truncate_bitmap(b"\x00\x01\x00\x01"), b"\x00\x01\x00\x01" + ) + self.assertEqual(dns.rdata._truncate_bitmap(b"\x00\x00\x00\x00"), b"\x00") def test_covers_and_extended_rdatatype(self): - rd = dns.rdata.from_text('in', 'a', '10.0.0.1') + rd = dns.rdata.from_text("in", "a", "10.0.0.1") self.assertEqual(rd.covers(), dns.rdatatype.NONE) self.assertEqual(rd.extended_rdatatype(), 0x00000001) - rd = dns.rdata.from_text('in', 'rrsig', - 'NSEC 1 3 3600 ' + - '20200101000000 20030101000000 ' + - '2143 foo Ym9ndXM=') + rd = dns.rdata.from_text( + "in", + "rrsig", + "NSEC 1 3 3600 " + "20200101000000 20030101000000 " + "2143 foo Ym9ndXM=", + ) self.assertEqual(rd.covers(), dns.rdatatype.NSEC) - self.assertEqual(rd.extended_rdatatype(), 0x002f002e) + self.assertEqual(rd.extended_rdatatype(), 0x002F002E) def test_uncomparable(self): - rd = dns.rdata.from_text('in', 'a', '10.0.0.1') - self.assertFalse(rd == 'a') - self.assertTrue(rd != 'a') + rd = dns.rdata.from_text("in", "a", "10.0.0.1") + self.assertFalse(rd == "a") + self.assertTrue(rd != "a") def test_bad_generic(self): # does not start with \# with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'type45678', '# 7 000a03666f6f00') + dns.rdata.from_text("in", "type45678", "# 7 000a03666f6f00") # wrong length with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'type45678', '\\# 6 000a03666f6f00') + dns.rdata.from_text("in", "type45678", "\\# 6 000a03666f6f00") def test_empty_generic(self): - dns.rdata.from_text('in', 'type45678', r'\# 0') + dns.rdata.from_text("in", "type45678", r"\# 0") def test_covered_repr(self): - text = 'NSEC 1 3 3600 20190101000000 20030101000000 ' + \ - '2143 foo Ym9ndXM=' - rd = dns.rdata.from_text('in', 'rrsig', text) - self.assertEqual(repr(rd), '') + text = "NSEC 1 3 3600 20190101000000 20030101000000 " + "2143 foo Ym9ndXM=" + rd = dns.rdata.from_text("in", "rrsig", text) + self.assertEqual(repr(rd), "") def test_bad_registration_implementing_known_type_with_wrong_name(self): # Try to register an implementation at the MG codepoint that isn't # called "MG" with self.assertRaises(dns.rdata.RdatatypeExists): - dns.rdata.register_type(None, dns.rdatatype.MG, 'NOTMG') + dns.rdata.register_type(None, dns.rdatatype.MG, "NOTMG") def test_registration_implementing_known_type_with_right_name(self): # Try to register an implementation at the MD codepoint - dns.rdata.register_type(tests.md_module, dns.rdatatype.MD, 'MD') - rd = dns.rdata.from_text('in', 'md', 'foo.') - self.assertEqual(rd.target, dns.name.from_text('foo.')) + dns.rdata.register_type(tests.md_module, dns.rdatatype.MD, "MD") + rd = dns.rdata.from_text("in", "md", "foo.") + self.assertEqual(rd.target, dns.name.from_text("foo.")) def test_CERT_with_string_type(self): - rd = dns.rdata.from_text('in', 'cert', 'SPKI 1 PRIVATEOID Ym9ndXM=') - self.assertEqual(rd.to_text(), 'SPKI 1 PRIVATEOID Ym9ndXM=') + rd = dns.rdata.from_text("in", "cert", "SPKI 1 PRIVATEOID Ym9ndXM=") + self.assertEqual(rd.to_text(), "SPKI 1 PRIVATEOID Ym9ndXM=") def test_CERT_algorithm(self): - rd = dns.rdata.from_text('in', 'cert', 'SPKI 1 0 Ym9ndXM=') + rd = dns.rdata.from_text("in", "cert", "SPKI 1 0 Ym9ndXM=") self.assertEqual(rd.algorithm, 0) with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'cert', 'SPKI 1 -1 Ym9ndXM=') + dns.rdata.from_text("in", "cert", "SPKI 1 -1 Ym9ndXM=") with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'cert', 'SPKI 1 256 Ym9ndXM=') + dns.rdata.from_text("in", "cert", "SPKI 1 256 Ym9ndXM=") with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'cert', 'SPKI 1 BOGUS Ym9ndXM=') + dns.rdata.from_text("in", "cert", "SPKI 1 BOGUS Ym9ndXM=") def test_bad_URI_text(self): # empty target with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'uri', '10 1 ""') + dns.rdata.from_text("in", "uri", '10 1 ""') # no target with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'uri', '10 1') + dns.rdata.from_text("in", "uri", "10 1") def test_bad_URI_wire(self): - wire = bytes.fromhex('000a0001') + wire = bytes.fromhex("000a0001") with self.assertRaises(dns.exception.FormError): - dns.rdata.from_wire('in', 'uri', wire, 0, 4) + dns.rdata.from_wire("in", "uri", wire, 0, 4) def test_bad_NSAP_text(self): # does not start with 0x with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'nsap', '0y4700') + dns.rdata.from_text("in", "nsap", "0y4700") # odd hex string length with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'nsap', '0x470') + dns.rdata.from_text("in", "nsap", "0x470") def test_bad_CAA_text(self): # tag too long with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'caa', - '0 ' + 'a' * 256 + ' "ca.example.net"') + dns.rdata.from_text("in", "caa", "0 " + "a" * 256 + ' "ca.example.net"') # tag not alphanumeric with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'caa', - '0 a-b "ca.example.net"') + dns.rdata.from_text("in", "caa", '0 a-b "ca.example.net"') def test_bad_HIP_text(self): # hit too long with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'hip', - '2 ' + - '00' * 256 + - ' Ym9ndXM=') + dns.rdata.from_text("in", "hip", "2 " + "00" * 256 + " Ym9ndXM=") def test_bad_sigtime(self): try: - dns.rdata.from_text('in', 'rrsig', - 'NSEC 1 3 3600 ' + - '202001010000000 20030101000000 ' + - '2143 foo Ym9ndXM=') + dns.rdata.from_text( + "in", + "rrsig", + "NSEC 1 3 3600 " + + "202001010000000 20030101000000 " + + "2143 foo Ym9ndXM=", + ) self.assertTrue(False) # should not happen except dns.exception.SyntaxError as e: - self.assertIsInstance(e.__cause__, - dns.rdtypes.ANY.RRSIG.BadSigTime) + self.assertIsInstance(e.__cause__, dns.rdtypes.ANY.RRSIG.BadSigTime) try: - dns.rdata.from_text('in', 'rrsig', - 'NSEC 1 3 3600 ' + - '20200101000000 2003010100000 ' + - '2143 foo Ym9ndXM=') + dns.rdata.from_text( + "in", + "rrsig", + "NSEC 1 3 3600 " + + "20200101000000 2003010100000 " + + "2143 foo Ym9ndXM=", + ) self.assertTrue(False) # should not happen except dns.exception.SyntaxError as e: - self.assertIsInstance(e.__cause__, - dns.rdtypes.ANY.RRSIG.BadSigTime) + self.assertIsInstance(e.__cause__, dns.rdtypes.ANY.RRSIG.BadSigTime) def test_empty_TXT(self): # hit too long with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'txt', '') + dns.rdata.from_text("in", "txt", "") def test_too_long_TXT(self): # hit too long with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'txt', 'a' * 256) + dns.rdata.from_text("in", "txt", "a" * 256) def equal_smimea(self, a, b): a = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SMIMEA, a) @@ -649,63 +683,65 @@ class RdataTestCase(unittest.TestCase): self.assertEqual(a, b) def test_good_SMIMEA(self): - self.equal_smimea('3 0 1 aabbccddeeff', '3 0 01 AABBCCDDEEFF') + self.equal_smimea("3 0 1 aabbccddeeff", "3 0 01 AABBCCDDEEFF") def test_bad_SMIMEA(self): with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SMIMEA, '1 1 1 aGVsbG8gd29ybGQh') + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SMIMEA, "1 1 1 aGVsbG8gd29ybGQh" + ) def test_bad_APLItem_address_length(self): with self.assertRaises(ValueError): # 9999 is used in as an "unknown" address family. In the unlikely # event it is ever defined, we should switch the test to another # value. - dns.rdtypes.IN.APL.APLItem(9999, False, b'0xff' * 128, 255) + dns.rdtypes.IN.APL.APLItem(9999, False, b"0xff" * 128, 255) def test_DNSKEY_chunking(self): inputs = ( # each with chunking as given by dig, unusual chunking, and no chunking # example 1 ( - '257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iDSFZNORnQuHKtJ9Wpyz+kNryq uB78Pyk/NTEoai5bxoipVQQXzHlzyg==', - '257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocK mnS1iDSFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/ NTEoai5bxoipVQQXzHlzyg==', - '257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iDSFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/NTEoai5bxoipVQQXzHlzyg==', + "257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iDSFZNORnQuHKtJ9Wpyz+kNryq uB78Pyk/NTEoai5bxoipVQQXzHlzyg==", + "257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocK mnS1iDSFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/ NTEoai5bxoipVQQXzHlzyg==", + "257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iDSFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/NTEoai5bxoipVQQXzHlzyg==", ), # example 2 ( - '257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy9mvL5qGQTuaG5TSrNqEAR6b/ qvxDx6my4JmEmjUPA1JeEI9YfTUieMr2UZflu7aIbZFLw0vqiYrywCGr CHXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7xXiP3U5Ll 96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPri ec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAst bxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6as lO7jXv16Gws=', - '257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeq Hy9mvL5qGQTuaG5TSrNqEA R6b/qvxDx6my4JmEmjUPA1JeEI9Y fTUieMr2UZflu7aIbZFLw0vqiYrywCGrC HXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7 xXiP3U5Ll 96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPriec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAst bxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6as lO7jXv16Gws=', - '257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy9mvL5qGQTuaG5TSrNqEAR6b/qvxDx6my4JmEmjUPA1JeEI9YfTUieMr2UZflu7aIbZFLw0vqiYrywCGrCHXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7xXiP3U5Ll96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPriec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAstbxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6aslO7jXv16Gws=', + "257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy9mvL5qGQTuaG5TSrNqEAR6b/ qvxDx6my4JmEmjUPA1JeEI9YfTUieMr2UZflu7aIbZFLw0vqiYrywCGr CHXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7xXiP3U5Ll 96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPri ec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAst bxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6as lO7jXv16Gws=", + "257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeq Hy9mvL5qGQTuaG5TSrNqEA R6b/qvxDx6my4JmEmjUPA1JeEI9Y fTUieMr2UZflu7aIbZFLw0vqiYrywCGrC HXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7 xXiP3U5Ll 96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPriec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAst bxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6as lO7jXv16Gws=", + "257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy9mvL5qGQTuaG5TSrNqEAR6b/qvxDx6my4JmEmjUPA1JeEI9YfTUieMr2UZflu7aIbZFLw0vqiYrywCGrCHXLalOrEOmrvAxLvq4vHtuTlH7JIszzYBSes8g1vle6KG7xXiP3U5Ll96Qiu6bZ31rlMQSPB20xbqJJh6psNSrQs41QvdcXAej+K2Hl1Wd8kPriec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFaW2m7N/Wy4qcFU13roWKDEAstbxH5CHPoBfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lvu9TAiZPc0oysY6aslO7jXv16Gws=", ), # example 3 ( - '256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5AgB/2jmdR/+ 1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mx t6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+TLK l3D0L/cD', - '256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5 AgB/2jmdR/+1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+ TLKl3D0L/cD', - '256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5AgB/2jmdR/+1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+TLKl3D0L/cD', + "256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5AgB/2jmdR/+ 1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mx t6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+TLK l3D0L/cD", + "256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5 AgB/2jmdR/+1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+ TLKl3D0L/cD", + "256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5Ofv4akjQGN2zY5AgB/2jmdR/+1PvXFqzKCAGJv4wjABEBNWLLFm7ew1hHMDZEKVL17aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKaxT4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0P+2F+TLKl3D0L/cD", ), ) output_map = { 32: ( - '257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iD SFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/ NTEoai5bxoipVQQXzHlzyg==', - '257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy 9mvL5qGQTuaG5TSrNqEAR6b/qvxDx6my 4JmEmjUPA1JeEI9YfTUieMr2UZflu7aI bZFLw0vqiYrywCGrCHXLalOrEOmrvAxL vq4vHtuTlH7JIszzYBSes8g1vle6KG7x XiP3U5Ll96Qiu6bZ31rlMQSPB20xbqJJ h6psNSrQs41QvdcXAej+K2Hl1Wd8kPri ec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFa W2m7N/Wy4qcFU13roWKDEAstbxH5CHPo BfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lv u9TAiZPc0oysY6aslO7jXv16Gws=', - '256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5O fv4akjQGN2zY5AgB/2jmdR/+1PvXFqzK CAGJv4wjABEBNWLLFm7ew1hHMDZEKVL1 7aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKax T4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0 P+2F+TLKl3D0L/cD', + "257 3 13 aCoEWYBBVsP9Fek2oC8yqU8ocKmnS1iD SFZNORnQuHKtJ9Wpyz+kNryquB78Pyk/ NTEoai5bxoipVQQXzHlzyg==", + "257 3 8 AwEAAcw5QLr0IjC0wKbGoBPQv4qmeqHy 9mvL5qGQTuaG5TSrNqEAR6b/qvxDx6my 4JmEmjUPA1JeEI9YfTUieMr2UZflu7aI bZFLw0vqiYrywCGrCHXLalOrEOmrvAxL vq4vHtuTlH7JIszzYBSes8g1vle6KG7x XiP3U5Ll96Qiu6bZ31rlMQSPB20xbqJJ h6psNSrQs41QvdcXAej+K2Hl1Wd8kPri ec4AgiBEh8sk5Pp8W9ROLQ7PcbqqttFa W2m7N/Wy4qcFU13roWKDEAstbxH5CHPo BfZSbIwK4KM6BK/uDHpSPIbiOvOCW+lv u9TAiZPc0oysY6aslO7jXv16Gws=", + "256 3 8 AwEAAday3UX323uVzQqtOMQ7EHQYfD5O fv4akjQGN2zY5AgB/2jmdR/+1PvXFqzK CAGJv4wjABEBNWLLFm7ew1hHMDZEKVL1 7aml0EBKI6Dsz6Mxt6n7ScvLtHaFRKax T4i2JxiuVhKdQR9XGMiWAPQKrRM5SLG0 P+2F+TLKl3D0L/cD", ), 56: (t[0] for t in inputs), - 0: (t[0][:12] + t[0][12:].replace(' ', '') for t in inputs) + 0: (t[0][:12] + t[0][12:].replace(" ", "") for t in inputs), } for chunksize, outputs in output_map.items(): for input, output in zip(inputs, outputs): for input_variation in input: - rr = dns.rdata.from_text('IN', 'DNSKEY', input_variation) + rr = dns.rdata.from_text("IN", "DNSKEY", input_variation) new_text = rr.to_text(chunksize=chunksize) self.assertEqual(output, new_text) - + def test_relative_vs_absolute_compare_unstrict(self): try: saved = dns.rdata._allow_relative_comparisons dns.rdata._allow_relative_comparisons = True - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www.") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www") self.assertFalse(r1 == r2) self.assertTrue(r1 != r2) self.assertFalse(r1 < r2) @@ -723,18 +759,23 @@ class RdataTestCase(unittest.TestCase): try: saved = dns.rdata._allow_relative_comparisons dns.rdata._allow_relative_comparisons = False - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www.") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www") self.assertFalse(r1 == r2) self.assertTrue(r1 != r2) + def bad1(): r1 < r2 + def bad2(): r1 <= r2 + def bad3(): r1 > r2 + def bad4(): r1 >= r2 + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1) self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2) self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3) @@ -743,8 +784,8 @@ class RdataTestCase(unittest.TestCase): dns.rdata._allow_relative_comparisons = saved def test_absolute_vs_absolute_compare(self): - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx.') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www.") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "xxx.") self.assertFalse(r1 == r2) self.assertTrue(r1 != r2) self.assertTrue(r1 < r2) @@ -756,8 +797,8 @@ class RdataTestCase(unittest.TestCase): try: saved = dns.rdata._allow_relative_comparisons dns.rdata._allow_relative_comparisons = True - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "xxx") self.assertFalse(r1 == r2) self.assertTrue(r1 != r2) self.assertTrue(r1 < r2) @@ -771,18 +812,23 @@ class RdataTestCase(unittest.TestCase): try: saved = dns.rdata._allow_relative_comparisons dns.rdata._allow_relative_comparisons = False - r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') - r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx') + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "www") + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "xxx") self.assertFalse(r1 == r2) self.assertTrue(r1 != r2) + def bad1(): r1 < r2 + def bad2(): r1 <= r2 + def bad3(): r1 > r2 + def bad4(): r1 >= r2 + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1) self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2) self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3) @@ -790,15 +836,15 @@ class RdataTestCase(unittest.TestCase): finally: dns.rdata._allow_relative_comparisons = saved -class UtilTestCase(unittest.TestCase): +class UtilTestCase(unittest.TestCase): def test_Gateway_bad_type0(self): with self.assertRaises(SyntaxError): - dns.rdtypes.util.Gateway(0, 'bad.') + dns.rdtypes.util.Gateway(0, "bad.") def test_Gateway_bad_type3(self): with self.assertRaises(SyntaxError): - dns.rdtypes.util.Gateway(3, 'bad.') + dns.rdtypes.util.Gateway(3, "bad.") def test_Gateway_type4(self): with self.assertRaises(SyntaxError): @@ -808,7 +854,7 @@ class UtilTestCase(unittest.TestCase): def test_Bitmap(self): b = dns.rdtypes.util.Bitmap - tok = dns.tokenizer.Tokenizer('A MX') + tok = dns.tokenizer.Tokenizer("A MX") windows = b.from_text(tok).windows ba = bytearray() ba.append(0x40) # bit 1, for A @@ -817,7 +863,7 @@ class UtilTestCase(unittest.TestCase): def test_Bitmap_with_duplicate_types(self): b = dns.rdtypes.util.Bitmap - tok = dns.tokenizer.Tokenizer('A MX A A MX') + tok = dns.tokenizer.Tokenizer("A MX A A MX") windows = b.from_text(tok).windows ba = bytearray() ba.append(0x40) # bit 1, for A @@ -826,7 +872,7 @@ class UtilTestCase(unittest.TestCase): def test_Bitmap_with_out_of_order_types(self): b = dns.rdtypes.util.Bitmap - tok = dns.tokenizer.Tokenizer('MX A') + tok = dns.tokenizer.Tokenizer("MX A") windows = b.from_text(tok).windows ba = bytearray() ba.append(0x40) # bit 1, for A @@ -835,7 +881,7 @@ class UtilTestCase(unittest.TestCase): def test_Bitmap_zero_padding_works(self): b = dns.rdtypes.util.Bitmap - tok = dns.tokenizer.Tokenizer('SRV') + tok = dns.tokenizer.Tokenizer("SRV") windows = b.from_text(tok).windows ba = bytearray() ba.append(0) @@ -848,49 +894,48 @@ class UtilTestCase(unittest.TestCase): def test_Bitmap_has_type_0_set(self): b = dns.rdtypes.util.Bitmap with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('NONE A MX') + tok = dns.tokenizer.Tokenizer("NONE A MX") b.from_text(tok) def test_Bitmap_empty_window_not_written(self): b = dns.rdtypes.util.Bitmap - tok = dns.tokenizer.Tokenizer('URI CAA') # types 256 and 257 + tok = dns.tokenizer.Tokenizer("URI CAA") # types 256 and 257 windows = b.from_text(tok).windows ba = bytearray() - ba.append(0xc0) # bits 0 and 1 in window 1 + ba.append(0xC0) # bits 0 and 1 in window 1 self.assertEqual(windows, [(1, bytes(ba))]) def test_Bitmap_ok_parse(self): - parser = dns.wire.Parser(b'\x00\x01\x40') + parser = dns.wire.Parser(b"\x00\x01\x40") b = dns.rdtypes.util.Bitmap([]) windows = b.from_wire_parser(parser).windows - self.assertEqual(windows, [(0, b'@')]) + self.assertEqual(windows, [(0, b"@")]) def test_Bitmap_0_length_window_parse(self): - parser = dns.wire.Parser(b'\x00\x00') + parser = dns.wire.Parser(b"\x00\x00") with self.assertRaises(ValueError): b = dns.rdtypes.util.Bitmap([]) b.from_wire_parser(parser) def test_Bitmap_too_long_parse(self): - parser = dns.wire.Parser(b'\x00\x21' + b'\x01' * 33) + parser = dns.wire.Parser(b"\x00\x21" + b"\x01" * 33) with self.assertRaises(ValueError): b = dns.rdtypes.util.Bitmap([]) b.from_wire_parser(parser) def test_compressed_in_generic_is_bad(self): with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - r'\# 4 000aC000') + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, r"\# 4 000aC000") def test_rdataset_ttl_conversion(self): - rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds1 = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") self.assertEqual(rds1.ttl, 300) - rds2 = dns.rdataset.from_text('in', 'a', '5m', '10.0.0.1') + rds2 = dns.rdataset.from_text("in", "a", "5m", "10.0.0.1") self.assertEqual(rds2.ttl, 300) with self.assertRaises(ValueError): - dns.rdataset.from_text('in', 'a', 1.6, '10.0.0.1') + dns.rdataset.from_text("in", "a", 1.6, "10.0.0.1") with self.assertRaises(dns.ttl.BadTTL): - dns.rdataset.from_text('in', 'a', '10.0.0.1', '10.0.0.2') + dns.rdataset.from_text("in", "a", "10.0.0.1", "10.0.0.2") Rdata = dns.rdata.Rdata @@ -898,16 +943,16 @@ Rdata = dns.rdata.Rdata class RdataConvertersTestCase(unittest.TestCase): def test_as_name(self): - n = dns.name.from_text('hi') + n = dns.name.from_text("hi") self.assertEqual(Rdata._as_name(n), n) - self.assertEqual(Rdata._as_name('hi'), n) + self.assertEqual(Rdata._as_name("hi"), n) with self.assertRaises(ValueError): Rdata._as_name(100) def test_as_uint8(self): self.assertEqual(Rdata._as_uint8(0), 0) with self.assertRaises(ValueError): - Rdata._as_uint8('hi') + Rdata._as_uint8("hi") with self.assertRaises(ValueError): Rdata._as_uint8(-1) with self.assertRaises(ValueError): @@ -916,7 +961,7 @@ class RdataConvertersTestCase(unittest.TestCase): def test_as_uint16(self): self.assertEqual(Rdata._as_uint16(0), 0) with self.assertRaises(ValueError): - Rdata._as_uint16('hi') + Rdata._as_uint16("hi") with self.assertRaises(ValueError): Rdata._as_uint16(-1) with self.assertRaises(ValueError): @@ -925,25 +970,25 @@ class RdataConvertersTestCase(unittest.TestCase): def test_as_uint32(self): self.assertEqual(Rdata._as_uint32(0), 0) with self.assertRaises(ValueError): - Rdata._as_uint32('hi') + Rdata._as_uint32("hi") with self.assertRaises(ValueError): Rdata._as_uint32(-1) with self.assertRaises(ValueError): - Rdata._as_uint32(2 ** 32) + Rdata._as_uint32(2**32) def test_as_uint48(self): self.assertEqual(Rdata._as_uint48(0), 0) with self.assertRaises(ValueError): - Rdata._as_uint48('hi') + Rdata._as_uint48("hi") with self.assertRaises(ValueError): Rdata._as_uint48(-1) with self.assertRaises(ValueError): - Rdata._as_uint48(2 ** 48) + Rdata._as_uint48(2**48) def test_as_int(self): self.assertEqual(Rdata._as_int(0, 0, 10), 0) with self.assertRaises(ValueError): - Rdata._as_int('hi', 0, 10) + Rdata._as_int("hi", 0, 10) with self.assertRaises(ValueError): Rdata._as_int(-1, 0, 10) with self.assertRaises(ValueError): @@ -953,18 +998,19 @@ class RdataConvertersTestCase(unittest.TestCase): self.assertEqual(Rdata._as_bool(True), True) self.assertEqual(Rdata._as_bool(False), False) with self.assertRaises(ValueError): - Rdata._as_bool('hi') + Rdata._as_bool("hi") def test_as_ttl(self): self.assertEqual(Rdata._as_ttl(300), 300) - self.assertEqual(Rdata._as_ttl('5m'), 300) + self.assertEqual(Rdata._as_ttl("5m"), 300) self.assertEqual(Rdata._as_ttl(dns.ttl.MAX_TTL), dns.ttl.MAX_TTL) with self.assertRaises(dns.ttl.BadTTL): - Rdata._as_ttl('hi') + Rdata._as_ttl("hi") with self.assertRaises(ValueError): Rdata._as_ttl(1.9) with self.assertRaises(ValueError): Rdata._as_ttl(dns.ttl.MAX_TTL + 1) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py index 69ec6ded..4c602f89 100644 --- a/tests/test_rdataset.py +++ b/tests/test_rdataset.py @@ -9,123 +9,119 @@ import dns.rdataclass import dns.rdataset import dns.rdatatype -class RdatasetTestCase(unittest.TestCase): +class RdatasetTestCase(unittest.TestCase): def testCodec2003(self): - r1 = dns.rdataset.from_text_list('in', 'ns', 30, - ['Königsgäßchen']) - r2 = dns.rdataset.from_text_list('in', 'ns', 30, - ['xn--knigsgsschen-lcb0w']) + r1 = dns.rdataset.from_text_list("in", "ns", 30, ["Königsgäßchen"]) + r2 = dns.rdataset.from_text_list("in", "ns", 30, ["xn--knigsgsschen-lcb0w"]) self.assertEqual(r1, r2) - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testCodec2008(self): - r1 = dns.rdataset.from_text_list('in', 'ns', 30, - ['Königsgäßchen'], - idna_codec=dns.name.IDNA_2008) - r2 = dns.rdataset.from_text_list('in', 'ns', 30, - ['xn--knigsgchen-b4a3dun'], - idna_codec=dns.name.IDNA_2008) + r1 = dns.rdataset.from_text_list( + "in", "ns", 30, ["Königsgäßchen"], idna_codec=dns.name.IDNA_2008 + ) + r2 = dns.rdataset.from_text_list( + "in", "ns", 30, ["xn--knigsgchen-b4a3dun"], idna_codec=dns.name.IDNA_2008 + ) self.assertEqual(r1, r2) def testCopy(self): - r1 = dns.rdataset.from_text_list('in', 'a', 30, - ['10.0.0.1', '10.0.0.2']) + r1 = dns.rdataset.from_text_list("in", "a", 30, ["10.0.0.1", "10.0.0.2"]) r2 = r1.copy() self.assertFalse(r1 is r2) self.assertTrue(r1 == r2) def testAddIncompatible(self): rds = dns.rdataset.Rdataset(dns.rdataclass.IN, dns.rdatatype.A) - rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.AAAA, - '::1') + rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.AAAA, "::1") rds.add(rd1, 30) - self.assertRaises(dns.rdataset.IncompatibleTypes, - lambda: rds.add(rd2, 30)) + self.assertRaises(dns.rdataset.IncompatibleTypes, lambda: rds.add(rd2, 30)) def testDifferingCovers(self): - rds = dns.rdataset.Rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, - dns.rdatatype.A) + rds = dns.rdataset.Rdataset( + dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.A + ) rd1 = dns.rdata.from_text( - dns.rdataclass.IN, dns.rdatatype.RRSIG, - 'A 1 3 3600 20200101000000 20030101000000 2143 foo Ym9ndXM=') + dns.rdataclass.IN, + dns.rdatatype.RRSIG, + "A 1 3 3600 20200101000000 20030101000000 2143 foo Ym9ndXM=", + ) rd2 = dns.rdata.from_text( - dns.rdataclass.IN, dns.rdatatype.RRSIG, - 'AAAA 1 3 3600 20200101000000 20030101000000 2143 foo Ym9ndXM=') + dns.rdataclass.IN, + dns.rdatatype.RRSIG, + "AAAA 1 3 3600 20200101000000 20030101000000 2143 foo Ym9ndXM=", + ) rds.add(rd1, 30) - self.assertRaises(dns.rdataset.DifferingCovers, - lambda: rds.add(rd2, 30)) + self.assertRaises(dns.rdataset.DifferingCovers, lambda: rds.add(rd2, 30)) def testUnionUpdate(self): - rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') - rds2 = dns.rdataset.from_text('in', 'a', 30, '10.0.0.2') - rdse = dns.rdataset.from_text('in', 'a', 30, '10.0.0.1', '10.0.0.2') + rds1 = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") + rds2 = dns.rdataset.from_text("in", "a", 30, "10.0.0.2") + rdse = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2") rds1.union_update(rds2) self.assertEqual(rds1, rdse) def testIntersectionUpdate(self): - rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') - rds2 = dns.rdataset.from_text('in', 'a', 30, '10.0.0.2') - rdse = dns.rdataset.from_text('in', 'a', 30, '10.0.0.2') + rds1 = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") + rds2 = dns.rdataset.from_text("in", "a", 30, "10.0.0.2") + rdse = dns.rdataset.from_text("in", "a", 30, "10.0.0.2") rds1.intersection_update(rds2) self.assertEqual(rds1, rdse) def testNoEqualToOther(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") self.assertFalse(rds == 123) def testEmptyRdataList(self): - self.assertRaises(ValueError, - lambda: dns.rdataset.from_rdata_list(300, [])) + self.assertRaises(ValueError, lambda: dns.rdataset.from_rdata_list(300, [])) def testToTextNoName(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") text = rds.to_text() - self.assertEqual(text, '300 IN A 10.0.0.1') + self.assertEqual(text, "300 IN A 10.0.0.1") def testToTextOverrideClass(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") text = rds.to_text(override_rdclass=dns.rdataclass.NONE) - self.assertEqual(text, '300 NONE A 10.0.0.1') + self.assertEqual(text, "300 NONE A 10.0.0.1") def testRepr(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") self.assertEqual(repr(rds), "]>") def testTruncatedRepr(self): - rds = dns.rdataset.from_text('in', 'txt', 300, - 'a' * 200) + rds = dns.rdataset.from_text("in", "txt", 300, "a" * 200) # * 99 not * 100 below as the " counts as one of the 100 chars - self.assertEqual(repr(rds), - ']>') + self.assertEqual(repr(rds), ']>") def testStr(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") self.assertEqual(str(rds), "300 IN A 10.0.0.1") def testMultilineToText(self): - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") self.assertEqual(rds.to_text(), "300 IN A 10.0.0.1\n300 IN A 10.0.0.2") def testCoveredRepr(self): - rds = dns.rdataset.from_text('in', 'rrsig', 300, - 'NSEC 1 3 3600 ' + - '20190101000000 20030101000000 ' + - '2143 foo Ym9ndXM=') + rds = dns.rdataset.from_text( + "in", + "rrsig", + 300, + "NSEC 1 3 3600 " + "20190101000000 20030101000000 " + "2143 foo Ym9ndXM=", + ) # Using startswith as I don't care about the repr of the rdata, # just the covers - self.assertTrue(repr(rds).startswith( - ' None - '''Test that all defined flags are recognized.''' - good_s = {'SEP', 'REVOKE', 'ZONE'} +class RdtypeAnyDnskeyTestCase(unittest.TestCase): + def testFlagsAll(self): # type: () -> None + """Test that all defined flags are recognized.""" + good_s = {"SEP", "REVOKE", "ZONE"} good_f = 0x181 - self.assertEqual(dns.rdtypes.ANY.DNSKEY.SEP | - dns.rdtypes.ANY.DNSKEY.REVOKE | - dns.rdtypes.ANY.DNSKEY.ZONE, good_f) - - def testFlagsRRToText(self): # type: () -> None - '''Test that RR method returns correct flags.''' - rr = dns.rrset.from_text('foo', 300, 'IN', 'DNSKEY', '257 3 8 KEY=')[0] - self.assertEqual(dns.rdtypes.ANY.DNSKEY.ZONE | - dns.rdtypes.ANY.DNSKEY.SEP, - rr.flags) - - -if __name__ == '__main__': + self.assertEqual( + dns.rdtypes.ANY.DNSKEY.SEP + | dns.rdtypes.ANY.DNSKEY.REVOKE + | dns.rdtypes.ANY.DNSKEY.ZONE, + good_f, + ) + + def testFlagsRRToText(self): # type: () -> None + """Test that RR method returns correct flags.""" + rr = dns.rrset.from_text("foo", 300, "IN", "DNSKEY", "257 3 8 KEY=")[0] + self.assertEqual( + dns.rdtypes.ANY.DNSKEY.ZONE | dns.rdtypes.ANY.DNSKEY.SEP, rr.flags + ) + + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_rdtypeanyeui.py b/tests/test_rdtypeanyeui.py index 08527273..b65442a1 100644 --- a/tests/test_rdtypeanyeui.py +++ b/tests/test_rdtypeanyeui.py @@ -24,181 +24,161 @@ import dns.exception class RdtypeAnyEUI48TestCase(unittest.TestCase): def testInstOk(self): - '''Valid binary input.''' - eui = b'\x01\x23\x45\x67\x89\xab' - inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, - dns.rdatatype.EUI48, - eui) + """Valid binary input.""" + eui = b"\x01\x23\x45\x67\x89\xab" + inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, dns.rdatatype.EUI48, eui) self.assertEqual(inst.eui, eui) def testInstLength(self): - '''Incorrect input length.''' - eui = b'\x01\x23\x45\x67\x89\xab\xcd' + """Incorrect input length.""" + eui = b"\x01\x23\x45\x67\x89\xab\xcd" with self.assertRaises(dns.exception.FormError): - dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, - dns.rdatatype.EUI48, - eui) + dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, dns.rdatatype.EUI48, eui) def testFromTextOk(self): - '''Valid text input.''' - r1 = dns.rrset.from_text('foo', 300, 'IN', 'EUI48', - '01-23-45-67-89-ab') - eui = b'\x01\x23\x45\x67\x89\xab' + """Valid text input.""" + r1 = dns.rrset.from_text("foo", 300, "IN", "EUI48", "01-23-45-67-89-ab") + eui = b"\x01\x23\x45\x67\x89\xab" self.assertEqual(r1[0].eui, eui) def testFromTextLength(self): - '''Invalid input length.''' + """Invalid input length.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI48', - '00-01-23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI48", "00-01-23-45-67-89-ab") def testFromTextDelim(self): - '''Invalid delimiter.''' + """Invalid delimiter.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI48', '01_23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI48", "01_23-45-67-89-ab") def testFromTextExtraDash(self): - '''Extra dash instead of hex digit.''' + """Extra dash instead of hex digit.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI48', '0--23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI48", "0--23-45-67-89-ab") def testFromTextMultipleTokens(self): - '''Invalid input divided to multiple tokens.''' + """Invalid input divided to multiple tokens.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI48', '01 23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI48", "01 23-45-67-89-ab") def testFromTextInvalidHex(self): - '''Invalid hexadecimal input.''' + """Invalid hexadecimal input.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI48', 'g0-23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI48", "g0-23-45-67-89-ab") def testToTextOk(self): - '''Valid text output.''' - eui = b'\x01\x23\x45\x67\x89\xab' - exp_text = '01-23-45-67-89-ab' - inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, - dns.rdatatype.EUI48, - eui) + """Valid text output.""" + eui = b"\x01\x23\x45\x67\x89\xab" + exp_text = "01-23-45-67-89-ab" + inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, dns.rdatatype.EUI48, eui) text = inst.to_text() self.assertEqual(exp_text, text) def testToWire(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89\xab' - inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, - dns.rdatatype.EUI48, - eui) + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89\xab" + inst = dns.rdtypes.ANY.EUI48.EUI48(dns.rdataclass.IN, dns.rdatatype.EUI48, eui) self.assertEqual(inst.to_wire(), eui) def testFromWireOk(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89\xab' + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89\xab" pad_len = 100 - wire = b'x' * pad_len + eui + b'y' * pad_len * 2 - inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48, - wire, pad_len, len(eui)) + wire = b"x" * pad_len + eui + b"y" * pad_len * 2 + inst = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.EUI48, wire, pad_len, len(eui) + ) self.assertEqual(inst.eui, eui) def testFromWireLength(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89' + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89" pad_len = 100 - wire = b'x' * pad_len + eui + b'y' * pad_len * 2 + wire = b"x" * pad_len + eui + b"y" * pad_len * 2 with self.assertRaises(dns.exception.FormError): - dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48, - wire, pad_len, len(eui)) + dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.EUI48, wire, pad_len, len(eui) + ) class RdtypeAnyEUI64TestCase(unittest.TestCase): def testInstOk(self): - '''Valid binary input.''' - eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' - inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, - dns.rdatatype.EUI64, - eui) + """Valid binary input.""" + eui = b"\x01\x23\x45\x67\x89\xab\xcd\xef" + inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, dns.rdatatype.EUI64, eui) self.assertEqual(inst.eui, eui) def testInstLength(self): - '''Incorrect input length.''' - eui = b'\x01\x23\x45\x67\x89\xab' + """Incorrect input length.""" + eui = b"\x01\x23\x45\x67\x89\xab" with self.assertRaises(dns.exception.FormError): - dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, - dns.rdatatype.EUI64, - eui) + dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, dns.rdatatype.EUI64, eui) def testFromTextOk(self): - '''Valid text input.''' - r1 = dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - '01-23-45-67-89-ab-cd-ef') - eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' + """Valid text input.""" + r1 = dns.rrset.from_text("foo", 300, "IN", "EUI64", "01-23-45-67-89-ab-cd-ef") + eui = b"\x01\x23\x45\x67\x89\xab\xcd\xef" self.assertEqual(r1[0].eui, eui) def testFromTextLength(self): - '''Invalid input length.''' + """Invalid input length.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - '01-23-45-67-89-ab') + dns.rrset.from_text("foo", 300, "IN", "EUI64", "01-23-45-67-89-ab") def testFromTextDelim(self): - '''Invalid delimiter.''' + """Invalid delimiter.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - '01_23-45-67-89-ab-cd-ef') + dns.rrset.from_text("foo", 300, "IN", "EUI64", "01_23-45-67-89-ab-cd-ef") def testFromTextExtraDash(self): - '''Extra dash instead of hex digit.''' + """Extra dash instead of hex digit.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - '0--23-45-67-89-ab-cd-ef') + dns.rrset.from_text("foo", 300, "IN", "EUI64", "0--23-45-67-89-ab-cd-ef") def testFromTextMultipleTokens(self): - '''Invalid input divided to multiple tokens.''' + """Invalid input divided to multiple tokens.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - '01 23-45-67-89-ab-cd-ef') + dns.rrset.from_text("foo", 300, "IN", "EUI64", "01 23-45-67-89-ab-cd-ef") def testFromTextInvalidHex(self): - '''Invalid hexadecimal input.''' + """Invalid hexadecimal input.""" with self.assertRaises(dns.exception.SyntaxError): - dns.rrset.from_text('foo', 300, 'IN', 'EUI64', - 'g0-23-45-67-89-ab-cd-ef') + dns.rrset.from_text("foo", 300, "IN", "EUI64", "g0-23-45-67-89-ab-cd-ef") def testToTextOk(self): - '''Valid text output.''' - eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' - exp_text = '01-23-45-67-89-ab-cd-ef' - inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, - dns.rdatatype.EUI64, - eui) + """Valid text output.""" + eui = b"\x01\x23\x45\x67\x89\xab\xcd\xef" + exp_text = "01-23-45-67-89-ab-cd-ef" + inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, dns.rdatatype.EUI64, eui) text = inst.to_text() self.assertEqual(exp_text, text) def testToWire(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' - inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, - dns.rdatatype.EUI64, - eui) + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89\xab\xcd\xef" + inst = dns.rdtypes.ANY.EUI64.EUI64(dns.rdataclass.IN, dns.rdatatype.EUI64, eui) self.assertEqual(inst.to_wire(), eui) def testFromWireOk(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89\xab\xcd\xef" pad_len = 100 - wire = b'x' * pad_len + eui + b'y' * pad_len * 2 - inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64, - wire, pad_len, len(eui)) + wire = b"x" * pad_len + eui + b"y" * pad_len * 2 + inst = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.EUI64, wire, pad_len, len(eui) + ) self.assertEqual(inst.eui, eui) def testFromWireLength(self): - '''Valid wire format.''' - eui = b'\x01\x23\x45\x67\x89' + """Valid wire format.""" + eui = b"\x01\x23\x45\x67\x89" pad_len = 100 - wire = b'x' * pad_len + eui + b'y' * pad_len * 2 + wire = b"x" * pad_len + eui + b"y" * pad_len * 2 with self.assertRaises(dns.exception.FormError): - dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64, - wire, pad_len, len(eui)) + dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.EUI64, wire, pad_len, len(eui) + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_rdtypeanyloc.py b/tests/test_rdtypeanyloc.py index 23a1f68c..8fe210bf 100644 --- a/tests/test_rdtypeanyloc.py +++ b/tests/test_rdtypeanyloc.py @@ -19,52 +19,84 @@ import unittest import dns.rrset import dns.rdtypes.ANY.LOC -class RdtypeAnyLocTestCase(unittest.TestCase): +class RdtypeAnyLocTestCase(unittest.TestCase): def testEqual1(self): - '''Test default values for size, horizontal and vertical precision.''' - r1 = dns.rrset.from_text('foo', 300, 'IN', 'LOC', - '49 11 42.400 N 16 36 29.600 E 227.64m') - r2 = dns.rrset.from_text('FOO', 600, 'in', 'loc', - '49 11 42.400 N 16 36 29.600 E 227.64m ' - '1.00m 10000.00m 10.00m') + """Test default values for size, horizontal and vertical precision.""" + r1 = dns.rrset.from_text( + "foo", 300, "IN", "LOC", "49 11 42.400 N 16 36 29.600 E 227.64m" + ) + r2 = dns.rrset.from_text( + "FOO", + 600, + "in", + "loc", + "49 11 42.400 N 16 36 29.600 E 227.64m " "1.00m 10000.00m 10.00m", + ) self.assertEqual(r1, r2, '"{}" != "{}"'.format(r1, r2)) def testEqual2(self): - '''Test default values for size, horizontal and vertical precision.''' - r1 = dns.rdtypes.ANY.LOC.LOC(1, 29, (49, 11, 42, 400, 1), - (16, 36, 29, 600, 1), - 22764.0) # centimeters - r2 = dns.rdtypes.ANY.LOC.LOC(1, 29, (49, 11, 42, 400, 1), - (16, 36, 29, 600, 1), - 22764.0, # centimeters - 100.0, 1000000.00, 1000.0) # centimeters + """Test default values for size, horizontal and vertical precision.""" + r1 = dns.rdtypes.ANY.LOC.LOC( + 1, 29, (49, 11, 42, 400, 1), (16, 36, 29, 600, 1), 22764.0 + ) # centimeters + r2 = dns.rdtypes.ANY.LOC.LOC( + 1, + 29, + (49, 11, 42, 400, 1), + (16, 36, 29, 600, 1), + 22764.0, # centimeters + 100.0, + 1000000.00, + 1000.0, + ) # centimeters self.assertEqual(r1, r2, '"{}" != "{}"'.format(r1, r2)) def testEqual3(self): - '''Test size, horizontal and vertical precision parsers: 100 cm == 1 m. + """Test size, horizontal and vertical precision parsers: 100 cm == 1 m. - Parsers in from_text() and __init__() have to produce equal results.''' - r1 = dns.rdtypes.ANY.LOC.LOC(1, 29, (49, 11, 42, 400, 1), - (16, 36, 29, 600, 1), 22764.0, - 200.0, 1000.00, 200.0) # centimeters - r2 = dns.rrset.from_text('FOO', 600, 'in', 'loc', - '49 11 42.400 N 16 36 29.600 E 227.64m ' - '2.00m 10.00m 2.00m')[0] + Parsers in from_text() and __init__() have to produce equal results.""" + r1 = dns.rdtypes.ANY.LOC.LOC( + 1, + 29, + (49, 11, 42, 400, 1), + (16, 36, 29, 600, 1), + 22764.0, + 200.0, + 1000.00, + 200.0, + ) # centimeters + r2 = dns.rrset.from_text( + "FOO", + 600, + "in", + "loc", + "49 11 42.400 N 16 36 29.600 E 227.64m " "2.00m 10.00m 2.00m", + )[0] self.assertEqual(r1, r2, '"{}" != "{}"'.format(r1, r2)) def testEqual4(self): - '''Test size, horizontal and vertical precision parsers without unit. + """Test size, horizontal and vertical precision parsers without unit. Parsers in from_text() and __init__() have produce equal result - for values with and without trailing "m".''' - r1 = dns.rdtypes.ANY.LOC.LOC(1, 29, (49, 11, 42, 400, 1), - (16, 36, 29, 600, 1), 22764.0, - 200.0, 1000.00, 200.0) # centimeters - r2 = dns.rrset.from_text('FOO', 600, 'in', 'loc', - '49 11 42.400 N 16 36 29.600 E 227.64 ' - '2 10 2')[0] # meters without explicit unit + for values with and without trailing "m".""" + r1 = dns.rdtypes.ANY.LOC.LOC( + 1, + 29, + (49, 11, 42, 400, 1), + (16, 36, 29, 600, 1), + 22764.0, + 200.0, + 1000.00, + 200.0, + ) # centimeters + r2 = dns.rrset.from_text( + "FOO", 600, "in", "loc", "49 11 42.400 N 16 36 29.600 E 227.64 " "2 10 2" + )[ + 0 + ] # meters without explicit unit self.assertEqual(r1, r2, '"{}" != "{}"'.format(r1, r2)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_rdtypeanytkey.py b/tests/test_rdtypeanytkey.py index 3a3ca57d..f3d70071 100644 --- a/tests/test_rdtypeanytkey.py +++ b/tests/test_rdtypeanytkey.py @@ -27,70 +27,94 @@ from dns.rdatatype import RdataType class RdtypeAnyTKeyTestCase(unittest.TestCase): - tkey_rdata_text = 'gss-tsig. 1594203795 1594206664 3 0 KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY OTHEROTHEROTHEROTHEROTHEROTHEROT' - tkey_rdata_text_no_other = 'gss-tsig. 1594203795 1594206664 3 0 KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY' + tkey_rdata_text = "gss-tsig. 1594203795 1594206664 3 0 KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY OTHEROTHEROTHEROTHEROTHEROTHEROT" + tkey_rdata_text_no_other = ( + "gss-tsig. 1594203795 1594206664 3 0 KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY" + ) def testTextOptionalData(self): # construct the rdata from text and extract the TKEY tkey = dns.rdata.from_text( - RdataClass.ANY, RdataType.TKEY, - RdtypeAnyTKeyTestCase.tkey_rdata_text, origin='.') + RdataClass.ANY, + RdataType.TKEY, + RdtypeAnyTKeyTestCase.tkey_rdata_text, + origin=".", + ) self.assertEqual(type(tkey), dns.rdtypes.ANY.TKEY.TKEY) # go to text and compare tkey_out_text = tkey.to_text(relativize=False) - self.assertEqual(tkey_out_text, - RdtypeAnyTKeyTestCase.tkey_rdata_text) + self.assertEqual(tkey_out_text, RdtypeAnyTKeyTestCase.tkey_rdata_text) def testTextNoOptionalData(self): # construct the rdata from text and extract the TKEY tkey = dns.rdata.from_text( - RdataClass.ANY, RdataType.TKEY, - RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other, origin='.') + RdataClass.ANY, + RdataType.TKEY, + RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other, + origin=".", + ) self.assertEqual(type(tkey), dns.rdtypes.ANY.TKEY.TKEY) # go to text and compare tkey_out_text = tkey.to_text(relativize=False) - self.assertEqual(tkey_out_text, - RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other) + self.assertEqual(tkey_out_text, RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other) def testWireOptionalData(self): - key = base64.b64decode('KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY') - other = base64.b64decode('OTHEROTHEROTHEROTHEROTHEROTHEROT') + key = base64.b64decode("KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY") + other = base64.b64decode("OTHEROTHEROTHEROTHEROTHEROTHEROT") # construct the TKEY and compare the text output - tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY, - dns.rdatatype.TKEY, - dns.name.from_text('gss-tsig.'), - 1594203795, 1594206664, - 3, 0, key, other) - self.assertEqual(tkey.to_text(relativize=False), - RdtypeAnyTKeyTestCase.tkey_rdata_text) + tkey = dns.rdtypes.ANY.TKEY.TKEY( + dns.rdataclass.ANY, + dns.rdatatype.TKEY, + dns.name.from_text("gss-tsig."), + 1594203795, + 1594206664, + 3, + 0, + key, + other, + ) + self.assertEqual( + tkey.to_text(relativize=False), RdtypeAnyTKeyTestCase.tkey_rdata_text + ) # go to/from wire and compare the text output wire = tkey.to_wire() - tkey_out_wire = dns.rdata.from_wire(dns.rdataclass.ANY, - dns.rdatatype.TKEY, - wire, 0, len(wire)) - self.assertEqual(tkey_out_wire.to_text(relativize=False), - RdtypeAnyTKeyTestCase.tkey_rdata_text) + tkey_out_wire = dns.rdata.from_wire( + dns.rdataclass.ANY, dns.rdatatype.TKEY, wire, 0, len(wire) + ) + self.assertEqual( + tkey_out_wire.to_text(relativize=False), + RdtypeAnyTKeyTestCase.tkey_rdata_text, + ) def testWireNoOptionalData(self): - key = base64.b64decode('KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY') + key = base64.b64decode("KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY") # construct the TKEY with no 'other' data and compare the text output - tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY, - dns.rdatatype.TKEY, - dns.name.from_text('gss-tsig.'), - 1594203795, 1594206664, - 3, 0, key) - self.assertEqual(tkey.to_text(relativize=False), - RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other) + tkey = dns.rdtypes.ANY.TKEY.TKEY( + dns.rdataclass.ANY, + dns.rdatatype.TKEY, + dns.name.from_text("gss-tsig."), + 1594203795, + 1594206664, + 3, + 0, + key, + ) + self.assertEqual( + tkey.to_text(relativize=False), + RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other, + ) # go to/from wire and compare the text output wire = tkey.to_wire() - tkey_out_wire = dns.rdata.from_wire(dns.rdataclass.ANY, - dns.rdatatype.TKEY, - wire, 0, len(wire)) - self.assertEqual(tkey_out_wire.to_text(relativize=False), - RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other) + tkey_out_wire = dns.rdata.from_wire( + dns.rdataclass.ANY, dns.rdatatype.TKEY, wire, 0, len(wire) + ) + self.assertEqual( + tkey_out_wire.to_text(relativize=False), + RdtypeAnyTKeyTestCase.tkey_rdata_text_no_other, + ) diff --git a/tests/test_renderer.py b/tests/test_renderer.py index c60ccf95..ca5a85e6 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -9,8 +9,7 @@ import dns.renderer import dns.tsig import dns.tsigkeyring -basic_answer = \ - """flags QR +basic_answer = """flags QR edns 0 payload 4096 ;QUESTION @@ -20,12 +19,13 @@ foo.example. 30 IN A 10.0.0.1 foo.example. 30 IN A 10.0.0.2 """ + class RendererTestCase(unittest.TestCase): def test_basic(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) - qname = dns.name.from_text('foo.example') + qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) - rds = dns.rdataset.from_text('in', 'a', 30, '10.0.0.1', '10.0.0.2') + rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2") r.add_rdataset(dns.renderer.ANSWER, qname, rds) r.add_edns(0, 0, 4096) r.write_header() @@ -39,13 +39,14 @@ class RendererTestCase(unittest.TestCase): def test_tsig(self): r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) - qname = dns.name.from_text('foo.example') + qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) - keyring = dns.tsigkeyring.from_text({'key' : '12345678'}) + keyring = dns.tsigkeyring.from_text({"key": "12345678"}) keyname = next(iter(keyring)) r.write_header() - r.add_tsig(keyname, keyring[keyname], 300, r.id, 0, b'', b'', - dns.tsig.HMAC_SHA256) + r.add_tsig( + keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256 + ) wire = r.get_wire() message = dns.message.from_wire(wire, keyring=keyring) expected = dns.message.make_query(qname, dns.rdatatype.A) @@ -53,15 +54,24 @@ class RendererTestCase(unittest.TestCase): self.assertEqual(message, expected) def test_multi_tsig(self): - qname = dns.name.from_text('foo.example') - keyring = dns.tsigkeyring.from_text({'key' : '12345678'}) + qname = dns.name.from_text("foo.example") + keyring = dns.tsigkeyring.from_text({"key": "12345678"}) keyname = next(iter(keyring)) r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) r.add_question(qname, dns.rdatatype.A) r.write_header() - ctx = r.add_multi_tsig(None, keyname, keyring[keyname], 300, r.id, 0, - b'', b'', dns.tsig.HMAC_SHA256) + ctx = r.add_multi_tsig( + None, + keyname, + keyring[keyname], + 300, + r.id, + 0, + b"", + b"", + dns.tsig.HMAC_SHA256, + ) wire = r.get_wire() message = dns.message.from_wire(wire, keyring=keyring, multi=True) expected = dns.message.make_query(qname, dns.rdatatype.A) @@ -71,22 +81,25 @@ class RendererTestCase(unittest.TestCase): r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) r.add_question(qname, dns.rdatatype.A) r.write_header() - ctx = r.add_multi_tsig(ctx, keyname, keyring[keyname], 300, r.id, 0, - b'', b'', dns.tsig.HMAC_SHA256) + ctx = r.add_multi_tsig( + ctx, keyname, keyring[keyname], 300, r.id, 0, b"", b"", dns.tsig.HMAC_SHA256 + ) wire = r.get_wire() - message = dns.message.from_wire(wire, keyring=keyring, - tsig_ctx=message.tsig_ctx, multi=True) + message = dns.message.from_wire( + wire, keyring=keyring, tsig_ctx=message.tsig_ctx, multi=True + ) expected = dns.message.make_query(qname, dns.rdatatype.A) expected.id = message.id self.assertEqual(message, expected) - def test_going_backwards_fails(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) - qname = dns.name.from_text('foo.example') + qname = dns.name.from_text("foo.example") r.add_question(qname, dns.rdatatype.A) r.add_edns(0, 0, 4096) - rds = dns.rdataset.from_text('in', 'a', 30, '10.0.0.1', '10.0.0.2') + rds = dns.rdataset.from_text("in", "a", 30, "10.0.0.1", "10.0.0.2") + def bad(): r.add_rdataset(dns.renderer.ANSWER, qname, rds) + self.assertRaises(dns.exception.FormError, bad) diff --git a/tests/test_resolution.py b/tests/test_resolution.py index 731090be..d2819a12 100644 --- a/tests/test_resolution.py +++ b/tests/test_resolution.py @@ -13,15 +13,16 @@ import dns.tsigkeyring # Test the resolver's Resolution, i.e. the business logic of the resolver. + class ResolutionTestCase(unittest.TestCase): def setUp(self): self.resolver = dns.resolver.Resolver(configure=False) - self.resolver.nameservers = ['10.0.0.1', '10.0.0.2'] - self.resolver.domain = dns.name.from_text('example') - self.qname = dns.name.from_text('www.dnspython.org') - self.resn = dns.resolver._Resolution(self.resolver, self.qname, - 'A', 'IN', - False, True, False) + self.resolver.nameservers = ["10.0.0.1", "10.0.0.2"] + self.resolver.domain = dns.name.from_text("example") + self.qname = dns.name.from_text("www.dnspython.org") + self.resn = dns.resolver._Resolution( + self.resolver, self.qname, "A", "IN", False, True, False + ) def test_next_request_abs(self): (request, answer) = self.resn.next_request() @@ -30,11 +31,11 @@ class ResolutionTestCase(unittest.TestCase): self.assertEqual(request.question[0].rdtype, dns.rdatatype.A) def test_next_request_rel_with_search(self): - qname = dns.name.from_text('www.dnspython.org', None) - abs_qname_1 = dns.name.from_text('www.dnspython.org.example') - self.resn = dns.resolver._Resolution(self.resolver, qname, - 'A', 'IN', - False, True, True) + qname = dns.name.from_text("www.dnspython.org", None) + abs_qname_1 = dns.name.from_text("www.dnspython.org.example") + self.resn = dns.resolver._Resolution( + self.resolver, qname, "A", "IN", False, True, True + ) (request, answer) = self.resn.next_request() self.assertTrue(answer is None) self.assertEqual(request.question[0].name, self.qname) @@ -43,44 +44,60 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(answer is None) self.assertEqual(request.question[0].name, abs_qname_1) self.assertEqual(request.question[0].rdtype, dns.rdatatype.A) + def bad(): (request, answer) = self.resn.next_request() + self.assertRaises(dns.resolver.NXDOMAIN, bad) def test_next_request_rel_without_search(self): - qname = dns.name.from_text('www.dnspython.org', None) - abs_qname_1 = dns.name.from_text('www.dnspython.org.example') - self.resn = dns.resolver._Resolution(self.resolver, qname, - 'A', 'IN', - False, True, False) + qname = dns.name.from_text("www.dnspython.org", None) + abs_qname_1 = dns.name.from_text("www.dnspython.org.example") + self.resn = dns.resolver._Resolution( + self.resolver, qname, "A", "IN", False, True, False + ) (request, answer) = self.resn.next_request() self.assertTrue(answer is None) self.assertEqual(request.question[0].name, self.qname) self.assertEqual(request.question[0].rdtype, dns.rdatatype.A) + def bad(): (request, answer) = self.resn.next_request() + self.assertRaises(dns.resolver.NXDOMAIN, bad) def test_next_request_exhaust_causes_nxdomain(self): def bad(): (request, answer) = self.resn.next_request() + (request, answer) = self.resn.next_request() self.assertRaises(dns.resolver.NXDOMAIN, bad) def make_address_response(self, q): r = dns.message.make_response(q) - rrs = r.get_rrset(r.answer, self.qname, dns.rdataclass.IN, - dns.rdatatype.A, create=True) - rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1'), 300) + rrs = r.get_rrset( + r.answer, self.qname, dns.rdataclass.IN, dns.rdatatype.A, create=True + ) + rrs.add( + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), 300 + ) return r def make_negative_response(self, q, nxdomain=False): r = dns.message.make_response(q) - rrs = r.get_rrset(r.authority, q.question[0].name, dns.rdataclass.IN, - dns.rdatatype.SOA, create=True) - rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - '. . 1 2 3 4 300'), 300) + rrs = r.get_rrset( + r.authority, + q.question[0].name, + dns.rdataclass.IN, + dns.rdatatype.SOA, + create=True, + ) + rrs.add( + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, ". . 1 2 3 4 300" + ), + 300, + ) if nxdomain: r.set_rcode(dns.rcode.NXDOMAIN) return r @@ -89,26 +106,33 @@ class ResolutionTestCase(unittest.TestCase): r = dns.message.make_response(q) name = self.qname for i in range(count): - rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN, - dns.rdatatype.CNAME, create=True) - tname = dns.name.from_text(f'target{i}.') - rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, - str(tname)), 300) + rrs = r.get_rrset( + r.answer, name, dns.rdataclass.IN, dns.rdatatype.CNAME, create=True + ) + tname = dns.name.from_text(f"target{i}.") + rrs.add( + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, str(tname)), + 300, + ) name = tname - rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN, - dns.rdatatype.A, create=True) - rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1'), 300) + rrs = r.get_rrset( + r.answer, name, dns.rdataclass.IN, dns.rdatatype.A, create=True + ) + rrs.add( + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), 300 + ) return r def test_next_request_cache_hit(self): self.resolver.cache = dns.resolver.Cache() q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_address_response(q) - cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A, - dns.rdataclass.IN, r) - self.resolver.cache.put((self.qname, dns.rdatatype.A, - dns.rdataclass.IN), cache_answer) + cache_answer = dns.resolver.Answer( + self.qname, dns.rdatatype.A, dns.rdataclass.IN, r + ) + self.resolver.cache.put( + (self.qname, dns.rdatatype.A, dns.rdataclass.IN), cache_answer + ) (request, answer) = self.resn.next_request() self.assertTrue(request is None) self.assertTrue(answer is cache_answer) @@ -120,36 +144,42 @@ class ResolutionTestCase(unittest.TestCase): # Note we need an SOA so the cache doesn't expire the answer # immediately, but our negative response code does that. r = self.make_negative_response(q) - cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A, - dns.rdataclass.IN, r) - self.resolver.cache.put((self.qname, dns.rdatatype.A, - dns.rdataclass.IN), cache_answer) + cache_answer = dns.resolver.Answer( + self.qname, dns.rdatatype.A, dns.rdataclass.IN, r + ) + self.resolver.cache.put( + (self.qname, dns.rdatatype.A, dns.rdataclass.IN), cache_answer + ) + def bad(): (request, answer) = self.resn.next_request() + self.assertRaises(dns.resolver.NoAnswer, bad) # If raise_on_no_answer is False, we should get a cache hit. - self.resn = dns.resolver._Resolution(self.resolver, self.qname, - 'A', 'IN', - False, False, False) + self.resn = dns.resolver._Resolution( + self.resolver, self.qname, "A", "IN", False, False, False + ) (request, answer) = self.resn.next_request() self.assertTrue(request is None) self.assertTrue(answer is cache_answer) def test_next_request_cached_nxdomain_without_search(self): # use a relative qname - qname = dns.name.from_text('www.dnspython.org', None) - self.resn = dns.resolver._Resolution(self.resolver, qname, - 'A', 'IN', - False, True, False) - qname1 = dns.name.from_text('www.dnspython.org.') + qname = dns.name.from_text("www.dnspython.org", None) + self.resn = dns.resolver._Resolution( + self.resolver, qname, "A", "IN", False, True, False + ) + qname1 = dns.name.from_text("www.dnspython.org.") # Arrange to get NXDOMAIN hits on it. self.resolver.cache = dns.resolver.Cache() q1 = dns.message.make_query(qname1, dns.rdatatype.A) r1 = self.make_negative_response(q1, True) - cache_answer = dns.resolver.Answer(qname1, dns.rdatatype.ANY, - dns.rdataclass.IN, r1) - self.resolver.cache.put((qname1, dns.rdatatype.ANY, - dns.rdataclass.IN), cache_answer) + cache_answer = dns.resolver.Answer( + qname1, dns.rdatatype.ANY, dns.rdataclass.IN, r1 + ) + self.resolver.cache.put( + (qname1, dns.rdatatype.ANY, dns.rdataclass.IN), cache_answer + ) try: (request, answer) = self.resn.next_request() self.assertTrue(False) # should not happen! @@ -158,27 +188,31 @@ class ResolutionTestCase(unittest.TestCase): def test_next_request_cached_nxdomain_with_search(self): # use a relative qname so we have two qnames to try - qname = dns.name.from_text('www.dnspython.org', None) + qname = dns.name.from_text("www.dnspython.org", None) # also enable search mode or we'll only see www.dnspython.org. - self.resn = dns.resolver._Resolution(self.resolver, qname, - 'A', 'IN', - False, True, True) - qname1 = dns.name.from_text('www.dnspython.org.example.') - qname2 = dns.name.from_text('www.dnspython.org.') + self.resn = dns.resolver._Resolution( + self.resolver, qname, "A", "IN", False, True, True + ) + qname1 = dns.name.from_text("www.dnspython.org.example.") + qname2 = dns.name.from_text("www.dnspython.org.") # Arrange to get NXDOMAIN hits on both of those qnames. self.resolver.cache = dns.resolver.Cache() q1 = dns.message.make_query(qname1, dns.rdatatype.A) r1 = self.make_negative_response(q1, True) - cache_answer = dns.resolver.Answer(qname1, dns.rdatatype.ANY, - dns.rdataclass.IN, r1) - self.resolver.cache.put((qname1, dns.rdatatype.ANY, - dns.rdataclass.IN), cache_answer) + cache_answer = dns.resolver.Answer( + qname1, dns.rdatatype.ANY, dns.rdataclass.IN, r1 + ) + self.resolver.cache.put( + (qname1, dns.rdatatype.ANY, dns.rdataclass.IN), cache_answer + ) q2 = dns.message.make_query(qname2, dns.rdatatype.A) r2 = self.make_negative_response(q2, True) - cache_answer = dns.resolver.Answer(qname2, dns.rdatatype.ANY, - dns.rdataclass.IN, r2) - self.resolver.cache.put((qname2, dns.rdatatype.ANY, - dns.rdataclass.IN), cache_answer) + cache_answer = dns.resolver.Answer( + qname2, dns.rdatatype.ANY, dns.rdataclass.IN, r2 + ) + self.resolver.cache.put( + (qname2, dns.rdatatype.ANY, dns.rdataclass.IN), cache_answer + ) try: (request, answer) = self.resn.next_request() self.assertTrue(False) # should not happen! @@ -188,8 +222,8 @@ class ResolutionTestCase(unittest.TestCase): def test_next_request_rotate(self): self.resolver.rotate = True - order1 = ['10.0.0.1', '10.0.0.2'] - order2 = ['10.0.0.2', '10.0.0.1'] + order1 = ["10.0.0.1", "10.0.0.2"] + order2 = ["10.0.0.2", "10.0.0.1"] seen1 = False seen2 = False # We're not interested in testing the randomness, but we'd @@ -197,9 +231,9 @@ class ResolutionTestCase(unittest.TestCase): # both orders at least once. This test can fail even with # correct code, but it is *extremely* unlikely. for count in range(0, 50): - self.resn = dns.resolver._Resolution(self.resolver, self.qname, - 'A', 'IN', - False, True, False) + self.resn = dns.resolver._Resolution( + self.resolver, self.qname, "A", "IN", False, True, False + ) self.resn.next_request() if self.resn.nameservers == order1: seen1 = True @@ -212,11 +246,11 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(seen1 and seen2) def test_next_request_TSIG(self): - self.resolver.keyring = dns.tsigkeyring.from_text({ - 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ==' - }) + self.resolver.keyring = dns.tsigkeyring.from_text( + {"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="} + ) (keyname, secret) = next(iter(self.resolver.keyring.items())) - self.resolver.keyname = dns.name.from_text('keyname.') + self.resolver.keyname = dns.name.from_text("keyname.") (request, answer) = self.resn.next_request() self.assertFalse(request is None) self.assertEqual(request.keyring.name, keyname) @@ -283,16 +317,22 @@ class ResolutionTestCase(unittest.TestCase): self.resn.nameservers.remove(nameserver) (nameserver, _, _, _) = self.resn.next_nameserver() self.resn.nameservers.remove(nameserver) + def bad(): (nameserver, _, _, _) = self.resn.next_nameserver() + self.assertRaises(dns.resolver.NoNameservers, bad) def test_query_result_nameserver_removing_exceptions(self): # add some nameservers so we have enough to remove :) - self.resolver.nameservers.extend(['10.0.0.3', '10.0.0.4']) + self.resolver.nameservers.extend(["10.0.0.3", "10.0.0.4"]) (request, _) = self.resn.next_request() - exceptions = [dns.exception.FormError(), EOFError(), - NotImplementedError(), dns.message.Truncated()] + exceptions = [ + dns.exception.FormError(), + EOFError(), + NotImplementedError(), + dns.message.Truncated(), + ] for i in range(4): (nameserver, _, _, _) = self.resn.next_nameserver() if i == 3: @@ -349,8 +389,9 @@ class ResolutionTestCase(unittest.TestCase): (_, _, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertFalse(answer is None) - cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.A, - dns.rdataclass.IN)) + cache_answer = self.resolver.cache.get( + (self.qname, dns.rdatatype.A, dns.rdataclass.IN) + ) self.assertTrue(answer is cache_answer) def test_query_result_no_error_no_data(self): @@ -358,8 +399,10 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_negative_response(q) (_, _) = self.resn.next_request() (_, _, _, _) = self.resn.next_nameserver() + def bad(): (answer, done) = self.resn.query_result(r, None) + self.assertRaises(dns.resolver.NoAnswer, bad) def test_query_result_nxdomain(self): @@ -410,8 +453,9 @@ class ResolutionTestCase(unittest.TestCase): (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) self.assertTrue(done) - cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.ANY, - dns.rdataclass.IN)) + cache_answer = self.resolver.cache.get( + (self.qname, dns.rdatatype.ANY, dns.rdataclass.IN) + ) self.assertTrue(cache_answer.response is r) def test_query_result_yxdomain(self): @@ -420,8 +464,10 @@ class ResolutionTestCase(unittest.TestCase): r.set_rcode(dns.rcode.YXDOMAIN) (_, _) = self.resn.next_request() (_, _, _, _) = self.resn.next_nameserver() + def bad(): (answer, done) = self.resn.query_result(r, None) + self.assertRaises(dns.resolver.YXDOMAIN, bad) def test_query_result_servfail_no_retry(self): @@ -461,12 +507,14 @@ class ResolutionTestCase(unittest.TestCase): def test_no_metaqueries(self): def bad1(): - self.resn = dns.resolver._Resolution(self.resolver, self.qname, - 'ANY', 'IN', - False, True, False) + self.resn = dns.resolver._Resolution( + self.resolver, self.qname, "ANY", "IN", False, True, False + ) + def bad2(): - self.resn = dns.resolver._Resolution(self.resolver, self.qname, - 'A', 'ANY', - False, True, False) + self.resn = dns.resolver._Resolution( + self.resolver, self.qname, "A", "ANY", False, True, False + ) + self.assertRaises(dns.resolver.NoMetaqueries, bad1) self.assertRaises(dns.resolver.NoMetaqueries, bad2) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 0f6a6384..dc2dde4a 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -37,7 +37,7 @@ import dns.tsigkeyring # skip those if it's not there. _network_available = True try: - socket.gethostbyname('dnspython.org') + socket.gethostbyname("dnspython.org") except socket.gaierror: _network_available = False @@ -46,25 +46,28 @@ except socket.gaierror: # those tests. try: from .nanonameserver import Server + _nanonameserver_available = True except ImportError: _nanonameserver_available = False + class Server(object): pass + # Look for systemd-resolved, as it does dangling CNAME responses incorrectly. # # Currently we simply check if the nameserver is 127.0.0.53. _systemd_resolved_present = False try: _resolver = dns.resolver.Resolver() - if _resolver.nameservers == ['127.0.0.53']: + if _resolver.nameservers == ["127.0.0.53"]: _systemd_resolved_present = True except Exception: pass -resolv_conf = u""" +resolv_conf = """ /t/t # comment 1 ; comment 2 @@ -220,19 +223,18 @@ class FakeTime: class BaseResolverTests(unittest.TestCase): - def testRead(self): f = StringIO(resolv_conf) r = dns.resolver.Resolver(configure=False) r.read_resolv_conf(f) - self.assertEqual(r.nameservers, ['10.0.0.1', '10.0.0.2']) - self.assertEqual(r.domain, dns.name.from_text('foo')) + self.assertEqual(r.nameservers, ["10.0.0.1", "10.0.0.2"]) + self.assertEqual(r.domain, dns.name.from_text("foo")) def testReadOptions(self): f = StringIO(resolv_conf_options1) r = dns.resolver.Resolver(configure=False) r.read_resolv_conf(f) - self.assertEqual(r.nameservers, ['10.0.0.1', '10.0.0.2']) + self.assertEqual(r.nameservers, ["10.0.0.1", "10.0.0.2"]) self.assertTrue(r.rotate) self.assertEqual(r.timeout, 1) self.assertEqual(r.ndots, 2) @@ -275,140 +277,147 @@ class BaseResolverTests(unittest.TestCase): f = StringIO(unknown_and_bad_directives) r = dns.resolver.Resolver(configure=False) r.read_resolv_conf(f) - self.assertEqual(r.nameservers, ['10.0.0.1']) + self.assertEqual(r.nameservers, ["10.0.0.1"]) def testReadUnknownOption(self): # The real test here is ignoring the unknown option f = StringIO(unknown_option) r = dns.resolver.Resolver(configure=False) r.read_resolv_conf(f) - self.assertEqual(r.nameservers, ['10.0.0.1']) + self.assertEqual(r.nameservers, ["10.0.0.1"]) def testCacheExpiration(self): with FakeTime() as fake_time: message = dns.message.from_text(message_text) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.A, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer( + name, dns.rdatatype.A, dns.rdataclass.IN, message + ) cache = dns.resolver.Cache() cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) fake_time.sleep(2) - self.assertTrue(cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) def testCacheCleaning(self): with FakeTime() as fake_time: message = dns.message.from_text(message_text) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.A, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer( + name, dns.rdatatype.A, dns.rdataclass.IN, message + ) cache = dns.resolver.Cache(cleaning_interval=1.0) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) fake_time.sleep(2) cache._maybe_clean() - self.assertTrue(cache.data.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + cache.data.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) def testCacheNonCleaning(self): with FakeTime() as fake_time: message = dns.message.from_text(message_text) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.A, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer( + name, dns.rdatatype.A, dns.rdataclass.IN, message + ) # override TTL as we're testing non-cleaning answer.expiration = fake_time.time() + 100 cache = dns.resolver.Cache(cleaning_interval=1.0) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) fake_time.sleep(1.1) - self.assertEqual(cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)), answer) + self.assertEqual( + cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)), answer + ) def testIndexErrorOnEmptyRRsetAccess(self): def bad(): message = dns.message.from_text(message_text_mx) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.MX, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer( + name, dns.rdatatype.MX, dns.rdataclass.IN, message + ) return answer[0] + self.assertRaises(IndexError, bad) def testIndexErrorOnEmptyRRsetDelete(self): def bad(): message = dns.message.from_text(message_text_mx) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.MX, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer( + name, dns.rdatatype.MX, dns.rdataclass.IN, message + ) del answer[0] + self.assertRaises(IndexError, bad) def testRRsetDelete(self): message = dns.message.from_text(message_text) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.A, - dns.rdataclass.IN, message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer(name, dns.rdatatype.A, dns.rdataclass.IN, message) del answer[0] self.assertEqual(len(answer), 0) def testLRUReplace(self): cache = dns.resolver.LRUCache(4) for i in range(0, 5): - name = dns.name.from_text('example%d.' % i) + name = dns.name.from_text("example%d." % i) answer = FakeAnswer(time.time() + 1) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) for i in range(0, 5): - name = dns.name.from_text('example%d.' % i) + name = dns.name.from_text("example%d." % i) if i == 0: - self.assertTrue(cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) else: - self.assertTrue(not cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + not cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) def testLRUDoesLRU(self): cache = dns.resolver.LRUCache(4) for i in range(0, 4): - name = dns.name.from_text('example%d.' % i) + name = dns.name.from_text("example%d." % i) answer = FakeAnswer(time.time() + 1) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) - name = dns.name.from_text('example0.') + name = dns.name.from_text("example0.") cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) # The LRU is now example1. - name = dns.name.from_text('example4.') + name = dns.name.from_text("example4.") answer = FakeAnswer(time.time() + 1) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) for i in range(0, 5): - name = dns.name.from_text('example%d.' % i) + name = dns.name.from_text("example%d." % i) if i == 1: - self.assertTrue(cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) else: - self.assertTrue(not cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + self.assertTrue( + not cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) def testLRUExpiration(self): with FakeTime() as fake_time: cache = dns.resolver.LRUCache(4) for i in range(0, 4): - name = dns.name.from_text('example%d.' % i) + name = dns.name.from_text("example%d." % i) answer = FakeAnswer(time.time() + 1) cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer) fake_time.sleep(2) for i in range(0, 4): - name = dns.name.from_text('example%d.' % i) - self.assertTrue(cache.get((name, dns.rdatatype.A, - dns.rdataclass.IN)) - is None) + name = dns.name.from_text("example%d." % i) + self.assertTrue( + cache.get((name, dns.rdatatype.A, dns.rdataclass.IN)) is None + ) def test_cache_flush(self): - name1 = dns.name.from_text('name1') - name2 = dns.name.from_text('name2') - name3 = dns.name.from_text('name3') + name1 = dns.name.from_text("name1") + name2 = dns.name.from_text("name2") + name3 = dns.name.from_text("name3") basic_cache = dns.resolver.Cache() lru_cache = dns.resolver.LRUCache(100) for cache in [basic_cache, lru_cache]: @@ -453,10 +462,11 @@ class BaseResolverTests(unittest.TestCase): return True cnode = cnode.next return False + cache = dns.resolver.LRUCache(4) answer1 = FakeAnswer(time.time() + 10) answer2 = FakeAnswer(time.time() + 10) - key = (dns.name.from_text('key.'), dns.rdatatype.A, dns.rdataclass.IN) + key = (dns.name.from_text("key."), dns.rdatatype.A, dns.rdataclass.IN) cache.put(key, answer1) canswer = cache.get(key) self.assertTrue(canswer is answer1) @@ -469,8 +479,8 @@ class BaseResolverTests(unittest.TestCase): def test_cache_stats(self): caches = [dns.resolver.Cache(), dns.resolver.LRUCache(4)] - key1 = (dns.name.from_text('key1.'), dns.rdatatype.A, dns.rdataclass.IN) - key2 = (dns.name.from_text('key2.'), dns.rdatatype.A, dns.rdataclass.IN) + key1 = (dns.name.from_text("key1."), dns.rdatatype.A, dns.rdataclass.IN) + key2 = (dns.name.from_text("key2."), dns.rdatatype.A, dns.rdataclass.IN) for cache in caches: answer1 = FakeAnswer(time.time() + 10) answer2 = FakeAnswer(10) # expired! @@ -507,66 +517,83 @@ class BaseResolverTests(unittest.TestCase): # with an empty answer section. Other than that it doesn't # apply. message = dns.message.from_text(dangling_cname_0_message_text) - name = dns.name.from_text('example.') - answer = dns.resolver.Answer(name, dns.rdatatype.A, dns.rdataclass.IN, - message) + name = dns.name.from_text("example.") + answer = dns.resolver.Answer(name, dns.rdatatype.A, dns.rdataclass.IN, message) + def test_python_internal_truth(answer): if answer: return True else: return False + self.assertFalse(test_python_internal_truth(answer)) for a in answer: pass def testSearchListsRelative(self): res = dns.resolver.Resolver(configure=False) - res.domain = dns.name.from_text('example') - res.search = [dns.name.from_text(x) for x in - ['dnspython.org', 'dnspython.net']] - qname = dns.name.from_text('www', None) + res.domain = dns.name.from_text("example") + res.search = [dns.name.from_text(x) for x in ["dnspython.org", "dnspython.net"]] + qname = dns.name.from_text("www", None) qnames = res._get_qnames_to_try(qname, True) - self.assertEqual(qnames, - [dns.name.from_text(x) for x in - ['www.dnspython.org', 'www.dnspython.net', 'www.']]) + self.assertEqual( + qnames, + [ + dns.name.from_text(x) + for x in ["www.dnspython.org", "www.dnspython.net", "www."] + ], + ) qnames = res._get_qnames_to_try(qname, False) - self.assertEqual(qnames, - [dns.name.from_text('www.')]) + self.assertEqual(qnames, [dns.name.from_text("www.")]) qnames = res._get_qnames_to_try(qname, None) - self.assertEqual(qnames, - [dns.name.from_text('www.')]) + self.assertEqual(qnames, [dns.name.from_text("www.")]) # # Now change search default on resolver to True # res.use_search_by_default = True qnames = res._get_qnames_to_try(qname, None) - self.assertEqual(qnames, - [dns.name.from_text(x) for x in - ['www.dnspython.org', 'www.dnspython.net', 'www.']]) + self.assertEqual( + qnames, + [ + dns.name.from_text(x) + for x in ["www.dnspython.org", "www.dnspython.net", "www."] + ], + ) # # Now test ndots # - qname = dns.name.from_text('a.b', None) + qname = dns.name.from_text("a.b", None) res.ndots = 1 qnames = res._get_qnames_to_try(qname, True) - self.assertEqual(qnames, - [dns.name.from_text(x) for x in - ['a.b', 'a.b.dnspython.org', 'a.b.dnspython.net']]) + self.assertEqual( + qnames, + [ + dns.name.from_text(x) + for x in ["a.b", "a.b.dnspython.org", "a.b.dnspython.net"] + ], + ) res.ndots = 2 qnames = res._get_qnames_to_try(qname, True) - self.assertEqual(qnames, - [dns.name.from_text(x) for x in - ['a.b.dnspython.org', 'a.b.dnspython.net', 'a.b']]) - qname = dns.name.from_text('a.b.c', None) + self.assertEqual( + qnames, + [ + dns.name.from_text(x) + for x in ["a.b.dnspython.org", "a.b.dnspython.net", "a.b"] + ], + ) + qname = dns.name.from_text("a.b.c", None) qnames = res._get_qnames_to_try(qname, True) - self.assertEqual(qnames, - [dns.name.from_text(x) for x in - ['a.b.c', 'a.b.c.dnspython.org', - 'a.b.c.dnspython.net']]) + self.assertEqual( + qnames, + [ + dns.name.from_text(x) + for x in ["a.b.c", "a.b.c.dnspython.org", "a.b.c.dnspython.net"] + ], + ) def testSearchListsAbsolute(self): res = dns.resolver.Resolver(configure=False) - qname = dns.name.from_text('absolute') + qname = dns.name.from_text("absolute") qnames = res._get_qnames_to_try(qname, True) self.assertEqual(qnames, [qname]) qnames = res._get_qnames_to_try(qname, False) @@ -590,91 +617,94 @@ class BaseResolverTests(unittest.TestCase): self.assertEqual(r.flags, flags) def testUseTSIG(self): - keyring = dns.tsigkeyring.from_text( - { - 'keyname.': 'NjHwPsMKjdN++dOfE5iAiQ==' - } - ) + keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="}) r = dns.resolver.Resolver(configure=False) r.use_tsig(keyring) self.assertEqual(r.keyring, keyring) self.assertEqual(r.keyname, None) self.assertEqual(r.keyalgorithm, dns.tsig.default_algorithm) -keyname = dns.name.from_text('keyname') +keyname = dns.name.from_text("keyname") @unittest.skipIf(not _network_available, "Internet not reachable") class LiveResolverTests(unittest.TestCase): def testZoneForName1(self): - name = dns.name.from_text('www.dnspython.org.') - ezname = dns.name.from_text('dnspython.org.') + name = dns.name.from_text("www.dnspython.org.") + ezname = dns.name.from_text("dnspython.org.") zname = dns.resolver.zone_for_name(name) self.assertEqual(zname, ezname) def testZoneForName2(self): - name = dns.name.from_text('a.b.www.dnspython.org.') - ezname = dns.name.from_text('dnspython.org.') + name = dns.name.from_text("a.b.www.dnspython.org.") + ezname = dns.name.from_text("dnspython.org.") zname = dns.resolver.zone_for_name(name) self.assertEqual(zname, ezname) def testZoneForName3(self): - ezname = dns.name.from_text('dnspython.org.') - zname = dns.resolver.zone_for_name('dnspython.org.') + ezname = dns.name.from_text("dnspython.org.") + zname = dns.resolver.zone_for_name("dnspython.org.") self.assertEqual(zname, ezname) def testZoneForName4(self): def bad(): - name = dns.name.from_text('dnspython.org', None) + name = dns.name.from_text("dnspython.org", None) dns.resolver.zone_for_name(name) + self.assertRaises(dns.resolver.NotAbsolute, bad) def testResolve(self): - answer = dns.resolver.resolve('dns.google.', 'A') + answer = dns.resolver.resolve("dns.google.", "A") seen = set([rdata.address for rdata in answer]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testResolveTCP(self): - answer = dns.resolver.resolve('dns.google.', 'A', tcp=True) + answer = dns.resolver.resolve("dns.google.", "A", tcp=True) seen = set([rdata.address for rdata in answer]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) def testResolveAddress(self): - answer = dns.resolver.resolve_address('8.8.8.8') - dnsgoogle = dns.name.from_text('dns.google.') + answer = dns.resolver.resolve_address("8.8.8.8") + dnsgoogle = dns.name.from_text("dns.google.") self.assertEqual(answer[0].target, dnsgoogle) - @patch.object(dns.message.Message, 'use_edns') + @patch.object(dns.message.Message, "use_edns") def testResolveEdnsOptions(self, message_use_edns_mock): resolver = dns.resolver.Resolver() - options = [dns.edns.ECSOption('1.1.1.1')] + options = [dns.edns.ECSOption("1.1.1.1")] resolver.use_edns(True, options=options) - resolver.resolve('dns.google.', 'A') - assert {'options': options} in message_use_edns_mock.call_args + resolver.resolve("dns.google.", "A") + assert {"options": options} in message_use_edns_mock.call_args def testResolveNodataException(self): def bad(): - dns.resolver.resolve('dnspython.org.', 'SRV') + dns.resolver.resolve("dnspython.org.", "SRV") + self.assertRaises(dns.resolver.NoAnswer, bad) def testResolveNodataAnswer(self): - qname = dns.name.from_text('dnspython.org') - qclass = dns.rdataclass.from_text('IN') - qtype = dns.rdatatype.from_text('SRV') + qname = dns.name.from_text("dnspython.org") + qclass = dns.rdataclass.from_text("IN") + qtype = dns.rdatatype.from_text("SRV") answer = dns.resolver.resolve(qname, qtype, raise_on_no_answer=False) - self.assertRaises(KeyError, - lambda: answer.response.find_rrset(answer.response.answer, - qname, qclass, qtype)) + self.assertRaises( + KeyError, + lambda: answer.response.find_rrset( + answer.response.answer, qname, qclass, qtype + ), + ) def testResolveNXDOMAIN(self): - qname = dns.name.from_text('nxdomain.dnspython.org') - qclass = dns.rdataclass.from_text('IN') - qtype = dns.rdatatype.from_text('A') + qname = dns.name.from_text("nxdomain.dnspython.org") + qclass = dns.rdataclass.from_text("IN") + qtype = dns.rdatatype.from_text("A") + def bad(): answer = dns.resolver.resolve(qname, qtype) + try: dns.resolver.resolve(qname, qtype) self.assertTrue(False) # should not happen! @@ -684,36 +714,36 @@ class LiveResolverTests(unittest.TestCase): def testResolveCacheHit(self): res = dns.resolver.Resolver(configure=False) - res.nameservers = ['8.8.8.8'] + res.nameservers = ["8.8.8.8"] res.cache = dns.resolver.Cache() - answer1 = res.resolve('dns.google.', 'A') + answer1 = res.resolve("dns.google.", "A") seen = set([rdata.address for rdata in answer1]) - self.assertIn('8.8.8.8', seen) - self.assertIn('8.8.4.4', seen) - answer2 = res.resolve('dns.google.', 'A') + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + answer2 = res.resolve("dns.google.", "A") self.assertIs(answer2, answer1) def testCanonicalNameNoCNAME(self): - cname = dns.name.from_text('www.google.com') - self.assertEqual(dns.resolver.canonical_name('www.google.com'), cname) + cname = dns.name.from_text("www.google.com") + self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname) def testCanonicalNameCNAME(self): - name = dns.name.from_text('www.dnspython.org') - cname = dns.name.from_text('dmfrjf4ips8xa.cloudfront.net') + name = dns.name.from_text("www.dnspython.org") + cname = dns.name.from_text("dmfrjf4ips8xa.cloudfront.net") self.assertEqual(dns.resolver.canonical_name(name), cname) @unittest.skipIf(_systemd_resolved_present, "systemd-resolved in use") def testCanonicalNameDangling(self): - name = dns.name.from_text('dangling-cname.dnspython.org') - cname = dns.name.from_text('dangling-target.dnspython.org') + name = dns.name.from_text("dangling-cname.dnspython.org") + cname = dns.name.from_text("dangling-target.dnspython.org") self.assertEqual(dns.resolver.canonical_name(name), cname) def testNameserverSetting(self): res = dns.resolver.Resolver(configure=False) - ns = ['1.2.3.4', '::1', 'https://ns.example'] + ns = ["1.2.3.4", "::1", "https://ns.example"] res.nameservers = ns[:] self.assertEqual(res.nameservers, ns) - for ns in ['999.999.999.999', 'ns.example.', 'bogus://ns.example']: + for ns in ["999.999.999.999", "ns.example.", "bogus://ns.example"]: with self.assertRaises(ValueError): res.nameservers = [ns] @@ -731,13 +761,18 @@ class PollingMonkeyPatchMixin(object): unittest.TestCase.tearDown(self) -class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): +class SelectResolverTestCase( + PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase +): def selector_class(self): return selectors.SelectSelector -if hasattr(selectors, 'PollSelector'): - class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): +if hasattr(selectors, "PollSelector"): + + class PollResolverTestCase( + PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase + ): def selector_class(self): return selectors.PollSelector @@ -747,35 +782,35 @@ class NXDOMAINExceptionTestCase(unittest.TestCase): # pylint: disable=broad-except def test_nxdomain_compatible(self): - n1 = dns.name.Name(('a', 'b', '')) - n2 = dns.name.Name(('a', 'b', 's', '')) + n1 = dns.name.Name(("a", "b", "")) + n2 = dns.name.Name(("a", "b", "s", "")) try: raise dns.resolver.NXDOMAIN except dns.exception.DNSException as e: self.assertEqual(e.args, (e.__doc__,)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), e.__doc__, str(e)) - self.assertTrue(('qnames' not in e.kwargs)) - self.assertTrue(('responses' not in e.kwargs)) + self.assertTrue(("qnames" not in e.kwargs)) + self.assertTrue(("responses" not in e.kwargs)) try: raise dns.resolver.NXDOMAIN("errmsg") except dns.exception.DNSException as e: self.assertEqual(e.args, ("errmsg",)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), "errmsg", str(e)) - self.assertTrue(('qnames' not in e.kwargs)) - self.assertTrue(('responses' not in e.kwargs)) + self.assertTrue(("qnames" not in e.kwargs)) + self.assertTrue(("responses" not in e.kwargs)) try: raise dns.resolver.NXDOMAIN("errmsg", -1) except dns.exception.DNSException as e: self.assertEqual(e.args, ("errmsg", -1)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), "('errmsg', -1)", str(e)) - self.assertTrue(('qnames' not in e.kwargs)) - self.assertTrue(('responses' not in e.kwargs)) + self.assertTrue(("qnames" not in e.kwargs)) + self.assertTrue(("responses" not in e.kwargs)) try: raise dns.resolver.NXDOMAIN(qnames=None) @@ -797,12 +832,12 @@ class NXDOMAINExceptionTestCase(unittest.TestCase): except dns.exception.DNSException as e: MSG = "The DNS query name does not exist: a.b." self.assertEqual(e.args, (MSG,), repr(e.args)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), MSG, str(e)) - self.assertTrue(('qnames' in e.kwargs)) - self.assertEqual(e.kwargs['qnames'], [n1]) - self.assertTrue(('responses' in e.kwargs)) - self.assertEqual(e.kwargs['responses'], {}) + self.assertTrue(("qnames" in e.kwargs)) + self.assertEqual(e.kwargs["qnames"], [n1]) + self.assertTrue(("responses" in e.kwargs)) + self.assertEqual(e.kwargs["responses"], {}) try: raise dns.resolver.NXDOMAIN(qnames=[n2, n1]) @@ -811,50 +846,49 @@ class NXDOMAINExceptionTestCase(unittest.TestCase): e = e0 + e MSG = "None of DNS query names exist: a.b.s., a.b." self.assertEqual(e.args, (MSG,), repr(e.args)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), MSG, str(e)) - self.assertTrue(('qnames' in e.kwargs)) - self.assertEqual(e.kwargs['qnames'], [n2, n1]) - self.assertTrue(('responses' in e.kwargs)) - self.assertEqual(e.kwargs['responses'], {}) + self.assertTrue(("qnames" in e.kwargs)) + self.assertEqual(e.kwargs["qnames"], [n2, n1]) + self.assertTrue(("responses" in e.kwargs)) + self.assertEqual(e.kwargs["responses"], {}) try: - raise dns.resolver.NXDOMAIN(qnames=[n1], responses=['r1.1']) + raise dns.resolver.NXDOMAIN(qnames=[n1], responses=["r1.1"]) except Exception as e: self.assertTrue((isinstance(e, AttributeError))) try: - raise dns.resolver.NXDOMAIN(qnames=[n1], responses={n1: 'r1.1'}) + raise dns.resolver.NXDOMAIN(qnames=[n1], responses={n1: "r1.1"}) except dns.resolver.NXDOMAIN as e: MSG = "The DNS query name does not exist: a.b." self.assertEqual(e.args, (MSG,), repr(e.args)) - self.assertTrue(('kwargs' in dir(e))) + self.assertTrue(("kwargs" in dir(e))) self.assertEqual(str(e), MSG, str(e)) - self.assertTrue(('qnames' in e.kwargs)) - self.assertEqual(e.kwargs['qnames'], [n1]) - self.assertTrue(('responses' in e.kwargs)) - self.assertEqual(e.kwargs['responses'], {n1: 'r1.1'}) + self.assertTrue(("qnames" in e.kwargs)) + self.assertEqual(e.kwargs["qnames"], [n1]) + self.assertTrue(("responses" in e.kwargs)) + self.assertEqual(e.kwargs["responses"], {n1: "r1.1"}) def test_nxdomain_merge(self): - n1 = dns.name.Name(('a', 'b', '')) - n2 = dns.name.Name(('a', 'b', '')) - n3 = dns.name.Name(('a', 'b', 'c', '')) - n4 = dns.name.Name(('a', 'b', 'd', '')) - responses1 = {n1: 'r1.1', n2: 'r1.2', n4: 'r1.4'} - qnames1 = [n1, n4] # n2 == n1 - responses2 = {n2: 'r2.2', n3: 'r2.3'} + n1 = dns.name.Name(("a", "b", "")) + n2 = dns.name.Name(("a", "b", "")) + n3 = dns.name.Name(("a", "b", "c", "")) + n4 = dns.name.Name(("a", "b", "d", "")) + responses1 = {n1: "r1.1", n2: "r1.2", n4: "r1.4"} + qnames1 = [n1, n4] # n2 == n1 + responses2 = {n2: "r2.2", n3: "r2.3"} qnames2 = [n2, n3] e0 = dns.resolver.NXDOMAIN() e1 = dns.resolver.NXDOMAIN(qnames=qnames1, responses=responses1) e2 = dns.resolver.NXDOMAIN(qnames=qnames2, responses=responses2) e = e1 + e0 + e2 self.assertRaises(AttributeError, lambda: e0 + e0) - self.assertEqual(e.kwargs['qnames'], [n1, n4, n3], - repr(e.kwargs['qnames'])) - self.assertTrue(e.kwargs['responses'][n1].startswith('r2.')) - self.assertTrue(e.kwargs['responses'][n2].startswith('r2.')) - self.assertTrue(e.kwargs['responses'][n3].startswith('r2.')) - self.assertTrue(e.kwargs['responses'][n4].startswith('r1.')) + self.assertEqual(e.kwargs["qnames"], [n1, n4, n3], repr(e.kwargs["qnames"])) + self.assertTrue(e.kwargs["responses"][n1].startswith("r2.")) + self.assertTrue(e.kwargs["responses"][n2].startswith("r2.")) + self.assertTrue(e.kwargs["responses"][n3].startswith("r2.")) + self.assertTrue(e.kwargs["responses"][n4].startswith("r1.")) def test_nxdomain_canonical_name(self): cname1 = "91.11.8-22.17.172.in-addr.arpa." @@ -877,35 +911,39 @@ class NXDOMAINExceptionTestCase(unittest.TestCase): class ResolverMiscTestCase(unittest.TestCase): - if sys.platform != 'win32': + if sys.platform != "win32": + def test_read_nonexistent_config(self): res = dns.resolver.Resolver(configure=False) - pathname = '/etc/nonexistent-resolv.conf' - self.assertRaises(dns.resolver.NoResolverConfiguration, - lambda: res.read_resolv_conf(pathname)) + pathname = "/etc/nonexistent-resolv.conf" + self.assertRaises( + dns.resolver.NoResolverConfiguration, + lambda: res.read_resolv_conf(pathname), + ) def test_compute_timeout(self): res = dns.resolver.Resolver(configure=False) now = time.time() - self.assertRaises(dns.resolver.Timeout, - lambda: res._compute_timeout(now + 10000)) - self.assertRaises(dns.resolver.Timeout, - lambda: res._compute_timeout(0)) + self.assertRaises( + dns.resolver.Timeout, lambda: res._compute_timeout(now + 10000) + ) + self.assertRaises(dns.resolver.Timeout, lambda: res._compute_timeout(0)) # not raising is the test res._compute_timeout(now + 0.5) - if sys.platform == 'win32': + if sys.platform == "win32": + def test_configure_win32_domain(self): - n = dns.name.from_text('home.') - self.assertEqual(n, dns.win32util._config_domain('home')) - self.assertEqual(n, dns.win32util._config_domain('.home')) + n = dns.name.from_text("home.") + self.assertEqual(n, dns.win32util._config_domain("home")) + self.assertEqual(n, dns.win32util._config_domain(".home")) class ResolverNameserverValidTypeTestCase(unittest.TestCase): def test_set_nameservers_to_list(self): resolver = dns.resolver.Resolver(configure=False) - resolver.nameservers = ['1.2.3.4'] - self.assertEqual(resolver.nameservers, ['1.2.3.4']) + resolver.nameservers = ["1.2.3.4"] + self.assertEqual(resolver.nameservers, ["1.2.3.4"]) def test_set_namservers_to_empty_list(self): resolver = dns.resolver.Resolver(configure=False) @@ -914,27 +952,35 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): def test_set_nameservers_invalid_type(self): resolver = dns.resolver.Resolver(configure=False) - invalid_nameservers = [None, '1.2.3.4', 1234, (1, 2, 3, 4), {'invalid': 'nameserver'}] + invalid_nameservers = [ + None, + "1.2.3.4", + 1234, + (1, 2, 3, 4), + {"invalid": "nameserver"}, + ] for invalid_nameserver in invalid_nameservers: with self.assertRaises(ValueError): resolver.nameservers = invalid_nameserver class NaptrNanoNameserver(Server): - def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: - zero_subdomain = dns.e164.from_e164('0') + zero_subdomain = dns.e164.from_e164("0") if request.qname.is_subdomain(zero_subdomain): response.set_rcode(dns.rcode.NXDOMAIN) response.flags |= dns.flags.AA - elif request.qtype == dns.rdatatype.NAPTR and \ - request.qclass == dns.rdataclass.IN: - rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'NAPTR', - '0 0 "" "" "" .') + elif ( + request.qtype == dns.rdatatype.NAPTR + and request.qclass == dns.rdataclass.IN + ): + rrs = dns.rrset.from_text( + request.qname, 300, "IN", "NAPTR", '0 0 "" "" "" .' + ) response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR) response.flags |= dns.flags.AA @@ -943,28 +989,28 @@ class NaptrNanoNameserver(Server): return response -@unittest.skipIf(not (_network_available and _nanonameserver_available), - "Internet and NanoAuth required") +@unittest.skipIf( + not (_network_available and _nanonameserver_available), + "Internet and NanoAuth required", +) class NanoTests(unittest.TestCase): - def testE164Query(self): with NaptrNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] - answer = dns.e164.query('1650551212', ['e164.arpa'], res) + answer = dns.e164.query("1650551212", ["e164.arpa"], res) self.assertEqual(answer[0].order, 0) self.assertEqual(answer[0].preference, 0) - self.assertEqual(answer[0].flags, b'') - self.assertEqual(answer[0].service, b'') - self.assertEqual(answer[0].regexp, b'') + self.assertEqual(answer[0].flags, b"") + self.assertEqual(answer[0].service, b"") + self.assertEqual(answer[0].regexp, b"") self.assertEqual(answer[0].replacement, dns.name.root) with self.assertRaises(dns.resolver.NXDOMAIN): - dns.e164.query('0123456789', ['e164.arpa'], res) + dns.e164.query("0123456789", ["e164.arpa"], res) class AlwaysType3NXDOMAINNanoNameserver(Server): - def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.NXDOMAIN) @@ -973,56 +1019,65 @@ class AlwaysType3NXDOMAINNanoNameserver(Server): class AlwaysNXDOMAINNanoNameserver(Server): - def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.NXDOMAIN) response.flags |= dns.flags.RA - origin = dns.name.from_text('example.') - soa_rrset = response.find_rrset(response.authority, origin, - dns.rdataclass.IN, dns.rdatatype.SOA, - create=True) - rdata = dns.rdata.from_text('IN', 'SOA', - 'ns.example. root.example. 1 2 3 4 5') + origin = dns.name.from_text("example.") + soa_rrset = response.find_rrset( + response.authority, + origin, + dns.rdataclass.IN, + dns.rdatatype.SOA, + create=True, + ) + rdata = dns.rdata.from_text("IN", "SOA", "ns.example. root.example. 1 2 3 4 5") soa_rrset.add(rdata) soa_rrset.update_ttl(300) return response -class AlwaysNoErrorNoDataNanoNameserver(Server): +class AlwaysNoErrorNoDataNanoNameserver(Server): def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.NOERROR) response.flags |= dns.flags.RA - origin = dns.name.from_text('example.') - soa_rrset = response.find_rrset(response.authority, origin, - dns.rdataclass.IN, dns.rdatatype.SOA, - create=True) - rdata = dns.rdata.from_text('IN', 'SOA', - 'ns.example. root.example. 1 2 3 4 5') + origin = dns.name.from_text("example.") + soa_rrset = response.find_rrset( + response.authority, + origin, + dns.rdataclass.IN, + dns.rdatatype.SOA, + create=True, + ) + rdata = dns.rdata.from_text("IN", "SOA", "ns.example. root.example. 1 2 3 4 5") soa_rrset.add(rdata) soa_rrset.update_ttl(300) return response -@unittest.skipIf(not (_network_available and _nanonameserver_available), - "Internet and NanoAuth required") -class ZoneForNameTests(unittest.TestCase): + +@unittest.skipIf( + not (_network_available and _nanonameserver_available), + "Internet and NanoAuth required", +) +class ZoneForNameTests(unittest.TestCase): def testNoRootSOA(self): with AlwaysType3NXDOMAINNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] with self.assertRaises(dns.resolver.NoRootSOA): - dns.resolver.zone_for_name('www.foo.bar.', resolver=res) + dns.resolver.zone_for_name("www.foo.bar.", resolver=res) def testHelpfulNXDOMAIN(self): with AlwaysNXDOMAINNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] - expected = dns.name.from_text('example.') - name = dns.resolver.zone_for_name('1.2.3.4.5.6.7.8.9.10.example.', - resolver=res) + expected = dns.name.from_text("example.") + name = dns.resolver.zone_for_name( + "1.2.3.4.5.6.7.8.9.10.example.", resolver=res + ) self.assertEqual(name, expected) def testHelpfulNoErrorNoData(self): @@ -1030,19 +1085,19 @@ class ZoneForNameTests(unittest.TestCase): res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] - expected = dns.name.from_text('example.') - name = dns.resolver.zone_for_name('1.2.3.4.5.6.7.8.9.10.example.', - resolver=res) + expected = dns.name.from_text("example.") + name = dns.resolver.zone_for_name( + "1.2.3.4.5.6.7.8.9.10.example.", resolver=res + ) self.assertEqual(name, expected) -class DroppingNanoNameserver(Server): +class DroppingNanoNameserver(Server): def handle(self, request): return None class FormErrNanoNameserver(Server): - def handle(self, request): r = dns.message.make_response(request.message) r.set_rcode(dns.rcode.FORMERR) @@ -1052,8 +1107,11 @@ class FormErrNanoNameserver(Server): # we use pytest for these so we can have a "slow" mark later if we want to # (right now it's still fast enough we don't really need it) -@pytest.mark.skipif(not (_network_available and _nanonameserver_available), - reason="Internet and NanoAuth required") + +@pytest.mark.skipif( + not (_network_available and _nanonameserver_available), + reason="Internet and NanoAuth required", +) def testResolverTimeout(): with DroppingNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) @@ -1062,13 +1120,13 @@ def testResolverTimeout(): res.timeout = 0.2 try: lifetime = 1.0 - a = res.resolve('www.dnspython.org', lifetime=lifetime) + a = res.resolve("www.dnspython.org", lifetime=lifetime) assert False # should never happen except dns.resolver.LifetimeTimeout as e: - assert e.kwargs['timeout'] >= lifetime + assert e.kwargs["timeout"] >= lifetime # The length of errors can vary based on how slow things are, # but it ought to be > 1, so we assert that. - errors = e.kwargs['errors'] + errors = e.kwargs["errors"] assert len(errors) > 1 for error in errors: assert error[0] == na.udp_address[0] # address @@ -1076,28 +1134,30 @@ def testResolverTimeout(): assert error[2] == na.udp_address[1] # port assert isinstance(error[3], dns.exception.Timeout) # exception -@pytest.mark.skipif(not (_network_available and _nanonameserver_available), - reason="Internet and NanoAuth required") + +@pytest.mark.skipif( + not (_network_available and _nanonameserver_available), + reason="Internet and NanoAuth required", +) def testResolverNoNameservers(): with FormErrNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] try: - a = res.resolve('www.dnspython.org') + a = res.resolve("www.dnspython.org") assert False # should never happen except dns.resolver.NoNameservers as e: - errors = e.kwargs['errors'] + errors = e.kwargs["errors"] assert len(errors) == 1 for error in errors: assert error[0] == na.udp_address[0] # address assert not error[1] # not TCP assert error[2] == na.udp_address[1] # port - assert error[3] == 'FORMERR' + assert error[3] == "FORMERR" class SlowAlwaysType3NXDOMAINNanoNameserver(Server): - def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.NXDOMAIN) @@ -1106,13 +1166,16 @@ class SlowAlwaysType3NXDOMAINNanoNameserver(Server): return response -@pytest.mark.skipif(not (_network_available and _nanonameserver_available), - reason="Internet and NanoAuth required") +@pytest.mark.skipif( + not (_network_available and _nanonameserver_available), + reason="Internet and NanoAuth required", +) def testZoneForNameLifetimeTimeout(): with SlowAlwaysType3NXDOMAINNanoNameserver() as na: res = dns.resolver.Resolver(configure=False) res.port = na.udp_address[1] res.nameservers = [na.udp_address[0]] with pytest.raises(dns.resolver.LifetimeTimeout): - dns.resolver.zone_for_name('1.2.3.4.5.6.7.8.9.10.example.', - resolver=res, lifetime=1.0) + dns.resolver.zone_for_name( + "1.2.3.4.5.6.7.8.9.10.example.", resolver=res, lifetime=1.0 + ) diff --git a/tests/test_resolver_override.py b/tests/test_resolver_override.py index ac93316a..3d79445d 100644 --- a/tests/test_resolver_override.py +++ b/tests/test_resolver_override.py @@ -13,17 +13,16 @@ import dns.resolver # skip those if it's not there. _network_available = True try: - socket.gethostbyname('dnspython.org') + socket.gethostbyname("dnspython.org") except socket.gaierror: _network_available = False @unittest.skipIf(not _network_available, "Internet not reachable") class OverrideSystemResolverTestCase(unittest.TestCase): - def setUp(self): self.res = dns.resolver.Resolver(configure=False) - self.res.nameservers = ['8.8.8.8'] + self.res.nameservers = ["8.8.8.8"] self.res.cache = dns.resolver.LRUCache() dns.resolver.override_system_resolver(self.res) @@ -32,34 +31,50 @@ class OverrideSystemResolverTestCase(unittest.TestCase): self.res = None def test_override(self): - self.assertTrue(socket.getaddrinfo is - dns.resolver._getaddrinfo) - socket.gethostbyname('www.dnspython.org') - answer = self.res.cache.get((dns.name.from_text('www.dnspython.org.'), - dns.rdatatype.A, dns.rdataclass.IN)) + self.assertTrue(socket.getaddrinfo is dns.resolver._getaddrinfo) + socket.gethostbyname("www.dnspython.org") + answer = self.res.cache.get( + ( + dns.name.from_text("www.dnspython.org."), + dns.rdatatype.A, + dns.rdataclass.IN, + ) + ) self.assertTrue(answer is not None) self.res.cache.flush() - socket.gethostbyname_ex('www.dnspython.org') - answer = self.res.cache.get((dns.name.from_text('www.dnspython.org.'), - dns.rdatatype.A, dns.rdataclass.IN)) + socket.gethostbyname_ex("www.dnspython.org") + answer = self.res.cache.get( + ( + dns.name.from_text("www.dnspython.org."), + dns.rdatatype.A, + dns.rdataclass.IN, + ) + ) self.assertTrue(answer is not None) self.res.cache.flush() - socket.getfqdn('8.8.8.8') + socket.getfqdn("8.8.8.8") answer = self.res.cache.get( - (dns.name.from_text('8.8.8.8.in-addr.arpa.'), - dns.rdatatype.PTR, dns.rdataclass.IN)) + ( + dns.name.from_text("8.8.8.8.in-addr.arpa."), + dns.rdatatype.PTR, + dns.rdataclass.IN, + ) + ) self.assertTrue(answer is not None) self.res.cache.flush() - socket.gethostbyaddr('8.8.8.8') + socket.gethostbyaddr("8.8.8.8") answer = self.res.cache.get( - (dns.name.from_text('8.8.8.8.in-addr.arpa.'), - dns.rdatatype.PTR, dns.rdataclass.IN)) + ( + dns.name.from_text("8.8.8.8.in-addr.arpa."), + dns.rdatatype.PTR, + dns.rdataclass.IN, + ) + ) self.assertTrue(answer is not None) # restoring twice is harmless, so we restore now instead of # waiting for tearDown so we can assert that it worked dns.resolver.restore_system_resolver() - self.assertTrue(socket.getaddrinfo is - dns.resolver._original_getaddrinfo) + self.assertTrue(socket.getaddrinfo is dns.resolver._original_getaddrinfo) def equivalent_info(self, a, b): if len(a) != len(b): @@ -70,7 +85,7 @@ class OverrideSystemResolverTestCase(unittest.TestCase): # looking for a zero protocol. y = (x[0], x[1], 0, x[3], x[4]) if y not in b: - print('NOT EQUIVALENT') + print("NOT EQUIVALENT") print(a) print(b) return False @@ -81,93 +96,114 @@ class OverrideSystemResolverTestCase(unittest.TestCase): b = dns.resolver._original_getaddrinfo(*args, **kwargs) return self.equivalent_info(a, b) - @unittest.skipIf(sys.platform == 'win32', - 'avoid windows original getaddrinfo issues') + @unittest.skipIf( + sys.platform == "win32", "avoid windows original getaddrinfo issues" + ) def test_basic_getaddrinfo(self): - self.assertTrue(self.equivalent('dns.google', 53, socket.AF_INET, - socket.SOCK_DGRAM)) - self.assertTrue(self.equivalent('dns.google', 53, socket.AF_INET6, - socket.SOCK_DGRAM)) - self.assertTrue(self.equivalent('dns.google', None, socket.AF_UNSPEC, - socket.SOCK_DGRAM)) - self.assertTrue(self.equivalent('8.8.8.8', 53, socket.AF_INET, - socket.SOCK_DGRAM)) - self.assertTrue(self.equivalent('2001:4860:4860::8888', 53, - socket.AF_INET6, socket.SOCK_DGRAM)) - self.assertTrue(self.equivalent('8.8.8.8', 53, socket.AF_INET, - socket.SOCK_DGRAM, - flags=socket.AI_NUMERICHOST)) - self.assertTrue(self.equivalent('2001:4860:4860::8888', 53, - socket.AF_INET6, socket.SOCK_DGRAM, - flags=socket.AI_NUMERICHOST)) + self.assertTrue( + self.equivalent("dns.google", 53, socket.AF_INET, socket.SOCK_DGRAM) + ) + self.assertTrue( + self.equivalent("dns.google", 53, socket.AF_INET6, socket.SOCK_DGRAM) + ) + self.assertTrue( + self.equivalent("dns.google", None, socket.AF_UNSPEC, socket.SOCK_DGRAM) + ) + self.assertTrue( + self.equivalent("8.8.8.8", 53, socket.AF_INET, socket.SOCK_DGRAM) + ) + self.assertTrue( + self.equivalent( + "2001:4860:4860::8888", 53, socket.AF_INET6, socket.SOCK_DGRAM + ) + ) + self.assertTrue( + self.equivalent( + "8.8.8.8", + 53, + socket.AF_INET, + socket.SOCK_DGRAM, + flags=socket.AI_NUMERICHOST, + ) + ) + self.assertTrue( + self.equivalent( + "2001:4860:4860::8888", + 53, + socket.AF_INET6, + socket.SOCK_DGRAM, + flags=socket.AI_NUMERICHOST, + ) + ) def test_getaddrinfo_nxdomain(self): try: - socket.getaddrinfo('nxdomain.dnspython.org.', 53) - self.assertTrue(False) # should not happen! + socket.getaddrinfo("nxdomain.dnspython.org.", 53) + self.assertTrue(False) # should not happen! except socket.gaierror as e: self.assertEqual(e.errno, socket.EAI_NONAME) def test_getaddrinfo_service(self): - a = socket.getaddrinfo('dns.google', 'domain') - b = socket.getaddrinfo('dns.google', 53) + a = socket.getaddrinfo("dns.google", "domain") + b = socket.getaddrinfo("dns.google", 53) self.assertTrue(self.equivalent_info(a, b)) try: - socket.getaddrinfo('dns.google', 'domain', - flags=socket.AI_NUMERICSERV) - self.assertTrue(False) # should not happen! + socket.getaddrinfo("dns.google", "domain", flags=socket.AI_NUMERICSERV) + self.assertTrue(False) # should not happen! except socket.gaierror as e: self.assertEqual(e.errno, socket.EAI_NONAME) def test_getaddrinfo_only_service(self): - infos = socket.getaddrinfo(service=53, family=socket.AF_INET, - socktype=socket.SOCK_DGRAM, - proto=socket.IPPROTO_UDP) + infos = socket.getaddrinfo( + service=53, + family=socket.AF_INET, + socktype=socket.SOCK_DGRAM, + proto=socket.IPPROTO_UDP, + ) self.assertEqual(len(infos), 1) info = infos[0] self.assertEqual(info[0], socket.AF_INET) self.assertEqual(info[1], socket.SOCK_DGRAM) self.assertEqual(info[2], socket.IPPROTO_UDP) - self.assertEqual(info[4], ('127.0.0.1', 53)) + self.assertEqual(info[4], ("127.0.0.1", 53)) def test_unknown_service_fails(self): with self.assertRaises(socket.gaierror): - socket.getaddrinfo('dns.google.', 'bogus-service') + socket.getaddrinfo("dns.google.", "bogus-service") def test_getnameinfo_tcp(self): - info = socket.getnameinfo(('8.8.8.8', 53)) - self.assertEqual(info, ('dns.google', 'domain')) + info = socket.getnameinfo(("8.8.8.8", 53)) + self.assertEqual(info, ("dns.google", "domain")) def test_getnameinfo_udp(self): - info = socket.getnameinfo(('8.8.8.8', 53), socket.NI_DGRAM) - self.assertEqual(info, ('dns.google', 'domain')) - - -# Give up on testing this for now as all of the names I've considered -# using for testing are part of CDNs and there is deep magic in -# gethostbyaddr() that python's getfqdn() is using. At any rate, -# the problem is that dnspython just gives up whereas the native python -# code is looking up www.dnspython.org, picking a CDN IPv4 address -# (sometimes) and returning the reverse lookup of that address (i.e. -# the domain name of the CDN server). This isn't what I'd consider the -# FQDN of www.dnspython.org to be! -# -# def test_getfqdn(self): -# b = socket.getfqdn('www.dnspython.org') -# # we do this now because python's original getfqdn calls -# # gethostbyaddr() and we don't want it to call us! -# dns.resolver.restore_system_resolver() -# a = dns.resolver._original_getfqdn('www.dnspython.org') -# self.assertEqual(dns.name.from_text(a), dns.name.from_text(b)) + info = socket.getnameinfo(("8.8.8.8", 53), socket.NI_DGRAM) + self.assertEqual(info, ("dns.google", "domain")) + + # Give up on testing this for now as all of the names I've considered + # using for testing are part of CDNs and there is deep magic in + # gethostbyaddr() that python's getfqdn() is using. At any rate, + # the problem is that dnspython just gives up whereas the native python + # code is looking up www.dnspython.org, picking a CDN IPv4 address + # (sometimes) and returning the reverse lookup of that address (i.e. + # the domain name of the CDN server). This isn't what I'd consider the + # FQDN of www.dnspython.org to be! + # + # def test_getfqdn(self): + # b = socket.getfqdn('www.dnspython.org') + # # we do this now because python's original getfqdn calls + # # gethostbyaddr() and we don't want it to call us! + # dns.resolver.restore_system_resolver() + # a = dns.resolver._original_getfqdn('www.dnspython.org') + # self.assertEqual(dns.name.from_text(a), dns.name.from_text(b)) def test_gethostbyaddr(self): - a = dns.resolver._original_gethostbyaddr('8.8.8.8') - b = socket.gethostbyaddr('8.8.8.8') + a = dns.resolver._original_gethostbyaddr("8.8.8.8") + b = socket.gethostbyaddr("8.8.8.8") # We only test elements 0 and 2 as we don't set aliases currently! self.assertEqual(a[0], b[0]) self.assertEqual(a[2], b[2]) - a = dns.resolver._original_gethostbyaddr('2001:4860:4860::8888') - b = socket.gethostbyaddr('2001:4860:4860::8888') + a = dns.resolver._original_gethostbyaddr("2001:4860:4860::8888") + b = socket.gethostbyaddr("2001:4860:4860::8888") self.assertEqual(a[0], b[0]) self.assertEqual(a[2], b[2]) @@ -178,7 +214,6 @@ class FakeResolver: class OverrideSystemResolverUsingFakeResolverTestCase(unittest.TestCase): - def setUp(self): self.res = FakeResolver() dns.resolver.override_system_resolver(self.res) @@ -189,7 +224,7 @@ class OverrideSystemResolverUsingFakeResolverTestCase(unittest.TestCase): def test_temporary_failure(self): with self.assertRaises(socket.gaierror): - socket.getaddrinfo('dns.google') + socket.getaddrinfo("dns.google") # We don't need the fake resolver for the following tests, but we # don't need the live network either, so we're testing here. @@ -200,16 +235,15 @@ class OverrideSystemResolverUsingFakeResolverTestCase(unittest.TestCase): def test_AI_ADDRCONFIG_fails(self): with self.assertRaises(socket.gaierror): - socket.getaddrinfo('dns.google', flags=socket.AI_ADDRCONFIG) + socket.getaddrinfo("dns.google", flags=socket.AI_ADDRCONFIG) def test_gethostbyaddr_of_name_fails(self): with self.assertRaises(socket.gaierror): - socket.gethostbyaddr('bogus') + socket.gethostbyaddr("bogus") @unittest.skipIf(not _network_available, "Internet not reachable") class OverrideSystemResolverUsingDefaultResolverTestCase(unittest.TestCase): - def setUp(self): self.res = FakeResolver() dns.resolver.override_system_resolver() diff --git a/tests/test_rrset.py b/tests/test_rrset.py index 5c3f17da..23e33bb9 100644 --- a/tests/test_rrset.py +++ b/tests/test_rrset.py @@ -21,142 +21,174 @@ import unittest import dns.name import dns.rrset -class RRsetTestCase(unittest.TestCase): +class RRsetTestCase(unittest.TestCase): def testEqual1(self): - r1 = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1', '10.0.0.2') - r2 = dns.rrset.from_text('FOO', 300, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1", "10.0.0.2") + r2 = dns.rrset.from_text("FOO", 300, "in", "a", "10.0.0.2", "10.0.0.1") self.assertEqual(r1, r2) def testEqual2(self): - r1 = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1', '10.0.0.2') - r2 = dns.rrset.from_text('FOO', 600, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1", "10.0.0.2") + r2 = dns.rrset.from_text("FOO", 600, "in", "a", "10.0.0.2", "10.0.0.1") self.assertEqual(r1, r2) def testNotEqual1(self): - r1 = dns.rrset.from_text('fooa', 30, 'in', 'a', '10.0.0.1', '10.0.0.2') - r2 = dns.rrset.from_text('FOO', 30, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text("fooa", 30, "in", "a", "10.0.0.1", "10.0.0.2") + r2 = dns.rrset.from_text("FOO", 30, "in", "a", "10.0.0.2", "10.0.0.1") self.assertNotEqual(r1, r2) def testNotEqual2(self): - r1 = dns.rrset.from_text('foo', 30, 'in', 'a', '10.0.0.1', '10.0.0.3') - r2 = dns.rrset.from_text('FOO', 30, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text("foo", 30, "in", "a", "10.0.0.1", "10.0.0.3") + r2 = dns.rrset.from_text("FOO", 30, "in", "a", "10.0.0.2", "10.0.0.1") self.assertNotEqual(r1, r2) def testNotEqual3(self): - r1 = dns.rrset.from_text('foo', 30, 'in', 'a', '10.0.0.1', '10.0.0.2', - '10.0.0.3') - r2 = dns.rrset.from_text('FOO', 30, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text( + "foo", 30, "in", "a", "10.0.0.1", "10.0.0.2", "10.0.0.3" + ) + r2 = dns.rrset.from_text("FOO", 30, "in", "a", "10.0.0.2", "10.0.0.1") self.assertNotEqual(r1, r2) def testNotEqual4(self): - r1 = dns.rrset.from_text('foo', 30, 'in', 'a', '10.0.0.1') - r2 = dns.rrset.from_text('FOO', 30, 'in', 'a', '10.0.0.2', '10.0.0.1') + r1 = dns.rrset.from_text("foo", 30, "in", "a", "10.0.0.1") + r2 = dns.rrset.from_text("FOO", 30, "in", "a", "10.0.0.2", "10.0.0.1") self.assertNotEqual(r1, r2) def testCodec2003(self): - r1 = dns.rrset.from_text_list('Königsgäßchen', 30, 'in', 'ns', - ['Königsgäßchen']) - r2 = dns.rrset.from_text_list('xn--knigsgsschen-lcb0w', 30, 'in', 'ns', - ['xn--knigsgsschen-lcb0w']) + r1 = dns.rrset.from_text_list( + "Königsgäßchen", 30, "in", "ns", ["Königsgäßchen"] + ) + r2 = dns.rrset.from_text_list( + "xn--knigsgsschen-lcb0w", 30, "in", "ns", ["xn--knigsgsschen-lcb0w"] + ) self.assertEqual(r1, r2) - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testCodec2008(self): - r1 = dns.rrset.from_text_list('Königsgäßchen', 30, 'in', 'ns', - ['Königsgäßchen'], - idna_codec=dns.name.IDNA_2008) - r2 = dns.rrset.from_text_list('xn--knigsgchen-b4a3dun', 30, 'in', 'ns', - ['xn--knigsgchen-b4a3dun'], - idna_codec=dns.name.IDNA_2008) + r1 = dns.rrset.from_text_list( + "Königsgäßchen", + 30, + "in", + "ns", + ["Königsgäßchen"], + idna_codec=dns.name.IDNA_2008, + ) + r2 = dns.rrset.from_text_list( + "xn--knigsgchen-b4a3dun", + 30, + "in", + "ns", + ["xn--knigsgchen-b4a3dun"], + idna_codec=dns.name.IDNA_2008, + ) self.assertEqual(r1, r2) def testCopy(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) r2 = r1.copy() self.assertFalse(r1 is r2) self.assertTrue(r1 == r2) def testFullMatch1(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - self.assertTrue(r1.full_match(r1.name, dns.rdataclass.IN, - dns.rdatatype.A, dns.rdatatype.NONE)) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + self.assertTrue( + r1.full_match( + r1.name, dns.rdataclass.IN, dns.rdatatype.A, dns.rdatatype.NONE + ) + ) def testFullMatch2(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) r1.deleting = dns.rdataclass.NONE - self.assertTrue(r1.full_match(r1.name, dns.rdataclass.IN, - dns.rdatatype.A, dns.rdatatype.NONE, - dns.rdataclass.NONE)) + self.assertTrue( + r1.full_match( + r1.name, + dns.rdataclass.IN, + dns.rdatatype.A, + dns.rdatatype.NONE, + dns.rdataclass.NONE, + ) + ) def testNoFullMatch1(self): - n = dns.name.from_text('bar', None) - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - self.assertFalse(r1.full_match(n, dns.rdataclass.IN, - dns.rdatatype.A, dns.rdatatype.NONE, - dns.rdataclass.ANY)) + n = dns.name.from_text("bar", None) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + self.assertFalse( + r1.full_match( + n, + dns.rdataclass.IN, + dns.rdatatype.A, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + ) + ) def testNoFullMatch2(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) r1.deleting = dns.rdataclass.NONE - self.assertFalse(r1.full_match(r1.name, dns.rdataclass.IN, - dns.rdatatype.A, dns.rdatatype.NONE, - dns.rdataclass.ANY)) + self.assertFalse( + r1.full_match( + r1.name, + dns.rdataclass.IN, + dns.rdatatype.A, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + ) + ) def testNoFullMatch3(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - self.assertFalse(r1.full_match(r1.name, dns.rdataclass.IN, - dns.rdatatype.MX, dns.rdatatype.NONE, - dns.rdataclass.ANY)) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + self.assertFalse( + r1.full_match( + r1.name, + dns.rdataclass.IN, + dns.rdatatype.MX, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + ) + ) def testMatchCompatibilityWithFullMatch(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - self.assertTrue(r1.match(r1.name, dns.rdataclass.IN, - dns.rdatatype.A, dns.rdatatype.NONE)) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + self.assertTrue( + r1.match(r1.name, dns.rdataclass.IN, dns.rdatatype.A, dns.rdatatype.NONE) + ) def testMatchCompatibilityWithRdatasetMatch(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - self.assertTrue(r1.match(dns.rdataclass.IN, dns.rdatatype.A, - dns.rdatatype.NONE)) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + self.assertTrue( + r1.match(dns.rdataclass.IN, dns.rdatatype.A, dns.rdatatype.NONE) + ) def testToRdataset(self): - r1 = dns.rrset.from_text_list('foo', 30, 'in', 'a', - ['10.0.0.1', '10.0.0.2']) - r2 = dns.rdataset.from_text_list('in', 'a', 30, - ['10.0.0.1', '10.0.0.2']) + r1 = dns.rrset.from_text_list("foo", 30, "in", "a", ["10.0.0.1", "10.0.0.2"]) + r2 = dns.rdataset.from_text_list("in", "a", 30, ["10.0.0.1", "10.0.0.2"]) self.assertEqual(r1.to_rdataset(), r2) def testFromRdata(self): - rdata1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rdata2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') - expected_rrs = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1', - '10.0.0.2') - rrs = dns.rrset.from_rdata('foo', 300, rdata1, rdata2) + rdata1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rdata2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") + expected_rrs = dns.rrset.from_text( + "foo", 300, "in", "a", "10.0.0.1", "10.0.0.2" + ) + rrs = dns.rrset.from_rdata("foo", 300, rdata1, rdata2) self.assertEqual(rrs, expected_rrs) def testEmptyList(self): def bad(): - rrs = dns.rrset.from_rdata_list('foo', 300, []) + rrs = dns.rrset.from_rdata_list("foo", 300, []) + self.assertRaises(ValueError, bad) def testTTLMinimization(self): - rrs = dns.rrset.RRset(dns.name.from_text('foo'), - dns.rdataclass.IN, dns.rdatatype.A) - rdata1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rdata2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') + rrs = dns.rrset.RRset( + dns.name.from_text("foo"), dns.rdataclass.IN, dns.rdatatype.A + ) + rdata1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rdata2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") rrs.add(rdata1, 300) self.assertEqual(rrs.ttl, 300) rrs.add(rdata2, 30) @@ -166,26 +198,33 @@ class RRsetTestCase(unittest.TestCase): self.assertEqual(rrs.ttl, 3) def testNotEqualOtherType(self): - rrs = dns.rrset.RRset(dns.name.from_text('foo'), - dns.rdataclass.IN, dns.rdatatype.A) + rrs = dns.rrset.RRset( + dns.name.from_text("foo"), dns.rdataclass.IN, dns.rdatatype.A + ) self.assertFalse(rrs == 123) def testRepr(self): - rrset = dns.rrset.from_text('foo', 30, 'in', 'a', '10.0.0.1', - '10.0.0.2') - self.assertEqual(repr(rrset), - ', <10.0.0.2>]>') + rrset = dns.rrset.from_text("foo", 30, "in", "a", "10.0.0.1", "10.0.0.2") + self.assertEqual(repr(rrset), ", <10.0.0.2>]>") rrset.deleting = dns.rdataclass.NONE - self.assertEqual(repr(rrset), - ', <10.0.0.2>]>') + self.assertEqual( + repr(rrset), + ", <10.0.0.2>]>", + ) rrset = dns.rrset.from_text( - 'foo', 30, 'in', 'rrsig', - 'A 1 3 3600 20200701000000 20200601000000 1 NAME Ym9ndXM=') - self.assertEqual(repr(rrset), - ']>') - -if __name__ == '__main__': + "foo", + 30, + "in", + "rrsig", + "A 1 3 3600 20200701000000 20200601000000 1 NAME Ym9ndXM=", + ) + self.assertEqual( + repr(rrset), + "]>", + ) + + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_rrset_reader.py b/tests/test_rrset_reader.py index 8d4255e2..3ae942e2 100644 --- a/tests/test_rrset_reader.py +++ b/tests/test_rrset_reader.py @@ -3,17 +3,17 @@ import pytest import dns.rrset from dns.zonefile import read_rrsets -expected_mx_1= dns.rrset.from_text('name.', 300, 'in', 'mx', '10 a.', '20 b.') -expected_mx_2 = dns.rrset.from_text('name.', 10, 'in', 'mx', '10 a.', '20 b.') -expected_mx_3 = dns.rrset.from_text('foo.', 10, 'in', 'mx', '10 a.') -expected_mx_4 = dns.rrset.from_text('bar.', 10, 'in', 'mx', '20 b.') -expected_mx_5 = dns.rrset.from_text('foo.example.', 10, 'in', 'mx', - '10 a.example.') -expected_mx_6 = dns.rrset.from_text('bar.example.', 10, 'in', 'mx', '20 b.') -expected_mx_7 = dns.rrset.from_text('foo', 10, 'in', 'mx', '10 a') -expected_mx_8 = dns.rrset.from_text('bar', 10, 'in', 'mx', '20 b.') -expected_ns_1 = dns.rrset.from_text('name.', 300, 'in', 'ns', 'hi.') -expected_ns_2 = dns.rrset.from_text('name.', 300, 'ch', 'ns', 'hi.') +expected_mx_1 = dns.rrset.from_text("name.", 300, "in", "mx", "10 a.", "20 b.") +expected_mx_2 = dns.rrset.from_text("name.", 10, "in", "mx", "10 a.", "20 b.") +expected_mx_3 = dns.rrset.from_text("foo.", 10, "in", "mx", "10 a.") +expected_mx_4 = dns.rrset.from_text("bar.", 10, "in", "mx", "20 b.") +expected_mx_5 = dns.rrset.from_text("foo.example.", 10, "in", "mx", "10 a.example.") +expected_mx_6 = dns.rrset.from_text("bar.example.", 10, "in", "mx", "20 b.") +expected_mx_7 = dns.rrset.from_text("foo", 10, "in", "mx", "10 a") +expected_mx_8 = dns.rrset.from_text("bar", 10, "in", "mx", "20 b.") +expected_ns_1 = dns.rrset.from_text("name.", 300, "in", "ns", "hi.") +expected_ns_2 = dns.rrset.from_text("name.", 300, "ch", "ns", "hi.") + def equal_rrsets(a, b): # return True iff. a and b have the same rrsets regardless of order @@ -24,108 +24,119 @@ def equal_rrsets(a, b): return False return True + def test_name_ttl_rdclass_forced(): - input='''; + input = """; mx 10 a mx 20 b. -ns hi''' - rrsets = read_rrsets(input, name='name', ttl=300) +ns hi""" + rrsets = read_rrsets(input, name="name", ttl=300) assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1]) assert rrsets[0].ttl == 300 assert rrsets[1].ttl == 300 + def test_name_ttl_rdclass_forced_rdata_split(): - input='''; + input = """; mx 10 a ns hi -mx 20 b.''' - rrsets = read_rrsets(input, name='name', ttl=300) +mx 20 b.""" + rrsets = read_rrsets(input, name="name", ttl=300) assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1]) + def test_name_ttl_rdclass_rdtype_forced(): - input='''; + input = """; 10 a -20 b.''' - rrsets = read_rrsets(input, name='name', ttl=300, rdtype='mx') +20 b.""" + rrsets = read_rrsets(input, name="name", ttl=300, rdtype="mx") assert equal_rrsets(rrsets, [expected_mx_1]) + def test_name_rdclass_forced(): - input = '''30 mx 10 a + input = """30 mx 10 a 10 mx 20 b. -''' - rrsets = read_rrsets(input, name='name') +""" + rrsets = read_rrsets(input, name="name") assert equal_rrsets(rrsets, [expected_mx_2]) assert rrsets[0].ttl == 10 + def test_rdclass_forced(): - input = '''; + input = """; foo 20 mx 10 a bar 30 mx 20 b. -''' +""" rrsets = read_rrsets(input) assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4]) + def test_rdclass_forced_with_origin(): - input = '''; + input = """; foo 20 mx 10 a bar.example. 30 mx 20 b. -''' - rrsets = read_rrsets(input, origin='example') +""" + rrsets = read_rrsets(input, origin="example") assert equal_rrsets(rrsets, [expected_mx_5, expected_mx_6]) def test_rdclass_forced_with_origin_relativized(): - input = '''; + input = """; foo 20 mx 10 a.example. bar.example. 30 mx 20 b. -''' - rrsets = read_rrsets(input, origin='example', relativize=True) +""" + rrsets = read_rrsets(input, origin="example", relativize=True) assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8]) + def test_rdclass_matching_default_tolerated(): - input = '''; + input = """; foo 20 mx 10 a.example. bar.example. 30 in mx 20 b. -''' - rrsets = read_rrsets(input, origin='example', relativize=True, - rdclass=None) +""" + rrsets = read_rrsets(input, origin="example", relativize=True, rdclass=None) assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8]) + def test_rdclass_not_matching_default_rejected(): - input = '''; + input = """; foo 20 mx 10 a.example. bar.example. 30 ch mx 20 b. -''' +""" with pytest.raises(dns.exception.SyntaxError): - rrsets = read_rrsets(input, origin='example', relativize=True, - rdclass=None) + rrsets = read_rrsets(input, origin="example", relativize=True, rdclass=None) + def test_default_rdclass_is_none(): - input = '' + input = "" with pytest.raises(TypeError): - rrsets = read_rrsets(input, default_rdclass=None, origin='example', - relativize=True) + rrsets = read_rrsets( + input, default_rdclass=None, origin="example", relativize=True + ) + def test_name_rdclass_rdtype_force(): # No real-world usage should do this, but it can be specified so we test it. - input = '''; + input = """; 30 10 a 10 20 b. -''' - rrsets = read_rrsets(input, name='name', rdtype='mx') +""" + rrsets = read_rrsets(input, name="name", rdtype="mx") assert equal_rrsets(rrsets, [expected_mx_1]) assert rrsets[0].ttl == 10 + def test_rdclass_rdtype_force(): # No real-world usage should do this, but it can be specified so we test it. - input = '''; + input = """; foo 30 10 a bar 30 20 b. -''' - rrsets = read_rrsets(input, rdtype='mx') +""" + rrsets = read_rrsets(input, rdtype="mx") assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4]) + # also weird but legal -#input5 = '''foo 30 10 a -#bar 10 20 foo. +# input5 = '''foo 30 10 a +# bar 10 20 foo. #''' diff --git a/tests/test_serial.py b/tests/test_serial.py index a9ef2df0..8e6c4396 100644 --- a/tests/test_serial.py +++ b/tests/test_serial.py @@ -4,12 +4,15 @@ import unittest import dns.serial + def S2(v): return dns.serial.Serial(v, bits=2) + def S8(v): return dns.serial.Serial(v, bits=8) + class SerialTestCase(unittest.TestCase): def test_rfc_1982_2_bit_cases(self): self.assertEqual(S2(0) + S2(1), S2(1)) @@ -67,13 +70,17 @@ class SerialTestCase(unittest.TestCase): def test_addition_bounds(self): self.assertRaises(ValueError, lambda: S8(0) + 128) self.assertRaises(ValueError, lambda: S8(0) - 128) + def bad1(): v = S8(0) v += 128 + self.assertRaises(ValueError, bad1) + def bad2(): v = S8(0) v -= 128 + self.assertRaises(ValueError, bad2) def test_casting(self): @@ -85,32 +92,36 @@ class SerialTestCase(unittest.TestCase): self.assertTrue(S8(0) >= 255) def test_uncastable(self): - self.assertRaises(ValueError, lambda: S8(0) + 'a') - self.assertRaises(ValueError, lambda: S8(0) - 'a') + self.assertRaises(ValueError, lambda: S8(0) + "a") + self.assertRaises(ValueError, lambda: S8(0) - "a") + def bad1(): v = S8(0) - v += 'a' + v += "a" + self.assertRaises(ValueError, bad1) + def bad2(): v = S8(0) - v -= 'a' + v -= "a" + self.assertRaises(ValueError, bad2) def test_uncomparable(self): self.assertFalse(S8(0) == S2(0)) - self.assertFalse(S8(0) == 'a') - self.assertTrue(S8(0) != 'a') - self.assertRaises(TypeError, lambda: S8(0) < 'a') - self.assertRaises(TypeError, lambda: S8(0) <= 'a') - self.assertRaises(TypeError, lambda: S8(0) > 'a') - self.assertRaises(TypeError, lambda: S8(0) >= 'a') + self.assertFalse(S8(0) == "a") + self.assertTrue(S8(0) != "a") + self.assertRaises(TypeError, lambda: S8(0) < "a") + self.assertRaises(TypeError, lambda: S8(0) <= "a") + self.assertRaises(TypeError, lambda: S8(0) > "a") + self.assertRaises(TypeError, lambda: S8(0) >= "a") def test_modulo(self): self.assertEqual(S8(-1), 255) self.assertEqual(S8(257), 1) def test_repr(self): - self.assertEqual(repr(S8(1)), 'dns.serial.Serial(1, 8)') + self.assertEqual(repr(S8(1)), "dns.serial.Serial(1, 8)") def test_not_equal(self): self.assertNotEqual(S8(0), S8(1)) diff --git a/tests/test_set.py b/tests/test_set.py index 103a6e9f..4db97ba2 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -23,8 +23,8 @@ import dns.set # for convenience S = dns.set.Set -class SetTestCase(unittest.TestCase): +class SetTestCase(unittest.TestCase): def testLen1(self): s1 = S() self.assertEqual(len(s1), 0) @@ -334,5 +334,6 @@ class SetTestCase(unittest.TestCase): s &= S([1, 2]) self.assertEqual(s, S([1, 2])) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_svcb.py b/tests/test_svcb.py index 34fc9ad3..d2d8dd79 100644 --- a/tests/test_svcb.py +++ b/tests/test_svcb.py @@ -10,18 +10,19 @@ from dns.tokenizer import Tokenizer from tests.util import here + class SVCBTestCase(unittest.TestCase): def check_valid_inputs(self, inputs): expected = inputs[0] for text in inputs: - rr = dns.rdata.from_text('IN', 'SVCB', text) + rr = dns.rdata.from_text("IN", "SVCB", text) new_text = rr.to_text() self.assertEqual(expected, new_text) def check_invalid_inputs(self, inputs): for text in inputs: with self.assertRaises((dns.exception.SyntaxError, ValueError)): - dns.rdata.from_text('IN', 'SVCB', text) + dns.rdata.from_text("IN", "SVCB", text) def test_svcb_general_invalid(self): invalid_inputs = ( @@ -29,19 +30,19 @@ class SVCBTestCase(unittest.TestCase): "1 . alpn=h2 alpn=h3", "1 . alpn=h2 key1=h3", # Quoted keys - "1 . \"alpn=h2\"", + '1 . "alpn=h2"', # Invalid space "1 . alpn= h2", "1 . alpn =h2", "1 . alpn = h2", - "1 . alpn= \"h2\"", + '1 . alpn= "h2"', "1 . =alpn", ) self.check_invalid_inputs(invalid_inputs) def test_svcb_mandatory(self): valid_inputs = ( - "1 . mandatory=\"alpn,no-default-alpn\" alpn=\"h2\" no-default-alpn", + '1 . mandatory="alpn,no-default-alpn" alpn="h2" no-default-alpn', "1 . mandatory=alpn,no-default-alpn alpn=h2 no-default-alpn", "1 . mandatory=key1,key2 alpn=h2 no-default-alpn", "1 . mandatory=alpn,no-default-alpn key1=\\002h2 key2", @@ -81,12 +82,12 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_alpn(self): valid_inputs_two_items = ( - "1 . alpn=\"h2,h3\"", + '1 . alpn="h2,h3"', "1 . alpn=h2,h3", "1 . alpn=h\\050,h3", - "1 . alpn=\"h\\050,h3\"", + '1 . alpn="h\\050,h3"', "1 . alpn=\\h2,h3", - "1 . alpn=\"h2\\,h3\"", + '1 . alpn="h2\\,h3"', "1 . alpn=h2\\,h3", "1 . alpn=h2\\044h3", "1 . key1=\\002h2\\002h3", @@ -94,7 +95,7 @@ class SVCBTestCase(unittest.TestCase): self.check_valid_inputs(valid_inputs_two_items) valid_inputs_one_item = ( - "1 . alpn=\"h2\\\\,h3\"", + '1 . alpn="h2\\\\,h3"', "1 . alpn=h2\\\\,h3", "1 . alpn=h2\\092\\044h3", "1 . key1=\\005h2,h3", @@ -106,13 +107,13 @@ class SVCBTestCase(unittest.TestCase): "1 . alpn=", "1 . alpn=h2,,h3", "1 . alpn=01234567890abcdef01234567890abcdef01234567890abcdef" - "01234567890abcdef01234567890abcdef01234567890abcdef" - "01234567890abcdef01234567890abcdef01234567890abcdef" - "01234567890abcdef01234567890abcdef01234567890abcdef" - "01234567890abcdef01234567890abcdef01234567890abcdef" - "01234567890abcdef", - "1 . alpn=\",h2,h3\"", - "1 . alpn=\"h2,h3,\"", + "01234567890abcdef01234567890abcdef01234567890abcdef" + "01234567890abcdef01234567890abcdef01234567890abcdef" + "01234567890abcdef01234567890abcdef01234567890abcdef" + "01234567890abcdef01234567890abcdef01234567890abcdef" + "01234567890abcdef", + '1 . alpn=",h2,h3"', + '1 . alpn="h2,h3,"', "1 . key1", "1 . key1=", "1 . key1=\\000", @@ -122,18 +123,18 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_no_default_alpn(self): valid_inputs = ( - "1 . alpn=\"h2\" no-default-alpn", - "1 . alpn=\"h2\" no-default-alpn=\"\"", - "1 . alpn=\"h2\" key2", - "1 . alpn=\"h2\" key2=\"\"", + '1 . alpn="h2" no-default-alpn', + '1 . alpn="h2" no-default-alpn=""', + '1 . alpn="h2" key2', + '1 . alpn="h2" key2=""', ) self.check_valid_inputs(valid_inputs) invalid_inputs = ( "1 . no-default-alpn", - "1 . no-default-alpn=\"\"", + '1 . no-default-alpn=""', "1 . key2", - "1 . key2=\"\"", + '1 . key2=""', "1 . alpn=h2 no-default-alpn=foo", "1 . alpn=h2 no-default-alpn=", "1 . alpn=h2 key2=foo", @@ -143,7 +144,7 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_port(self): valid_inputs = ( - "1 . port=\"53\"", + '1 . port="53"', "1 . port=53", "1 . key3=\\000\\053", ) @@ -165,7 +166,7 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_ipv4hint(self): valid_inputs = ( - "1 . ipv4hint=\"0.0.0.0,1.1.1.1\"", + '1 . ipv4hint="0.0.0.0,1.1.1.1"', "1 . ipv4hint=0.0.0.0,1.1.1.1", "1 . key4=\\000\\000\\000\\000\\001\\001\\001\\001", ) @@ -184,7 +185,7 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_ech(self): valid_inputs = ( - "1 . ech=\"Zm9vMA==\"", + '1 . ech="Zm9vMA=="', "1 . ech=Zm9vMA==", "1 . key5=foo0", "1 . key5=\\102\\111\\111\\048", @@ -203,12 +204,12 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_ipv6hint(self): valid_inputs = ( - "1 . ipv6hint=\"::4,1::\"", + '1 . ipv6hint="::4,1::"', "1 . ipv6hint=::4,1::", "1 . key6=\\000\\000\\000\\000\\000\\000\\000\\000" - "\\000\\000\\000\\000\\000\\000\\000\\004" - "\\000\\001\\000\\000\\000\\000\\000\\000" - "\\000\\000\\000\\000\\000\\000\\000\\000", + "\\000\\000\\000\\000\\000\\000\\000\\004" + "\\000\\001\\000\\000\\000\\000\\000\\000" + "\\000\\000\\000\\000\\000\\000\\000\\000", ) self.check_valid_inputs(valid_inputs) @@ -227,17 +228,17 @@ class SVCBTestCase(unittest.TestCase): def test_svcb_unknown(self): valid_inputs_one_key = ( - "1 . key23=\"key45\"", + '1 . key23="key45"', "1 . key23=key45", "1 . key23=key\\052\\053", - "1 . key23=\"key\\052\\053\"", + '1 . key23="key\\052\\053"', "1 . key23=\\107\\101\\121\\052\\053", ) self.check_valid_inputs(valid_inputs_one_key) valid_inputs_one_key_empty = ( "1 . key23", - "1 . key23=\"\"", + '1 . key23=""', ) self.check_valid_inputs(valid_inputs_one_key_empty) @@ -249,24 +250,25 @@ class SVCBTestCase(unittest.TestCase): valid_inputs_two_keys = ( "1 . key24 key48", - "1 . key24=\"\" key48", + '1 . key24="" key48', ) self.check_valid_inputs(valid_inputs_two_keys) def test_svcb_wire(self): valid_inputs = ( - "1 . mandatory=\"alpn,port\" alpn=\"h2\" port=\"257\"", + '1 . mandatory="alpn,port" alpn="h2" port="257"', "\\# 24 0001 00 0000000400010003 00010003026832 000300020101", ) self.check_valid_inputs(valid_inputs) - everything = \ - "100 foo.com. mandatory=\"alpn,port\" alpn=\"h2,h3\" " \ - " no-default-alpn port=\"12345\" ech=\"abcd\" " \ - " ipv4hint=1.2.3.4,4.3.2.1 ipv6hint=1::2,3::4" \ - " key12345=\"foo\"" - rr = dns.rdata.from_text('IN', 'SVCB', everything) - rr2 = dns.rdata.from_text('IN', 'SVCB', rr.to_generic().to_text()) + everything = ( + '100 foo.com. mandatory="alpn,port" alpn="h2,h3" ' + ' no-default-alpn port="12345" ech="abcd" ' + " ipv4hint=1.2.3.4,4.3.2.1 ipv6hint=1::2,3::4" + ' key12345="foo"' + ) + rr = dns.rdata.from_text("IN", "SVCB", everything) + rr2 = dns.rdata.from_text("IN", "SVCB", rr.to_generic().to_text()) self.assertEqual(rr, rr2) invalid_inputs = ( @@ -283,19 +285,19 @@ class SVCBTestCase(unittest.TestCase): self.check_invalid_inputs(invalid_inputs) def test_misc_escape(self): - rdata = dns.rdata.from_text('in', 'svcb', '1 . alpn=\\010\\010') + rdata = dns.rdata.from_text("in", "svcb", "1 . alpn=\\010\\010") expected = '1 . alpn="\\\\010\\\\010"' self.assertEqual(rdata.to_text(), expected) with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'svcb', '1 . alpn=\\0') + dns.rdata.from_text("in", "svcb", "1 . alpn=\\0") with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'svcb', '1 . alpn=\\00') + dns.rdata.from_text("in", "svcb", "1 . alpn=\\00") with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'svcb', '1 . alpn=\\00q') + dns.rdata.from_text("in", "svcb", "1 . alpn=\\00q") with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'svcb', '1 . alpn=\\256') + dns.rdata.from_text("in", "svcb", "1 . alpn=\\256") # This doesn't usually get exercised, so we do it directly. - gp = dns.rdtypes.svcbbase.GenericParam.from_value('\\001\\002') + gp = dns.rdtypes.svcbbase.GenericParam.from_value("\\001\\002") expected = '"\\001\\002"' self.assertEqual(gp.to_text(), expected) @@ -322,8 +324,8 @@ class SVCBTestCase(unittest.TestCase): self.assertTrue(text_token.is_identifier) text_tokenizer.unget(text_token) generic_tokenizer.unget(generic_token) - text_rdata = dns.rdata.from_text('IN', 'SVCB', text_tokenizer) - generic_rdata = dns.rdata.from_text('IN', 'SVCB', generic_tokenizer) + text_rdata = dns.rdata.from_text("IN", "SVCB", text_tokenizer) + generic_rdata = dns.rdata.from_text("IN", "SVCB", generic_tokenizer) self.assertEqual(text_rdata, generic_rdata) def test_svcb_spec_failure_cases(self): @@ -347,43 +349,42 @@ class SVCBTestCase(unittest.TestCase): # mandatory list (Section 7). "1 foo.example.com. mandatory=key123,key123 key123=abc", ) - self.check_invalid_inputs(failure_cases); + self.check_invalid_inputs(failure_cases) def test_alias_mode(self): - rd = dns.rdata.from_text('in', 'svcb', '0 .') + rd = dns.rdata.from_text("in", "svcb", "0 .") self.assertEqual(len(rd.params), 0) self.assertEqual(rd.target, dns.name.root) - self.assertEqual(rd.to_text(), '0 .') - rd = dns.rdata.from_text('in', 'svcb', '0 elsewhere.') - self.assertEqual(rd.target, dns.name.from_text('elsewhere.')) + self.assertEqual(rd.to_text(), "0 .") + rd = dns.rdata.from_text("in", "svcb", "0 elsewhere.") + self.assertEqual(rd.target, dns.name.from_text("elsewhere.")) self.assertEqual(len(rd.params), 0) # provoke 'parameters in AliasMode' from text. with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('in', 'svcb', '0 elsewhere. alpn=h2') + dns.rdata.from_text("in", "svcb", "0 elsewhere. alpn=h2") # provoke 'parameters in AliasMode' from wire too. - wire = bytes.fromhex('0000000000000400010003') + wire = bytes.fromhex("0000000000000400010003") with self.assertRaises(dns.exception.FormError): - dns.rdata.from_wire('in', 'svcb', wire, 0, len(wire)) + dns.rdata.from_wire("in", "svcb", wire, 0, len(wire)) def test_immutability(self): - alpn = dns.rdtypes.svcbbase.ALPNParam.from_value(['h2', 'h3']) + alpn = dns.rdtypes.svcbbase.ALPNParam.from_value(["h2", "h3"]) with self.assertRaises(TypeError): - alpn.ids[0] = 'foo' + alpn.ids[0] = "foo" with self.assertRaises(TypeError): del alpn.ids[0] with self.assertRaises(TypeError): - alpn.ids = 'foo' + alpn.ids = "foo" with self.assertRaises(TypeError): del alpn.ids def test_alias_not_compressed(self): - rrs = dns.rrset.from_text('elsewhere.', 300, 'in', 'svcb', - '0 elseWhere.') + rrs = dns.rrset.from_text("elsewhere.", 300, "in", "svcb", "0 elseWhere.") output = io.BytesIO() compress = {} rrs.to_wire(output, compress) wire = output.getvalue() # Just one of these assertions is enough, but we do both to show # the bug we're checking is fixed. - assert not wire.endswith(b'\xc0\x00') - assert wire.endswith(b'\x09elseWhere\x00') + assert not wire.endswith(b"\xc0\x00") + assert wire.endswith(b"\x09elseWhere\x00") diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 06f41776..d8b1723d 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -22,22 +22,22 @@ import dns.tokenizer Token = dns.tokenizer.Token -class TokenizerTestCase(unittest.TestCase): +class TokenizerTestCase(unittest.TestCase): def testStr(self): - tok = dns.tokenizer.Tokenizer('foo') + tok = dns.tokenizer.Tokenizer("foo") token = tok.get() - self.assertEqual(token, Token(dns.tokenizer.IDENTIFIER, 'foo')) + self.assertEqual(token, Token(dns.tokenizer.IDENTIFIER, "foo")) def testQuotedString1(self): tok = dns.tokenizer.Tokenizer(r'"foo"') token = tok.get() - self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, 'foo')) + self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, "foo")) def testQuotedString2(self): tok = dns.tokenizer.Tokenizer(r'""') token = tok.get() - self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, '')) + self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, "")) def testQuotedString3(self): tok = dns.tokenizer.Tokenizer(r'"\"foo\""') @@ -47,8 +47,7 @@ class TokenizerTestCase(unittest.TestCase): def testQuotedString4(self): tok = dns.tokenizer.Tokenizer(r'"foo\010bar"') token = tok.get() - self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, - 'foo\\010bar')) + self.assertEqual(token, Token(dns.tokenizer.QUOTED_STRING, "foo\\010bar")) def testQuotedString5(self): with self.assertRaises(dns.exception.UnexpectedEnd): @@ -66,223 +65,233 @@ class TokenizerTestCase(unittest.TestCase): tok.get() def testEmpty1(self): - tok = dns.tokenizer.Tokenizer('') + tok = dns.tokenizer.Tokenizer("") token = tok.get() self.assertTrue(token.is_eof()) def testEmpty2(self): - tok = dns.tokenizer.Tokenizer('') + tok = dns.tokenizer.Tokenizer("") token1 = tok.get() token2 = tok.get() self.assertTrue(token1.is_eof() and token2.is_eof()) def testEOL(self): - tok = dns.tokenizer.Tokenizer('\n') + tok = dns.tokenizer.Tokenizer("\n") token1 = tok.get() token2 = tok.get() self.assertTrue(token1.is_eol() and token2.is_eof()) def testWS1(self): - tok = dns.tokenizer.Tokenizer(' \n') + tok = dns.tokenizer.Tokenizer(" \n") token1 = tok.get() self.assertTrue(token1.is_eol()) def testWS2(self): - tok = dns.tokenizer.Tokenizer(' \n') + tok = dns.tokenizer.Tokenizer(" \n") token1 = tok.get(want_leading=True) self.assertTrue(token1.is_whitespace()) def testComment1(self): - tok = dns.tokenizer.Tokenizer(' ;foo\n') + tok = dns.tokenizer.Tokenizer(" ;foo\n") token1 = tok.get() self.assertTrue(token1.is_eol()) def testComment2(self): - tok = dns.tokenizer.Tokenizer(' ;foo\n') + tok = dns.tokenizer.Tokenizer(" ;foo\n") token1 = tok.get(want_comment=True) token2 = tok.get() - self.assertEqual(token1, Token(dns.tokenizer.COMMENT, 'foo')) + self.assertEqual(token1, Token(dns.tokenizer.COMMENT, "foo")) self.assertTrue(token2.is_eol()) def testComment3(self): - tok = dns.tokenizer.Tokenizer(' ;foo bar\n') + tok = dns.tokenizer.Tokenizer(" ;foo bar\n") token1 = tok.get(want_comment=True) token2 = tok.get() - self.assertEqual(token1, Token(dns.tokenizer.COMMENT, 'foo bar')) + self.assertEqual(token1, Token(dns.tokenizer.COMMENT, "foo bar")) self.assertTrue(token2.is_eol()) def testMultiline1(self): - tok = dns.tokenizer.Tokenizer('( foo\n\n bar\n)') + tok = dns.tokenizer.Tokenizer("( foo\n\n bar\n)") tokens = list(iter(tok)) - self.assertEqual(tokens, [Token(dns.tokenizer.IDENTIFIER, 'foo'), - Token(dns.tokenizer.IDENTIFIER, 'bar')]) + self.assertEqual( + tokens, + [ + Token(dns.tokenizer.IDENTIFIER, "foo"), + Token(dns.tokenizer.IDENTIFIER, "bar"), + ], + ) def testMultiline2(self): - tok = dns.tokenizer.Tokenizer('( foo\n\n bar\n)\n') + tok = dns.tokenizer.Tokenizer("( foo\n\n bar\n)\n") tokens = list(iter(tok)) - self.assertEqual(tokens, [Token(dns.tokenizer.IDENTIFIER, 'foo'), - Token(dns.tokenizer.IDENTIFIER, 'bar'), - Token(dns.tokenizer.EOL, '\n')]) + self.assertEqual( + tokens, + [ + Token(dns.tokenizer.IDENTIFIER, "foo"), + Token(dns.tokenizer.IDENTIFIER, "bar"), + Token(dns.tokenizer.EOL, "\n"), + ], + ) def testMultiline3(self): with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('foo)') + tok = dns.tokenizer.Tokenizer("foo)") list(iter(tok)) def testMultiline4(self): with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('((foo)') + tok = dns.tokenizer.Tokenizer("((foo)") list(iter(tok)) def testUnget1(self): - tok = dns.tokenizer.Tokenizer('foo') + tok = dns.tokenizer.Tokenizer("foo") t1 = tok.get() tok.unget(t1) t2 = tok.get() self.assertEqual(t1, t2) self.assertEqual(t1.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t1.value, 'foo') + self.assertEqual(t1.value, "foo") def testUnget2(self): with self.assertRaises(dns.tokenizer.UngetBufferFull): - tok = dns.tokenizer.Tokenizer('foo') + tok = dns.tokenizer.Tokenizer("foo") t1 = tok.get() tok.unget(t1) tok.unget(t1) def testGetEOL1(self): - tok = dns.tokenizer.Tokenizer('\n') + tok = dns.tokenizer.Tokenizer("\n") t = tok.get_eol() - self.assertEqual(t, '\n') + self.assertEqual(t, "\n") def testGetEOL2(self): - tok = dns.tokenizer.Tokenizer('') + tok = dns.tokenizer.Tokenizer("") t = tok.get_eol() - self.assertEqual(t, '') + self.assertEqual(t, "") def testEscapedDelimiter1(self): - tok = dns.tokenizer.Tokenizer(r'ch\ ld') + tok = dns.tokenizer.Tokenizer(r"ch\ ld") t = tok.get() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, r'ch\ ld') + self.assertEqual(t.value, r"ch\ ld") def testEscapedDelimiter2(self): - tok = dns.tokenizer.Tokenizer(r'ch\032ld') + tok = dns.tokenizer.Tokenizer(r"ch\032ld") t = tok.get() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, r'ch\032ld') + self.assertEqual(t.value, r"ch\032ld") def testEscapedDelimiter3(self): - tok = dns.tokenizer.Tokenizer(r'ch\ild') + tok = dns.tokenizer.Tokenizer(r"ch\ild") t = tok.get() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, r'ch\ild') + self.assertEqual(t.value, r"ch\ild") def testEscapedDelimiter1u(self): - tok = dns.tokenizer.Tokenizer(r'ch\ ld') + tok = dns.tokenizer.Tokenizer(r"ch\ ld") t = tok.get().unescape() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, r'ch ld') + self.assertEqual(t.value, r"ch ld") def testEscapedDelimiter2u(self): - tok = dns.tokenizer.Tokenizer(r'ch\032ld') + tok = dns.tokenizer.Tokenizer(r"ch\032ld") t = tok.get().unescape() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, 'ch ld') + self.assertEqual(t.value, "ch ld") def testEscapedDelimiter3u(self): - tok = dns.tokenizer.Tokenizer(r'ch\ild') + tok = dns.tokenizer.Tokenizer(r"ch\ild") t = tok.get().unescape() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, r'child') + self.assertEqual(t.value, r"child") def testGetUInt(self): - tok = dns.tokenizer.Tokenizer('1234') + tok = dns.tokenizer.Tokenizer("1234") v = tok.get_int() self.assertEqual(v, 1234) with self.assertRaises(dns.exception.SyntaxError): tok = dns.tokenizer.Tokenizer('"1234"') tok.get_int() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('q1234') + tok = dns.tokenizer.Tokenizer("q1234") tok.get_int() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('281474976710656') + tok = dns.tokenizer.Tokenizer("281474976710656") tok.get_uint48() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('4294967296') + tok = dns.tokenizer.Tokenizer("4294967296") tok.get_uint32() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('65536') + tok = dns.tokenizer.Tokenizer("65536") tok.get_uint16() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('256') + tok = dns.tokenizer.Tokenizer("256") tok.get_uint8() # Even though it is badly named get_int(), it's really get_unit! with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('-1234') + tok = dns.tokenizer.Tokenizer("-1234") tok.get_int() # get_uint16 can do other bases too, and has a custom error # for base 8. - tok = dns.tokenizer.Tokenizer('177777') + tok = dns.tokenizer.Tokenizer("177777") self.assertEqual(tok.get_uint16(base=8), 65535) with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('200000') + tok = dns.tokenizer.Tokenizer("200000") tok.get_uint16(base=8) def testGetString(self): - tok = dns.tokenizer.Tokenizer('foo') + tok = dns.tokenizer.Tokenizer("foo") v = tok.get_string() - self.assertEqual(v, 'foo') + self.assertEqual(v, "foo") tok = dns.tokenizer.Tokenizer('"foo"') v = tok.get_string() - self.assertEqual(v, 'foo') - tok = dns.tokenizer.Tokenizer('abcdefghij') + self.assertEqual(v, "foo") + tok = dns.tokenizer.Tokenizer("abcdefghij") v = tok.get_string(max_length=10) - self.assertEqual(v, 'abcdefghij') + self.assertEqual(v, "abcdefghij") with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('abcdefghij') + tok = dns.tokenizer.Tokenizer("abcdefghij") tok.get_string(max_length=9) - tok = dns.tokenizer.Tokenizer('') + tok = dns.tokenizer.Tokenizer("") with self.assertRaises(dns.exception.SyntaxError): tok.get_string() def testMultiLineWithComment(self): - tok = dns.tokenizer.Tokenizer('( ; abc\n)') + tok = dns.tokenizer.Tokenizer("( ; abc\n)") tok.get_eol() # Nothing to assert here, as we're testing tok.get_eol() does NOT # raise. def testEOLAfterComment(self): - tok = dns.tokenizer.Tokenizer('; abc\n') + tok = dns.tokenizer.Tokenizer("; abc\n") t = tok.get() self.assertTrue(t.is_eol()) def testEOFAfterComment(self): - tok = dns.tokenizer.Tokenizer('; abc') + tok = dns.tokenizer.Tokenizer("; abc") t = tok.get() self.assertTrue(t.is_eof()) def testMultiLineWithEOFAfterComment(self): with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('( ; abc') + tok = dns.tokenizer.Tokenizer("( ; abc") tok.get_eol() def testEscapeUnexpectedEnd(self): with self.assertRaises(dns.exception.UnexpectedEnd): - tok = dns.tokenizer.Tokenizer('\\') + tok = dns.tokenizer.Tokenizer("\\") tok.get() def testEscapeBounds(self): with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('\\256') + tok = dns.tokenizer.Tokenizer("\\256") tok.get().unescape() with self.assertRaises(dns.exception.SyntaxError): - tok = dns.tokenizer.Tokenizer('\\256') + tok = dns.tokenizer.Tokenizer("\\256") tok.get().unescape_to_bytes() def testGetUngetRegetComment(self): - tok = dns.tokenizer.Tokenizer(';comment') + tok = dns.tokenizer.Tokenizer(";comment") t1 = tok.get(want_comment=True) tok.unget(t1) t2 = tok.get(want_comment=True) @@ -314,12 +323,12 @@ class TokenizerTestCase(unittest.TestCase): tok.get().unescape_to_bytes() def testTokenMisc(self): - t1 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, 'hi') - t2 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, 'hi') - t3 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, 'there') + t1 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, "hi") + t2 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, "hi") + t3 = dns.tokenizer.Token(dns.tokenizer.IDENTIFIER, "there") self.assertEqual(t1, t2) - self.assertFalse(t1 == 'hi') # not NotEqual because we want to use == - self.assertNotEqual(t1, 'hi') + self.assertFalse(t1 == "hi") # not NotEqual because we want to use == + self.assertNotEqual(t1, "hi") self.assertNotEqual(t1, t3) self.assertEqual(str(t1), '3 "hi"') @@ -330,17 +339,17 @@ class TokenizerTestCase(unittest.TestCase): def testStdinFilename(self): tok = dns.tokenizer.Tokenizer() - self.assertEqual(tok.filename, '') + self.assertEqual(tok.filename, "") def testBytesLiteral(self): - tok = dns.tokenizer.Tokenizer(b'this is input') - self.assertEqual(tok.get().value, 'this') - self.assertEqual(tok.filename, '') - tok = dns.tokenizer.Tokenizer(b'this is input', 'myfilename') - self.assertEqual(tok.filename, 'myfilename') + tok = dns.tokenizer.Tokenizer(b"this is input") + self.assertEqual(tok.get().value, "this") + self.assertEqual(tok.filename, "") + tok = dns.tokenizer.Tokenizer(b"this is input", "myfilename") + self.assertEqual(tok.filename, "myfilename") def testUngetBranches(self): - tok = dns.tokenizer.Tokenizer(b' this is input') + tok = dns.tokenizer.Tokenizer(b" this is input") t = tok.get(want_leading=True) tok.unget(t) t = tok.get(want_leading=True) @@ -348,8 +357,8 @@ class TokenizerTestCase(unittest.TestCase): tok.unget(t) t = tok.get() self.assertEqual(t.ttype, dns.tokenizer.IDENTIFIER) - self.assertEqual(t.value, 'this') - tok = dns.tokenizer.Tokenizer(b'; this is input\n') + self.assertEqual(t.value, "this") + tok = dns.tokenizer.Tokenizer(b"; this is input\n") t = tok.get(want_comment=True) tok.unget(t) t = tok.get(want_comment=True) @@ -358,5 +367,6 @@ class TokenizerTestCase(unittest.TestCase): t = tok.get() self.assertEqual(t.ttype, dns.tokenizer.EOL) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index ce533c51..8e2744ab 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -25,7 +25,7 @@ class DB(dns.transaction.TransactionManager): return Transaction(self, replacement, False) def origin_information(self): - return (dns.name.from_text('example'), True, dns.name.empty) + return (dns.name.from_text("example"), True, dns.name.empty) def get_class(self): return dns.rdataclass.IN @@ -79,136 +79,130 @@ class Transaction(dns.transaction.Transaction): def _set_origin(self, origin): pass + @pytest.fixture def db(): db = DB() - rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') + rrset = dns.rrset.from_text("content", 300, "in", "txt", "content") db.rdatasets[(rrset.name, rrset.rdtype, 0)] = rrset return db + def test_basic(db): # successful txn with db.writer() as txn: - rrset = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.1', '10.0.0.2') + rrset = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1", "10.0.0.2") txn.add(rrset) assert txn.name_exists(rrset.name) - assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ - rrset + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == rrset # rollback with pytest.raises(Exception): with db.writer() as txn: - rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.3', '10.0.0.4') + rrset2 = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.3", "10.0.0.4") txn.add(rrset2) raise Exception() - assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ - rrset + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == rrset with db.writer() as txn: txn.delete(rrset.name) - assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \ - is None + assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) is None + def test_get(db): with db.writer() as txn: - content = dns.name.from_text('content', None) + content = dns.name.from_text("content", None) rdataset = txn.get(content, dns.rdatatype.TXT) assert rdataset is not None - assert rdataset[0].strings == (b'content',) + assert rdataset[0].strings == (b"content",) assert isinstance(rdataset, dns.rdataset.ImmutableRdataset) + def test_add(db): with db.writer() as txn: - rrset = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.1', '10.0.0.2') + rrset = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1", "10.0.0.2") txn.add(rrset) - rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.3', '10.0.0.4') + rrset2 = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.3", "10.0.0.4") txn.add(rrset2) - expected = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.1', '10.0.0.2', - '10.0.0.3', '10.0.0.4') - assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ - expected + expected = dns.rrset.from_text( + "foo", 300, "in", "a", "10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4" + ) + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == expected + def test_replacement(db): with db.writer() as txn: - rrset = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.1', '10.0.0.2') + rrset = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1", "10.0.0.2") txn.add(rrset) - rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', - '10.0.0.3', '10.0.0.4') + rrset2 = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.3", "10.0.0.4") txn.replace(rrset2) - assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ - rrset2 + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == rrset2 + def test_delete(db): with db.writer() as txn: - txn.delete(dns.name.from_text('nonexistent', None)) - content = dns.name.from_text('content', None) - content2 = dns.name.from_text('content2', None) + txn.delete(dns.name.from_text("nonexistent", None)) + content = dns.name.from_text("content", None) + content2 = dns.name.from_text("content2", None) txn.delete(content) assert not txn.name_exists(content) txn.delete(content2, dns.rdatatype.TXT) - rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content') + rrset = dns.rrset.from_text("content", 300, "in", "txt", "new-content") txn.add(rrset) assert txn.name_exists(content) txn.delete(content, dns.rdatatype.TXT) assert not txn.name_exists(content) - rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content') + rrset = dns.rrset.from_text("content2", 300, "in", "txt", "new-content") txn.delete(rrset) content_keys = [k for k in db.rdatasets if k[0] == content] assert len(content_keys) == 0 + def test_delete_exact(db): with db.writer() as txn: - rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'bad-content') + rrset = dns.rrset.from_text("content", 300, "in", "txt", "bad-content") with pytest.raises(dns.transaction.DeleteNotExact): txn.delete_exact(rrset) - rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'bad-content') + rrset = dns.rrset.from_text("content2", 300, "in", "txt", "bad-content") with pytest.raises(dns.transaction.DeleteNotExact): txn.delete_exact(rrset) with pytest.raises(dns.transaction.DeleteNotExact): txn.delete_exact(rrset.name) with pytest.raises(dns.transaction.DeleteNotExact): txn.delete_exact(rrset.name, dns.rdatatype.TXT) - rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') + rrset = dns.rrset.from_text("content", 300, "in", "txt", "content") txn.delete_exact(rrset) - assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \ - is None + assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) is None + def test_parameter_forms(db): with db.writer() as txn: - foo = dns.name.from_text('foo', None) - rdataset = dns.rdataset.from_text('in', 'a', 300, - '10.0.0.1', '10.0.0.2') - rdata1 = dns.rdata.from_text('in', 'a', '10.0.0.3') - rdata2 = dns.rdata.from_text('in', 'a', '10.0.0.4') + foo = dns.name.from_text("foo", None) + rdataset = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") + rdata1 = dns.rdata.from_text("in", "a", "10.0.0.3") + rdata2 = dns.rdata.from_text("in", "a", "10.0.0.4") txn.add(foo, rdataset) txn.add(foo, 100, rdata1) txn.add(foo, 30, rdata2) - expected = dns.rrset.from_text('foo', 30, 'in', 'a', - '10.0.0.1', '10.0.0.2', - '10.0.0.3', '10.0.0.4') - assert db.rdatasets[(foo, rdataset.rdtype, 0)] == \ - expected + expected = dns.rrset.from_text( + "foo", 30, "in", "a", "10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4" + ) + assert db.rdatasets[(foo, rdataset.rdtype, 0)] == expected with db.writer() as txn: txn.delete(foo, rdataset) txn.delete(foo, rdata1) txn.delete(foo, rdata2) - assert db.rdatasets.get((foo, rdataset.rdtype, 0)) \ - is None + assert db.rdatasets.get((foo, rdataset.rdtype, 0)) is None + def test_bad_parameters(db): with db.writer() as txn: with pytest.raises(TypeError): txn.add(1) with pytest.raises(TypeError): - rrset = dns.rrset.from_text('bar', 300, 'in', 'txt', 'bar') + rrset = dns.rrset.from_text("bar", 300, "in", "txt", "bar") txn.add(rrset, 1) with pytest.raises(ValueError): - foo = dns.name.from_text('foo', None) - rdata = dns.rdata.from_text('in', 'a', '10.0.0.3') + foo = dns.name.from_text("foo", None) + rdata = dns.rdata.from_text("in", "a", "10.0.0.3") txn.add(foo, 0x100000000, rdata) with pytest.raises(TypeError): txn.add(foo) @@ -217,21 +211,22 @@ def test_bad_parameters(db): with pytest.raises(TypeError): txn.add(foo, 300) with pytest.raises(TypeError): - txn.add(foo, 300, 'hi') + txn.add(foo, 300, "hi") with pytest.raises(TypeError): - txn.add(foo, 'hi') + txn.add(foo, "hi") with pytest.raises(TypeError): txn.delete() with pytest.raises(TypeError): txn.delete(1) + def test_cannot_store_non_origin_soa(db): with pytest.raises(ValueError): with db.writer() as txn: - rrset = dns.rrset.from_text('foo', 300, 'in', 'SOA', - '. . 1 2 3 4 5') + rrset = dns.rrset.from_text("foo", 300, "in", "SOA", ". . 1 2 3 4 5") txn.add(rrset) + example_text = """$TTL 3600 $ORIGIN example. @ soa foo bar 1 2 3 4 5 @@ -253,55 +248,58 @@ ns2 3600 IN A 10.0.0.2 ns3 3600 IN A 10.0.0.3 """ + @pytest.fixture(params=[dns.zone.Zone, dns.versioned.Zone]) def zone(request): return dns.zone.from_text(example_text, zone_factory=request.param) + def test_zone_basic(zone): with zone.writer() as txn: - txn.delete(dns.name.from_text('bar.foo', None)) - rd = dns.rdata.from_text('in', 'ns', 'ns3') + txn.delete(dns.name.from_text("bar.foo", None)) + rd = dns.rdata.from_text("in", "ns", "ns3") txn.add(dns.name.empty, 3600, rd) - rd = dns.rdata.from_text('in', 'a', '10.0.0.3') - txn.add(dns.name.from_text('ns3', None), 3600, rd) + rd = dns.rdata.from_text("in", "a", "10.0.0.3") + txn.add(dns.name.from_text("ns3", None), 3600, rd) output = zone.to_text() assert output == example_text_output + def test_explicit_rollback_and_commit(zone): with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('bar.foo', None)) + txn.delete(dns.name.from_text("bar.foo", None)) txn.rollback() - assert zone.get_node('bar.foo') is not None + assert zone.get_node("bar.foo") is not None with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('bar.foo', None)) + txn.delete(dns.name.from_text("bar.foo", None)) txn.commit() - assert zone.get_node('bar.foo') is None + assert zone.get_node("bar.foo") is None with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() - txn.delete(dns.name.from_text('bar.foo', None)) + txn.delete(dns.name.from_text("bar.foo", None)) with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() - txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + txn.add("bar.foo", 300, dns.rdata.from_text("in", "txt", "hi")) with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() - txn.replace('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + txn.replace("bar.foo", 300, dns.rdata.from_text("in", "txt", "hi")) with pytest.raises(dns.transaction.AlreadyEnded): with zone.reader() as txn: txn.rollback() - txn.get('bar.foo', 'in', 'mx') + txn.get("bar.foo", "in", "mx") with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() - txn.delete_exact('bar.foo') + txn.delete_exact("bar.foo") with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() - txn.name_exists('bar.foo') + txn.name_exists("bar.foo") with pytest.raises(dns.transaction.AlreadyEnded): with zone.writer() as txn: txn.rollback() @@ -324,6 +322,7 @@ def test_explicit_rollback_and_commit(zone): for rdataset in txn: pass + def test_zone_changed(zone): # Read-only is not changed! with zone.reader() as txn: @@ -331,57 +330,60 @@ def test_zone_changed(zone): # delete an existing name with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('bar.foo', None)) + txn.delete(dns.name.from_text("bar.foo", None)) assert txn.changed() # delete a nonexistent name with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('unknown.bar.foo', None)) + txn.delete(dns.name.from_text("unknown.bar.foo", None)) assert not txn.changed() # delete a nonexistent rdataset from an extant node with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('bar.foo', None), 'txt') + txn.delete(dns.name.from_text("bar.foo", None), "txt") assert not txn.changed() # add an rdataset to an extant Node with zone.writer() as txn: assert not txn.changed() - txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + txn.add("bar.foo", 300, dns.rdata.from_text("in", "txt", "hi")) assert txn.changed() # add an rdataset to a nonexistent Node with zone.writer() as txn: assert not txn.changed() - txn.add('foo.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + txn.add("foo.foo", 300, dns.rdata.from_text("in", "txt", "hi")) assert txn.changed() + def test_zone_base_layer(zone): with zone.writer() as txn: # Get a set from the zone layer rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) - expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') + expected = dns.rdataset.from_text("in", "ns", 300, "ns1", "ns2") assert rdataset == expected + def test_zone_transaction_layer(zone): with zone.writer() as txn: # Make a change - rd = dns.rdata.from_text('in', 'ns', 'ns3') + rd = dns.rdata.from_text("in", "ns", "ns3") txn.add(dns.name.empty, 3600, rd) # Get a set from the transaction layer - expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3') + expected = dns.rdataset.from_text("in", "ns", 300, "ns1", "ns2", "ns3") rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) assert rdataset == expected assert txn.name_exists(dns.name.empty) - ns1 = dns.name.from_text('ns1', None) + ns1 = dns.name.from_text("ns1", None) assert txn.name_exists(ns1) - ns99 = dns.name.from_text('ns99', None) + ns99 = dns.name.from_text("ns99", None) assert not txn.name_exists(ns99) + def test_zone_add_and_delete(zone): with zone.writer() as txn: - a99 = dns.name.from_text('a99', None) - a100 = dns.name.from_text('a100', None) - a101 = dns.name.from_text('a101', None) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + a99 = dns.name.from_text("a99", None) + a100 = dns.name.from_text("a100", None) + a101 = dns.name.from_text("a101", None) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.99") txn.add(a99, rds) txn.delete(a99, dns.rdatatype.A) txn.delete(a100, dns.rdatatype.A) @@ -389,7 +391,7 @@ def test_zone_add_and_delete(zone): assert not txn.name_exists(a99) assert not txn.name_exists(a100) assert not txn.name_exists(a101) - ns1 = dns.name.from_text('ns1', None) + ns1 = dns.name.from_text("ns1", None) txn.delete(ns1, dns.rdatatype.A) assert not txn.name_exists(ns1) with zone.writer() as txn: @@ -402,32 +404,35 @@ def test_zone_add_and_delete(zone): assert not txn.name_exists(a99) assert txn.name_exists(a100) + def test_write_after_rollback(zone): with pytest.raises(ExpectedException): with zone.writer() as txn: - a99 = dns.name.from_text('a99', None) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + a99 = dns.name.from_text("a99", None) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.99") txn.add(a99, rds) raise ExpectedException with zone.writer() as txn: - a99 = dns.name.from_text('a99', None) - rds = dns.rdataset.from_text('in', 'a', 300, '10.99.99.99') + a99 = dns.name.from_text("a99", None) + rds = dns.rdataset.from_text("in", "a", 300, "10.99.99.99") txn.add(a99, rds) - assert zone.get_rdataset('a99', 'a') == rds + assert zone.get_rdataset("a99", "a") == rds + def test_zone_get_deleted(zone): with zone.writer() as txn: - ns1 = dns.name.from_text('ns1', None) + ns1 = dns.name.from_text("ns1", None) assert txn.get(ns1, dns.rdatatype.A) is not None txn.delete(ns1) assert txn.get(ns1, dns.rdatatype.A) is None - ns2 = dns.name.from_text('ns2', None) + ns2 = dns.name.from_text("ns2", None) txn.delete(ns2, dns.rdatatype.A) assert txn.get(ns2, dns.rdatatype.A) is None + def test_zone_bad_class(zone): with zone.writer() as txn: - rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2') + rds = dns.rdataset.from_text("ch", "ns", 300, "ns1", "ns2") with pytest.raises(ValueError): txn.add(dns.name.empty, rds) with pytest.raises(ValueError): @@ -435,30 +440,31 @@ def test_zone_bad_class(zone): with pytest.raises(ValueError): txn.delete(dns.name.empty, rds) + def test_update_serial(zone): # basic with zone.writer() as txn: txn.update_serial() - rdataset = zone.find_rdataset('@', 'soa') + rdataset = zone.find_rdataset("@", "soa") assert rdataset[0].serial == 2 # max with zone.writer() as txn: - txn.update_serial(0xffffffff, False) - rdataset = zone.find_rdataset('@', 'soa') - assert rdataset[0].serial == 0xffffffff + txn.update_serial(0xFFFFFFFF, False) + rdataset = zone.find_rdataset("@", "soa") + assert rdataset[0].serial == 0xFFFFFFFF # wraparound to 1 with zone.writer() as txn: txn.update_serial() - rdataset = zone.find_rdataset('@', 'soa') + rdataset = zone.find_rdataset("@", "soa") assert rdataset[0].serial == 1 # trying to set to zero sets to 1 with zone.writer() as txn: txn.update_serial(0, False) - rdataset = zone.find_rdataset('@', 'soa') + rdataset = zone.find_rdataset("@", "soa") assert rdataset[0].serial == 1 with pytest.raises(KeyError): with zone.writer() as txn: - txn.update_serial(name=dns.name.from_text('unknown', None)) + txn.update_serial(name=dns.name.from_text("unknown", None)) with pytest.raises(ValueError): with zone.writer() as txn: txn.update_serial(-1) @@ -466,14 +472,16 @@ def test_update_serial(zone): with zone.writer() as txn: txn.update_serial(2**31) + class ExpectedException(Exception): pass + def test_zone_rollback(zone): - a99 = dns.name.from_text('a99.example.') + a99 = dns.name.from_text("a99.example.") try: with zone.writer() as txn: - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.99") txn.add(a99, rds) assert txn.name_exists(a99) raise ExpectedException @@ -481,12 +489,14 @@ def test_zone_rollback(zone): pass assert not zone.get_node(a99) + def test_zone_ooz_name(zone): with zone.writer() as txn: with pytest.raises(KeyError): - a99 = dns.name.from_text('a99.not-example.') + a99 = dns.name.from_text("a99.not-example.") assert txn.name_exists(a99) + def test_zone_iteration(zone): expected = {} for (name, rdataset) in zone.iterate_rdatasets(): @@ -497,8 +507,9 @@ def test_zone_iteration(zone): actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset assert actual == expected + def test_iteration_in_replacement_txn(zone): - rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8') + rds = dns.rdataset.from_text("in", "a", 300, "1.2.3.4", "5.6.7.8") expected = {} expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds with zone.writer(True) as txn: @@ -508,8 +519,9 @@ def test_iteration_in_replacement_txn(zone): actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset assert actual == expected + def test_replacement_commit(zone): - rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8') + rds = dns.rdataset.from_text("in", "a", 300, "1.2.3.4", "5.6.7.8") expected = {} expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds with zone.writer(True) as txn: @@ -520,9 +532,10 @@ def test_replacement_commit(zone): actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset assert actual == expected + def test_replacement_get(zone): with zone.writer(True) as txn: - rds = txn.get(dns.name.empty, 'soa') + rds = txn.get(dns.name.empty, "soa") assert rds is None @@ -530,14 +543,16 @@ def test_replacement_get(zone): def vzone(): return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone) + def test_vzone_read_only(vzone): with vzone.reader() as txn: rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) - expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') + expected = dns.rdataset.from_text("in", "ns", 300, "ns1", "ns2") assert rdataset == expected with pytest.raises(dns.transaction.ReadOnly): txn.replace(dns.name.empty, expected) + def test_vzone_multiple_versions(vzone): assert len(vzone._versions) == 1 vzone.set_max_versions(None) # unlimited! @@ -547,39 +562,41 @@ def test_vzone_multiple_versions(vzone): txn.update_serial() with vzone.writer() as txn: txn.update_serial(1000, False) - rdataset = vzone.find_rdataset('@', 'soa') + rdataset = vzone.find_rdataset("@", "soa") assert rdataset[0].serial == 1000 assert len(vzone._versions) == 4 with vzone.reader(id=5) as txn: assert txn.version.id == 5 - rdataset = txn.get('@', 'soa') + rdataset = txn.get("@", "soa") assert rdataset[0].serial == 1000 with vzone.reader(serial=1000) as txn: assert txn.version.id == 5 - rdataset = txn.get('@', 'soa') + rdataset = txn.get("@", "soa") assert rdataset[0].serial == 1000 vzone.set_max_versions(2) assert len(vzone._versions) == 2 # The ones that survived should be 3 and 1000 - rdataset = vzone._versions[0].get_rdataset(dns.name.empty, - dns.rdatatype.SOA, - dns.rdatatype.NONE) + rdataset = vzone._versions[0].get_rdataset( + dns.name.empty, dns.rdatatype.SOA, dns.rdatatype.NONE + ) assert rdataset[0].serial == 3 - rdataset = vzone._versions[1].get_rdataset(dns.name.empty, - dns.rdatatype.SOA, - dns.rdatatype.NONE) + rdataset = vzone._versions[1].get_rdataset( + dns.name.empty, dns.rdatatype.SOA, dns.rdatatype.NONE + ) assert rdataset[0].serial == 1000 with pytest.raises(ValueError): vzone.set_max_versions(0) + # for debugging if needed def _dump(zone): for v in zone._versions: - print('VERSION', v.id) + print("VERSION", v.id) for (name, n) in v.nodes.items(): for rdataset in n: print(rdataset.to_text(name)) + def test_vzone_open_txn_pins_versions(vzone): assert len(vzone._versions) == 1 vzone.set_max_versions(None) # unlimited! @@ -592,11 +609,11 @@ def test_vzone_open_txn_pins_versions(vzone): with vzone.reader(id=2) as txn: vzone.set_max_versions(1) with vzone.reader(id=3) as txn: - rdataset = txn.get('@', 'soa') + rdataset = txn.get("@", "soa") assert rdataset[0].serial == 2 assert len(vzone._versions) == 4 assert len(vzone._versions) == 1 - rdataset = vzone.find_rdataset('@', 'soa') + rdataset = vzone.find_rdataset("@", "soa") assert vzone._versions[0].id == 5 assert rdataset[0].serial == 4 @@ -612,16 +629,16 @@ try: # wait until two blocks while len(zone._write_waiters) == 0: time.sleep(0.01) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.98') - txn.add('a98', rds) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.98") + txn.add("a98", rds) def run_two(zone): # wait until one has the lock so we know we will block if we # get the call done before the sleep in one completes one_got_lock.wait() with zone.writer() as txn: - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') - txn.add('a99', rds) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.99") + txn.add("a99", rds) def test_vzone_concurrency(vzone): t1 = threading.Thread(target=run_one, args=(vzone,)) @@ -631,8 +648,8 @@ try: t1.join() t2.join() with vzone.reader() as txn: - assert txn.name_exists('a98') - assert txn.name_exists('a99') + assert txn.name_exists("a98") + assert txn.name_exists("a99") except ImportError: # pragma: no cover pass diff --git a/tests/test_tsig.py b/tests/test_tsig.py index 4c793d53..6571d5b0 100644 --- a/tests/test_tsig.py +++ b/tests/test_tsig.py @@ -11,30 +11,25 @@ import dns.tsigkeyring import dns.message import dns.rdtypes.ANY.TKEY -keyring = dns.tsigkeyring.from_text( - { - 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ==' - } -) +keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="}) -keyname = dns.name.from_text('keyname') +keyname = dns.name.from_text("keyname") class TSIGTestCase(unittest.TestCase): - def test_get_context(self): - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha256') + key = dns.tsig.Key("foo.com", "abcd", "hmac-sha256") ctx = dns.tsig.get_context(key) - self.assertEqual(ctx.name, 'hmac-sha256') - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') + self.assertEqual(ctx.name, "hmac-sha256") + key = dns.tsig.Key("foo.com", "abcd", "hmac-sha512") ctx = dns.tsig.get_context(key) - self.assertEqual(ctx.name, 'hmac-sha512') - bogus = dns.tsig.Key('foo.com', 'abcd', 'bogus') + self.assertEqual(ctx.name, "hmac-sha512") + bogus = dns.tsig.Key("foo.com", "abcd", "bogus") with self.assertRaises(NotImplementedError): dns.tsig.get_context(bogus) def test_tsig_message_properties(self): - m = dns.message.make_query('example', 'a') + m = dns.message.make_query("example", "a") self.assertIsNone(m.keyname) self.assertIsNone(m.keyalgorithm) self.assertIsNone(m.tsig_error) @@ -42,77 +37,76 @@ class TSIGTestCase(unittest.TestCase): self.assertEqual(m.keyname, keyname) self.assertEqual(m.keyalgorithm, dns.tsig.default_algorithm) self.assertEqual(m.tsig_error, dns.rcode.NOERROR) - m = dns.message.make_query('example', 'a') + m = dns.message.make_query("example", "a") m.use_tsig(keyring, keyname, tsig_error=dns.rcode.BADKEY) self.assertEqual(m.tsig_error, dns.rcode.BADKEY) def test_verify_mac_for_context(self): - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') + key = dns.tsig.Key("foo.com", "abcd", "hmac-sha512") ctx = dns.tsig.get_context(key) - bad_expected = b'xxxxxxxxxx' + bad_expected = b"xxxxxxxxxx" with self.assertRaises(dns.tsig.BadSignature): ctx.verify(bad_expected) def test_validate(self): # make message and grab the TSIG - m = dns.message.make_query('example', 'a') + m = dns.message.make_query("example", "a") m.use_tsig(keyring, keyname, algorithm=dns.tsig.HMAC_SHA256) w = m.to_wire() tsig = m.tsig[0] # get the time and create a key with matching characteristics now = int(time.time()) - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha256') + key = dns.tsig.Key("foo.com", "abcd", "hmac-sha256") # add enough to the time to take it over the fudge amount with self.assertRaises(dns.tsig.BadTime): - dns.tsig.validate(w, key, dns.name.from_text('foo.com'), - tsig, now + 1000, b'', 0) + dns.tsig.validate( + w, key, dns.name.from_text("foo.com"), tsig, now + 1000, b"", 0 + ) # change the key name with self.assertRaises(dns.tsig.BadKey): - dns.tsig.validate(w, key, dns.name.from_text('bar.com'), - tsig, now, b'', 0) + dns.tsig.validate(w, key, dns.name.from_text("bar.com"), tsig, now, b"", 0) # change the key algorithm - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') + key = dns.tsig.Key("foo.com", "abcd", "hmac-sha512") with self.assertRaises(dns.tsig.BadAlgorithm): - dns.tsig.validate(w, key, dns.name.from_text('foo.com'), - tsig, now, b'', 0) + dns.tsig.validate(w, key, dns.name.from_text("foo.com"), tsig, now, b"", 0) def test_gssapi_context(self): def verify_signature(data, mac): - if data == b'throw': + if data == b"throw": raise Exception return None # mock out the gssapi context to return some dummy values gssapi_context_mock = Mock() - gssapi_context_mock.get_signature.return_value = b'xxxxxxxxxxx' + gssapi_context_mock.get_signature.return_value = b"xxxxxxxxxxx" gssapi_context_mock.verify_signature.side_effect = verify_signature # create the key and add it to the keyring - keyname = 'gsstsigtest' - key = dns.tsig.Key(keyname, gssapi_context_mock, 'gss-tsig') + keyname = "gsstsigtest" + key = dns.tsig.Key(keyname, gssapi_context_mock, "gss-tsig") ctx = dns.tsig.get_context(key) - self.assertEqual(ctx.name, 'gss-tsig') + self.assertEqual(ctx.name, "gss-tsig") gsskeyname = dns.name.from_text(keyname) keyring[gsskeyname] = key # make sure we can get the keyring (no exception == success) text = dns.tsigkeyring.to_text(keyring) - self.assertNotEqual(text, '') + self.assertNotEqual(text, "") # test exceptional case for _verify_mac_for_context with self.assertRaises(dns.tsig.BadSignature): - ctx.update(b'throw') - ctx.verify(b'bogus') + ctx.update(b"throw") + ctx.verify(b"bogus") gssapi_context_mock.verify_signature.assert_called() self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1) # simulate case where TKEY message is used to establish the context; # first, the query from the client - tkey_message = dns.message.make_query(keyname, 'tkey', 'any') + tkey_message = dns.message.make_query(keyname, "tkey", "any") # test existent/non-existent keys in the keyring adapted_keyring = dns.tsig.GSSTSigAdapter(keyring) @@ -127,20 +121,28 @@ class TSIGTestCase(unittest.TestCase): # create a response, TKEY and turn it into bytes, simulating the server # sending the response to the query tkey_response = dns.message.make_response(tkey_message) - key = base64.b64decode('KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY') - tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY, - dns.rdatatype.TKEY, - dns.name.from_text('gss-tsig.'), - 1594203795, 1594206664, - 3, 0, key) + key = base64.b64decode("KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY") + tkey = dns.rdtypes.ANY.TKEY.TKEY( + dns.rdataclass.ANY, + dns.rdatatype.TKEY, + dns.name.from_text("gss-tsig."), + 1594203795, + 1594206664, + 3, + 0, + key, + ) # add the TKEY answer and sign it tkey_response.set_rcode(dns.rcode.NOERROR) tkey_response.answer = [ - dns.rrset.from_rdata(dns.name.from_text(keyname), 0, tkey)] - tkey_response.use_tsig(keyring=dns.tsig.GSSTSigAdapter(keyring), - keyname=gsskeyname, - algorithm=dns.tsig.GSS_TSIG) + dns.rrset.from_rdata(dns.name.from_text(keyname), 0, tkey) + ] + tkey_response.use_tsig( + keyring=dns.tsig.GSSTSigAdapter(keyring), + keyname=gsskeyname, + algorithm=dns.tsig.GSS_TSIG, + ) # "send" it to the client tkey_wire = tkey_response.to_wire() @@ -158,7 +160,7 @@ class TSIGTestCase(unittest.TestCase): # create example message and go to/from wire to simulate sign/verify # of regular messages - a_message = dns.message.make_query('example', 'a') + a_message = dns.message.make_query("example", "a") a_message.use_tsig(dns.tsig.GSSTSigAdapter(keyring), gsskeyname) a_wire = a_message.to_wire() # not raising is passing @@ -171,14 +173,14 @@ class TSIGTestCase(unittest.TestCase): self.assertEqual(gssapi_context_mock.verify_signature.call_count, 3) def test_sign_and_validate(self): - m = dns.message.make_query('example', 'a') + m = dns.message.make_query("example", "a") m.use_tsig(keyring, keyname) w = m.to_wire() # not raising is passing dns.message.from_wire(w, keyring) def test_validate_with_bad_keyring(self): - m = dns.message.make_query('example', 'a') + m = dns.message.make_query("example", "a") m.use_tsig(keyring, keyname) w = m.to_wire() @@ -190,14 +192,14 @@ class TSIGTestCase(unittest.TestCase): dns.message.from_wire(w, lambda m, n: None) def test_sign_and_validate_with_other_data(self): - m = dns.message.make_query('example', 'a') - m.use_tsig(keyring, keyname, other_data=b'other') + m = dns.message.make_query("example", "a") + m.use_tsig(keyring, keyname, other_data=b"other") w = m.to_wire() # not raising is passing dns.message.from_wire(w, keyring) def test_sign_respond_and_validate(self): - mq = dns.message.make_query('example', 'a') + mq = dns.message.make_query("example", "a") mq.use_tsig(keyring, keyname) wq = mq.to_wire() mq_with_tsig = dns.message.from_wire(wq, keyring) @@ -206,30 +208,33 @@ class TSIGTestCase(unittest.TestCase): wr = mr.to_wire() dns.message.from_wire(wr, keyring, request_mac=mq_with_tsig.mac) - def make_message_pair(self, qname='example', rdtype='A', tsig_error=0): + def make_message_pair(self, qname="example", rdtype="A", tsig_error=0): q = dns.message.make_query(qname, rdtype) q.use_tsig(keyring=keyring, keyname=keyname) q.to_wire() # to set q.mac r = dns.message.make_response(q, tsig_error=tsig_error) - return(q, r) + return (q, r) def test_peer_errors(self): - items = [(dns.rcode.BADSIG, dns.tsig.PeerBadSignature), - (dns.rcode.BADKEY, dns.tsig.PeerBadKey), - (dns.rcode.BADTIME, dns.tsig.PeerBadTime), - (dns.rcode.BADTRUNC, dns.tsig.PeerBadTruncation), - (99, dns.tsig.PeerError), - ] + items = [ + (dns.rcode.BADSIG, dns.tsig.PeerBadSignature), + (dns.rcode.BADKEY, dns.tsig.PeerBadKey), + (dns.rcode.BADTIME, dns.tsig.PeerBadTime), + (dns.rcode.BADTRUNC, dns.tsig.PeerBadTruncation), + (99, dns.tsig.PeerError), + ] for err, ex in items: q, r = self.make_message_pair(tsig_error=err) w = r.to_wire() + def bad(): dns.message.from_wire(w, keyring=keyring, request_mac=q.mac) + self.assertRaises(ex, bad) def _test_truncated_algorithm(self, alg, length): - key = dns.tsig.Key('foo', b'abcdefg', algorithm=alg) - q = dns.message.make_query('example', 'a') + key = dns.tsig.Key("foo", b"abcdefg", algorithm=alg) + q = dns.message.make_query("example", "a") q.use_tsig(key) q2 = dns.message.from_wire(q.to_wire(), keyring=key) @@ -247,22 +252,22 @@ class TSIGTestCase(unittest.TestCase): self._test_truncated_algorithm(dns.tsig.HMAC_SHA512_256, 256) def _test_text_format(self, alg): - key = dns.tsig.Key('foo', b'abcdefg', algorithm=alg) - q = dns.message.make_query('example', 'a') + key = dns.tsig.Key("foo", b"abcdefg", algorithm=alg) + q = dns.message.make_query("example", "a") q.use_tsig(key) _ = q.to_wire() text = q.tsig[0].to_text() - tsig2 = dns.rdata.from_text('ANY', 'TSIG', text) + tsig2 = dns.rdata.from_text("ANY", "TSIG", text) self.assertEqual(tsig2, q.tsig[0]) - q = dns.message.make_query('example', 'a') - q.use_tsig(key, other_data=b'abc') + q = dns.message.make_query("example", "a") + q.use_tsig(key, other_data=b"abc") q.use_tsig(key) _ = q.to_wire() text = q.tsig[0].to_text() - tsig2 = dns.rdata.from_text('ANY', 'TSIG', text) + tsig2 = dns.rdata.from_text("ANY", "TSIG", text) self.assertEqual(tsig2, q.tsig[0]) def test_text_hmac_sha256_128(self): @@ -275,15 +280,16 @@ class TSIGTestCase(unittest.TestCase): self._test_text_format(dns.tsig.HMAC_SHA512_256) def test_non_gss_key_repr(self): - key = dns.tsig.Key('foo', b'0123456789abcdef' * 2, - algorithm=dns.tsig.HMAC_SHA256) - self.assertEqual(repr(key), - "") + key = dns.tsig.Key("foo", None, algorithm=dns.tsig.GSS_TSIG) + self.assertEqual(repr(key), "") diff --git a/tests/test_tsigkeyring.py b/tests/test_tsigkeyring.py index 47f88067..f8c889e8 100644 --- a/tests/test_tsigkeyring.py +++ b/tests/test_tsigkeyring.py @@ -6,26 +6,20 @@ import unittest import dns.tsig import dns.tsigkeyring -text_keyring = { - 'keyname.' : ('hmac-sha256.', 'NjHwPsMKjdN++dOfE5iAiQ==') -} +text_keyring = {"keyname.": ("hmac-sha256.", "NjHwPsMKjdN++dOfE5iAiQ==")} -alt_text_keyring = { - 'keyname.' : (dns.tsig.HMAC_SHA256, 'NjHwPsMKjdN++dOfE5iAiQ==') -} +alt_text_keyring = {"keyname.": (dns.tsig.HMAC_SHA256, "NjHwPsMKjdN++dOfE5iAiQ==")} -old_text_keyring = { - 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ==' -} +old_text_keyring = {"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="} -key = dns.tsig.Key('keyname.', 'NjHwPsMKjdN++dOfE5iAiQ==') +key = dns.tsig.Key("keyname.", "NjHwPsMKjdN++dOfE5iAiQ==") -rich_keyring = { key.name : key } +rich_keyring = {key.name: key} -old_rich_keyring = { key.name : key.secret } +old_rich_keyring = {key.name: key.secret} -class TSIGKeyRingTestCase(unittest.TestCase): +class TSIGKeyRingTestCase(unittest.TestCase): def test_from_text(self): """text keyring -> rich keyring""" rkeyring = dns.tsigkeyring.from_text(text_keyring) diff --git a/tests/test_ttl.py b/tests/test_ttl.py index 2bf298ef..d566f41b 100644 --- a/tests/test_ttl.py +++ b/tests/test_ttl.py @@ -4,33 +4,33 @@ import unittest import dns.ttl -class TTLTestCase(unittest.TestCase): +class TTLTestCase(unittest.TestCase): def test_bind_style_ok(self): - ttl = dns.ttl.from_text('2w1d1h1m1s') + ttl = dns.ttl.from_text("2w1d1h1m1s") self.assertEqual(ttl, 2 * 604800 + 86400 + 3600 + 60 + 1) def test_bind_style_ok2(self): # no one should do this, but it is legal! :) - ttl = dns.ttl.from_text('1s2w1m1d1h') + ttl = dns.ttl.from_text("1s2w1m1d1h") self.assertEqual(ttl, 2 * 604800 + 86400 + 3600 + 60 + 1) def test_bind_style_bad_unit(self): with self.assertRaises(dns.ttl.BadTTL): - dns.ttl.from_text('5y') + dns.ttl.from_text("5y") def test_bind_style_no_unit(self): with self.assertRaises(dns.ttl.BadTTL): - dns.ttl.from_text('1d5') + dns.ttl.from_text("1d5") def test_bind_style_leading_unit(self): with self.assertRaises(dns.ttl.BadTTL): - dns.ttl.from_text('s') + dns.ttl.from_text("s") def test_bind_style_unit_without_digits(self): with self.assertRaises(dns.ttl.BadTTL): - dns.ttl.from_text('1mw') + dns.ttl.from_text("1mw") def test_empty(self): with self.assertRaises(dns.ttl.BadTTL): - dns.ttl.from_text('') + dns.ttl.from_text("") diff --git a/tests/test_update.py b/tests/test_update.py index 3abec93f..c1dea824 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -23,72 +23,70 @@ import dns.rdata import dns.rdataset import dns.tsigkeyring + def hextowire(hex): - return binascii.unhexlify(hex.replace(' ', '').encode()) + return binascii.unhexlify(hex.replace(" ", "").encode()) + goodwire = hextowire( - '0001 2800 0001 0005 0007 0000' - '076578616d706c6500 0006 0001' - '03666f6fc00c 00ff 00ff 00000000 0000' - 'c019 0001 00ff 00000000 0000' - '03626172c00c 0001 0001 00000000 0004 0a000005' - '05626c617a32c00c 00ff 00fe 00000000 0000' - 'c049 0001 00fe 00000000 0000' - 'c019 0001 00ff 00000000 0000' - 'c019 0001 0001 0000012c 0004 0a000001' - 'c019 0001 0001 0000012c 0004 0a000002' - 'c035 0001 0001 0000012c 0004 0a000003' - 'c035 0001 00fe 00000000 0004 0a000004' - '04626c617ac00c 0001 00ff 00000000 0000' - 'c049 00ff 00ff 00000000 0000' + "0001 2800 0001 0005 0007 0000" + "076578616d706c6500 0006 0001" + "03666f6fc00c 00ff 00ff 00000000 0000" + "c019 0001 00ff 00000000 0000" + "03626172c00c 0001 0001 00000000 0004 0a000005" + "05626c617a32c00c 00ff 00fe 00000000 0000" + "c049 0001 00fe 00000000 0000" + "c019 0001 00ff 00000000 0000" + "c019 0001 0001 0000012c 0004 0a000001" + "c019 0001 0001 0000012c 0004 0a000002" + "c035 0001 0001 0000012c 0004 0a000003" + "c035 0001 00fe 00000000 0004 0a000004" + "04626c617ac00c 0001 00ff 00000000 0000" + "c049 00ff 00ff 00000000 0000" ) goodwirenone = hextowire( - '0001 2800 0001 0000 0001 0000' - '076578616d706c6500 0006 0001' - '03666f6fc00c 0001 00fe 00000000 0004 01020304' + "0001 2800 0001 0000 0001 0000" + "076578616d706c6500 0006 0001" + "03666f6fc00c 0001 00fe 00000000 0004 01020304" ) badwirenone = hextowire( - '0001 2800 0001 0003 0000 0000' - '076578616d706c6500 0006 0001' - '03666f6fc00c 00ff 00ff 00000000 0000' - 'c019 0001 00ff 00000000 0000' - 'c019 0001 00fe 00000000 0004 01020304' + "0001 2800 0001 0003 0000 0000" + "076578616d706c6500 0006 0001" + "03666f6fc00c 00ff 00ff 00000000 0000" + "c019 0001 00ff 00000000 0000" + "c019 0001 00fe 00000000 0004 01020304" ) badwireany = hextowire( - '0001 2800 0001 0002 0000 0000' - '076578616d706c6500 0006 0001' - '03666f6fc00c 00ff 00ff 00000000 0000' - 'c019 0001 00ff 00000000 0004 01020304' + "0001 2800 0001 0002 0000 0000" + "076578616d706c6500 0006 0001" + "03666f6fc00c 00ff 00ff 00000000 0000" + "c019 0001 00ff 00000000 0004 01020304" ) badwireanyany = hextowire( - '0001 2800 0001 0001 0000 0000' - '076578616d706c6500 0006 0001' - '03666f6fc00c 00ff 00ff 00000000 0004 01020304' + "0001 2800 0001 0001 0000 0000" + "076578616d706c6500 0006 0001" + "03666f6fc00c 00ff 00ff 00000000 0004 01020304" ) badwirezonetype = hextowire( - '0001 2800 0001 0000 0000 0000' - '076578616d706c6500 0001 0001' + "0001 2800 0001 0000 0000 0000" "076578616d706c6500 0001 0001" ) badwirezoneclass = hextowire( - '0001 2800 0001 0000 0000 0000' - '076578616d706c6500 0006 00ff' + "0001 2800 0001 0000 0000 0000" "076578616d706c6500 0006 00ff" ) badwirezonemulti = hextowire( - '0001 2800 0002 0000 0000 0000' - '076578616d706c6500 0006 0001' - 'c019 0006 0001' + "0001 2800 0002 0000 0000 0000" "076578616d706c6500 0006 0001" "c019 0006 0001" ) badwirenozone = hextowire( - '0001 2800 0000 0000 0001 0000' - '03666f6f076578616d706c6500 0001 0001 00000030 0004 01020304' + "0001 2800 0000 0000 0001 0000" + "03666f6f076578616d706c6500 0001 0001 00000030 0004 01020304" ) update_text = """id 1 @@ -140,116 +138,120 @@ foo 0 NONE A 10.0.0.1 foo 0 NONE A 10.0.0.2 """ -class UpdateTestCase(unittest.TestCase): - def test_to_wire1(self): # type: () -> None - update = dns.update.Update('example') +class UpdateTestCase(unittest.TestCase): + def test_to_wire1(self): # type: () -> None + update = dns.update.Update("example") update.id = 1 - update.present('foo') - update.present('foo', 'a') - update.present('bar', 'a', '10.0.0.5') - update.absent('blaz2') - update.absent('blaz2', 'a') - update.replace('foo', 300, 'a', '10.0.0.1', '10.0.0.2') - update.add('bar', 300, 'a', '10.0.0.3') - update.delete('bar', 'a', '10.0.0.4') - update.delete('blaz', 'a') - update.delete('blaz2') + update.present("foo") + update.present("foo", "a") + update.present("bar", "a", "10.0.0.5") + update.absent("blaz2") + update.absent("blaz2", "a") + update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2") + update.add("bar", 300, "a", "10.0.0.3") + update.delete("bar", "a", "10.0.0.4") + update.delete("blaz", "a") + update.delete("blaz2") self.assertEqual(update.to_wire(), goodwire) - def test_to_wire2(self): # type: () -> None - update = dns.update.Update('example') + def test_to_wire2(self): # type: () -> None + update = dns.update.Update("example") update.id = 1 - update.present('foo') - update.present('foo', 'a') - update.present('bar', 'a', '10.0.0.5') - update.absent('blaz2') - update.absent('blaz2', 'a') - update.replace('foo', 300, 'a', '10.0.0.1', '10.0.0.2') - update.add('bar', 300, dns.rdata.from_text(1, 1, '10.0.0.3')) - update.delete('bar', 'a', '10.0.0.4') - update.delete('blaz', 'a') - update.delete('blaz2') + update.present("foo") + update.present("foo", "a") + update.present("bar", "a", "10.0.0.5") + update.absent("blaz2") + update.absent("blaz2", "a") + update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2") + update.add("bar", 300, dns.rdata.from_text(1, 1, "10.0.0.3")) + update.delete("bar", "a", "10.0.0.4") + update.delete("blaz", "a") + update.delete("blaz2") self.assertEqual(update.to_wire(), goodwire) - def test_to_wire3(self): # type: () -> None - update = dns.update.Update('example') + def test_to_wire3(self): # type: () -> None + update = dns.update.Update("example") update.id = 1 - update.present('foo') - update.present('foo', 'a') - update.present('bar', 'a', '10.0.0.5') - update.absent('blaz2') - update.absent('blaz2', 'a') - update.replace('foo', 300, 'a', '10.0.0.1', '10.0.0.2') - update.add('bar', dns.rdataset.from_text(1, 1, 300, '10.0.0.3')) - update.delete('bar', 'a', '10.0.0.4') - update.delete('blaz', 'a') - update.delete('blaz2') + update.present("foo") + update.present("foo", "a") + update.present("bar", "a", "10.0.0.5") + update.absent("blaz2") + update.absent("blaz2", "a") + update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2") + update.add("bar", dns.rdataset.from_text(1, 1, 300, "10.0.0.3")) + update.delete("bar", "a", "10.0.0.4") + update.delete("blaz", "a") + update.delete("blaz2") self.assertEqual(update.to_wire(), goodwire) - def test_from_text1(self): # type: () -> None + def test_from_text1(self): # type: () -> None update = dns.message.from_text(update_text) self.assertTrue(isinstance(update, dns.update.UpdateMessage)) - w = update.to_wire(origin=dns.name.from_text('example'), - want_shuffle=False) + w = update.to_wire(origin=dns.name.from_text("example"), want_shuffle=False) self.assertEqual(w, goodwire) def test_from_wire(self): - origin = dns.name.from_text('example') + origin = dns.name.from_text("example") u1 = dns.message.from_wire(goodwire, origin=origin) u2 = dns.message.from_text(update_text, origin=origin) self.assertEqual(u1, u2) def test_good_explicit_delete_wire(self): - name = dns.name.from_text('foo.example') + name = dns.name.from_text("foo.example") u = dns.message.from_wire(goodwirenone) self.assertEqual(u.update[0].name, name) self.assertEqual(u.update[0].rdtype, dns.rdatatype.A) self.assertEqual(u.update[0].rdclass, dns.rdataclass.IN) self.assertTrue(u.update[0].deleting) - self.assertEqual(u.update[0][0].address, '1.2.3.4') + self.assertEqual(u.update[0][0].address, "1.2.3.4") def test_none_with_rdata_from_wire(self): def bad(): dns.message.from_wire(badwirenone) + self.assertRaises(dns.exception.FormError, bad) def test_any_with_rdata_from_wire(self): def bad(): dns.message.from_wire(badwireany) + self.assertRaises(dns.exception.FormError, bad) def test_any_any_with_rdata_from_wire(self): def bad(): dns.message.from_wire(badwireanyany) + self.assertRaises(dns.exception.FormError, bad) def test_bad_zone_type_from_wire(self): def bad(): dns.message.from_wire(badwirezonetype) + self.assertRaises(dns.exception.FormError, bad) def test_bad_zone_class_from_wire(self): def bad(): dns.message.from_wire(badwirezoneclass) + self.assertRaises(dns.exception.FormError, bad) def test_bad_zone_multi_from_wire(self): def bad(): dns.message.from_wire(badwirezonemulti) + self.assertRaises(dns.exception.FormError, bad) def test_no_zone_section_from_wire(self): def bad(): dns.message.from_wire(badwirenozone) + self.assertRaises(dns.exception.FormError, bad) def test_TSIG(self): - keyring = dns.tsigkeyring.from_text({ - 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ==' - }) - update = dns.update.Update('example.', keyring=keyring) - update.replace('host.example.', 300, 'A', '1.2.3.4') + keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="}) + update = dns.update.Update("example.", keyring=keyring) + update.replace("host.example.", 300, "A", "1.2.3.4") wire = update.to_wire() update2 = dns.message.from_wire(wire, keyring) self.assertEqual(update, update2) @@ -262,16 +264,18 @@ class UpdateTestCase(unittest.TestCase): self.assertTrue(update.is_response(r)) def test_making_UpdateSection(self): - self.assertEqual(dns.update.UpdateSection.make(0), - dns.update.UpdateSection.make('ZONE')) + self.assertEqual( + dns.update.UpdateSection.make(0), dns.update.UpdateSection.make("ZONE") + ) with self.assertRaises(ValueError): dns.update.UpdateSection.make(99) def test_setters(self): u = dns.update.UpdateMessage(id=1) - qrrset = dns.rrset.RRset(dns.name.from_text('example'), - dns.rdataclass.IN, dns.rdatatype.SOA) - rrset = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1') + qrrset = dns.rrset.RRset( + dns.name.from_text("example"), dns.rdataclass.IN, dns.rdatatype.SOA + ) + rrset = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1") u.zone = [qrrset] self.assertEqual(u.sections[0], [qrrset]) self.assertEqual(u.sections[1], []) @@ -289,58 +293,53 @@ class UpdateTestCase(unittest.TestCase): self.assertEqual(u.sections[3], []) def test_added_rdataset(self): - u = dns.update.UpdateMessage('example.', id=1) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') - u.add('foo', rds) + u = dns.update.UpdateMessage("example.", id=1) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") + u.add("foo", rds) expected = dns.message.from_text(added_text) self.assertEqual(u, expected) def test_replaced_rdataset(self): - u = dns.update.UpdateMessage('example.', id=1) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') - u.replace('foo', rds) + u = dns.update.UpdateMessage("example.", id=1) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") + u.replace("foo", rds) expected = dns.message.from_text(replaced_text) self.assertEqual(u, expected) def test_delete_rdataset(self): - u = dns.update.UpdateMessage('example.', id=1) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') - u.delete('foo', rds) + u = dns.update.UpdateMessage("example.", id=1) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1", "10.0.0.2") + u.delete("foo", rds) expected = dns.message.from_text(deleted_text) self.assertEqual(u, expected) def test_added_rdata(self): - u = dns.update.UpdateMessage('example.', id=1) - rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') - u.add('foo', 300, rd1) - u.add('foo', 300, rd2) + u = dns.update.UpdateMessage("example.", id=1) + rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") + u.add("foo", 300, rd1) + u.add("foo", 300, rd2) expected = dns.message.from_text(added_text) self.assertEqual(u, expected) def test_replaced_rdata(self): - u = dns.update.UpdateMessage('example.', id=1) - rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') - u.replace('foo', 300, rd1) - u.add('foo', 300, rd2) + u = dns.update.UpdateMessage("example.", id=1) + rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") + u.replace("foo", 300, rd1) + u.add("foo", 300, rd2) expected = dns.message.from_text(replaced_text) self.assertEqual(u, expected) def test_deleted_rdata(self): - u = dns.update.UpdateMessage('example.', id=1) - rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1') - rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2') - u.delete('foo', rd1) - u.delete('foo', rd2) + u = dns.update.UpdateMessage("example.", id=1) + rd1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1") + rd2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2") + u.delete("foo", rd1) + u.delete("foo", rd2) expected = dns.message.from_text(deleted_text) self.assertEqual(u, expected) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_wire.py b/tests/test_wire.py index a4b5991b..6d9df182 100644 --- a/tests/test_wire.py +++ b/tests/test_wire.py @@ -8,9 +8,8 @@ import dns.name class BinaryTestCase(unittest.TestCase): - def test_basic(self): - wire = bytes.fromhex('0102010203040102') + wire = bytes.fromhex("0102010203040102") p = dns.wire.Parser(wire) self.assertEqual(p.get_uint16(), 0x0102) with p.restrict_to(5): @@ -26,8 +25,8 @@ class BinaryTestCase(unittest.TestCase): def test_name(self): # www.dnspython.org NS IN question - wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01' - expected = dns.name.from_text('www.dnspython.org') + wire = b"\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01" + expected = dns.name.from_text("www.dnspython.org") p = dns.wire.Parser(wire) self.assertEqual(p.get_name(), expected) self.assertEqual(p.get_uint16(), 2) @@ -36,18 +35,18 @@ class BinaryTestCase(unittest.TestCase): def test_relativized_name(self): # www.dnspython.org NS IN question - wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01' - origin = dns.name.from_text('dnspython.org') - expected = dns.name.from_text('www', None) + wire = b"\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01" + origin = dns.name.from_text("dnspython.org") + expected = dns.name.from_text("www", None) p = dns.wire.Parser(wire) self.assertEqual(p.get_name(origin), expected) self.assertEqual(p.remaining(), 4) def test_compressed_name(self): # www.dnspython.org NS IN question - wire = b'\x09dnspython\x03org\x00\x03www\xc0\x00' - expected1 = dns.name.from_text('dnspython.org') - expected2 = dns.name.from_text('www.dnspython.org') + wire = b"\x09dnspython\x03org\x00\x03www\xc0\x00" + expected1 = dns.name.from_text("dnspython.org") + expected2 = dns.name.from_text("www.dnspython.org") p = dns.wire.Parser(wire) self.assertEqual(p.get_name(), expected1) self.assertEqual(p.get_name(), expected2) @@ -56,7 +55,7 @@ class BinaryTestCase(unittest.TestCase): self.assertEqual(p.current, len(wire)) def test_seek(self): - wire = b'\x09dnspython\x03org\x00' + wire = b"\x09dnspython\x03org\x00" p = dns.wire.Parser(wire) p.seek(10) self.assertEqual(p.get_uint8(), 3) @@ -72,7 +71,7 @@ class BinaryTestCase(unittest.TestCase): p.seek(len(wire) + 1) def test_not_reading_everything_in_restriction(self): - wire = bytes.fromhex('0102010203040102') + wire = bytes.fromhex("0102010203040102") p = dns.wire.Parser(wire) with self.assertRaises(dns.exception.FormError): with p.restrict_to(5): @@ -81,7 +80,7 @@ class BinaryTestCase(unittest.TestCase): # don't read the other 4 bytes def test_restriction_does_not_mask_exception(self): - wire = bytes.fromhex('0102010203040102') + wire = bytes.fromhex("0102010203040102") p = dns.wire.Parser(wire) with self.assertRaises(NotImplementedError): with p.restrict_to(5): diff --git a/tests/test_xfr.py b/tests/test_xfr.py index 3cf4c913..2a739eaf 100644 --- a/tests/test_xfr.py +++ b/tests/test_xfr.py @@ -17,13 +17,16 @@ import dns.xfr # those tests. try: from .nanonameserver import Server + _nanonameserver_available = True except ImportError: _nanonameserver_available = False + class Server(object): pass -axfr = '''id 1 + +axfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -37,9 +40,9 @@ bar.foo 300 IN MX 0 blaz.foo ns1 3600 IN A 10.0.0.1 ns2 3600 IN A 10.0.0.2 @ 3600 IN SOA foo bar 1 2 3 4 5 -''' +""" -axfr1 = '''id 1 +axfr1 = """id 1 opcode QUERY rcode NOERROR flags AA @@ -49,8 +52,8 @@ example. IN AXFR @ 3600 IN SOA foo bar 1 2 3 4 5 @ 3600 IN NS ns1 @ 3600 IN NS ns2 -''' -axfr2 = '''id 1 +""" +axfr2 = """id 1 opcode QUERY rcode NOERROR flags AA @@ -59,7 +62,7 @@ bar.foo 300 IN MX 0 blaz.foo ns1 3600 IN A 10.0.0.1 ns2 3600 IN A 10.0.0.2 @ 3600 IN SOA foo bar 1 2 3 4 5 -''' +""" base = """@ 3600 IN SOA foo bar 1 2 3 4 5 @ 3600 IN NS ns1 @@ -69,7 +72,7 @@ ns1 3600 IN A 10.0.0.1 ns2 3600 IN A 10.0.0.2 """ -axfr_unexpected_origin = '''id 1 +axfr_unexpected_origin = """id 1 opcode QUERY rcode NOERROR flags AA @@ -78,9 +81,9 @@ example. IN AXFR ;ANSWER @ 3600 IN SOA foo bar 1 2 3 4 5 @ 3600 IN SOA foo bar 1 2 3 4 7 -''' +""" -ixfr = '''id 1 +ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -100,9 +103,9 @@ ns3 3600 IN A 10.0.0.3 @ 3600 IN NS ns2 @ 3600 IN SOA foo bar 4 2 3 4 5 @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" -compressed_ixfr = '''id 1 +compressed_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -118,7 +121,7 @@ ns2 3600 IN A 10.0.0.2 ns2 3600 IN A 10.0.0.4 ns3 3600 IN A 10.0.0.3 @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" ixfr_expected = """@ 3600 IN SOA foo bar 4 2 3 4 5 @ 3600 IN NS ns1 @@ -127,7 +130,7 @@ ns2 3600 IN A 10.0.0.4 ns3 3600 IN A 10.0.0.3 """ -ixfr_first_message = '''id 1 +ixfr_first_message = """id 1 opcode QUERY rcode NOERROR flags AA @@ -135,34 +138,34 @@ flags AA example. IN IXFR ;ANSWER @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" -ixfr_header = '''id 1 +ixfr_header = """id 1 opcode QUERY rcode NOERROR flags AA ;ANSWER -''' +""" ixfr_body = [ - '@ 3600 IN SOA foo bar 1 2 3 4 5', - 'bar.foo 300 IN MX 0 blaz.foo', - 'ns2 3600 IN A 10.0.0.2', - '@ 3600 IN SOA foo bar 2 2 3 4 5', - 'ns2 3600 IN A 10.0.0.4', - '@ 3600 IN SOA foo bar 2 2 3 4 5', - '@ 3600 IN SOA foo bar 3 2 3 4 5', - 'ns3 3600 IN A 10.0.0.3', - '@ 3600 IN SOA foo bar 3 2 3 4 5', - '@ 3600 IN NS ns2', - '@ 3600 IN SOA foo bar 4 2 3 4 5', - '@ 3600 IN SOA foo bar 4 2 3 4 5', + "@ 3600 IN SOA foo bar 1 2 3 4 5", + "bar.foo 300 IN MX 0 blaz.foo", + "ns2 3600 IN A 10.0.0.2", + "@ 3600 IN SOA foo bar 2 2 3 4 5", + "ns2 3600 IN A 10.0.0.4", + "@ 3600 IN SOA foo bar 2 2 3 4 5", + "@ 3600 IN SOA foo bar 3 2 3 4 5", + "ns3 3600 IN A 10.0.0.3", + "@ 3600 IN SOA foo bar 3 2 3 4 5", + "@ 3600 IN NS ns2", + "@ 3600 IN SOA foo bar 4 2 3 4 5", + "@ 3600 IN SOA foo bar 4 2 3 4 5", ] ixfrs = [ixfr_first_message] ixfrs.extend([ixfr_header + l for l in ixfr_body]) -good_empty_ixfr = '''id 1 +good_empty_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -170,9 +173,9 @@ flags AA example. IN IXFR ;ANSWER @ 3600 IN SOA foo bar 1 2 3 4 5 -''' +""" -retry_tcp_ixfr = '''id 1 +retry_tcp_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -180,9 +183,9 @@ flags AA example. IN IXFR ;ANSWER @ 3600 IN SOA foo bar 5 2 3 4 5 -''' +""" -bad_empty_ixfr = '''id 1 +bad_empty_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -191,9 +194,9 @@ example. IN IXFR ;ANSWER @ 3600 IN SOA foo bar 4 2 3 4 5 @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" -unexpected_end_ixfr = '''id 1 +unexpected_end_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -209,9 +212,9 @@ ns2 3600 IN A 10.0.0.2 ns2 3600 IN A 10.0.0.4 ns3 3600 IN A 10.0.0.3 @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" -unexpected_end_ixfr_2 = '''id 1 +unexpected_end_ixfr_2 = """id 1 opcode QUERY rcode NOERROR flags AA @@ -223,9 +226,9 @@ example. IN IXFR bar.foo 300 IN MX 0 blaz.foo ns2 3600 IN A 10.0.0.2 @ 3600 IN NS ns2 -''' +""" -bad_serial_ixfr = '''id 1 +bad_serial_ixfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -241,9 +244,9 @@ ns2 3600 IN A 10.0.0.2 ns2 3600 IN A 10.0.0.4 ns3 3600 IN A 10.0.0.3 @ 3600 IN SOA foo bar 4 2 3 4 5 -''' +""" -ixfr_axfr = '''id 1 +ixfr_axfr = """id 1 opcode QUERY rcode NOERROR flags AA @@ -257,213 +260,209 @@ bar.foo 300 IN MX 0 blaz.foo ns1 3600 IN A 10.0.0.1 ns2 3600 IN A 10.0.0.2 @ 3600 IN SOA foo bar 1 2 3 4 5 -''' +""" + def test_basic_axfr(): - z = dns.versioned.Zone('example.') - m = dns.message.from_text(axfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.versioned.Zone("example.") + m = dns.message.from_text(axfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(base, 'example.') + ez = dns.zone.from_text(base, "example.") assert z == ez + def test_basic_axfr_unversioned(): - z = dns.zone.Zone('example.') - m = dns.message.from_text(axfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.Zone("example.") + m = dns.message.from_text(axfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(base, 'example.') + ez = dns.zone.from_text(base, "example.") assert z == ez + def test_basic_axfr_two_parts(): - z = dns.versioned.Zone('example.') - m1 = dns.message.from_text(axfr1, origin=z.origin, - one_rr_per_rrset=True) - m2 = dns.message.from_text(axfr2, origin=z.origin, - one_rr_per_rrset=True) + z = dns.versioned.Zone("example.") + m1 = dns.message.from_text(axfr1, origin=z.origin, one_rr_per_rrset=True) + m2 = dns.message.from_text(axfr2, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: done = xfr.process_message(m1) assert not done done = xfr.process_message(m2) assert done - ez = dns.zone.from_text(base, 'example.') + ez = dns.zone.from_text(base, "example.") assert z == ez + def test_axfr_unexpected_origin(): - z = dns.versioned.Zone('example.') - m = dns.message.from_text(axfr_unexpected_origin, origin=z.origin, - one_rr_per_rrset=True) + z = dns.versioned.Zone("example.") + m = dns.message.from_text( + axfr_unexpected_origin, origin=z.origin, one_rr_per_rrset=True + ) with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_basic_ixfr(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(ixfr_expected, 'example.') + ez = dns.zone.from_text(ixfr_expected, "example.") assert z == ez + def test_basic_ixfr_unversioned(): - z = dns.zone.from_text(base, 'example.') - m = dns.message.from_text(ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.") + m = dns.message.from_text(ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(ixfr_expected, 'example.') + ez = dns.zone.from_text(ixfr_expected, "example.") assert z == ez + def test_compressed_ixfr(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(compressed_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(compressed_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(ixfr_expected, 'example.') + ez = dns.zone.from_text(ixfr_expected, "example.") assert z == ez + def test_basic_ixfr_many_parts(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: done = False for text in ixfrs: assert not done - m = dns.message.from_text(text, origin=z.origin, - one_rr_per_rrset=True) + m = dns.message.from_text(text, origin=z.origin, one_rr_per_rrset=True) done = xfr.process_message(m) assert done - ez = dns.zone.from_text(ixfr_expected, 'example.') + ez = dns.zone.from_text(ixfr_expected, "example.") assert z == ez + def test_good_empty_ixfr(): - z = dns.zone.from_text(ixfr_expected, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(good_empty_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(ixfr_expected, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(good_empty_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(ixfr_expected, 'example.') + ez = dns.zone.from_text(ixfr_expected, "example.") assert z == ez + def test_retry_tcp_ixfr(): - z = dns.zone.from_text(ixfr_expected, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(retry_tcp_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(ixfr_expected, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(retry_tcp_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr: with pytest.raises(dns.xfr.UseTCP): xfr.process_message(m) + def test_bad_empty_ixfr(): - z = dns.zone.from_text(ixfr_expected, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(ixfr_expected, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=3) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_serial_went_backwards_ixfr(): - z = dns.zone.from_text(ixfr_expected, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(ixfr_expected, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=5) as xfr: with pytest.raises(dns.xfr.SerialWentBackwards): xfr.process_message(m) + def test_ixfr_is_axfr(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(ixfr_axfr, origin=z.origin, - one_rr_per_rrset=True) - with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=0xffffffff) as xfr: + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(ixfr_axfr, origin=z.origin, one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=0xFFFFFFFF) as xfr: done = xfr.process_message(m) assert done - ez = dns.zone.from_text(base, 'example.') + ez = dns.zone.from_text(base, "example.") assert z == ez + def test_ixfr_requires_serial(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) with pytest.raises(ValueError): dns.xfr.Inbound(z, dns.rdatatype.IXFR) + def test_ixfr_unexpected_end_bad_diff_sequence(): # This is where we get the end serial, but haven't seen all of # the expected diffs - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(unexpected_end_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text( + unexpected_end_ixfr, origin=z.origin, one_rr_per_rrset=True + ) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_udp_ixfr_unexpected_end_just_stops(): # This is where everything looks good, but the IXFR just stops # in the middle. - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(unexpected_end_ixfr_2, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text( + unexpected_end_ixfr_2, origin=z.origin, one_rr_per_rrset=True + ) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_ixfr_bad_serial(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(bad_serial_ixfr, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_serial_ixfr, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_no_udp_with_axfr(): - z = dns.versioned.Zone('example.') + z = dns.versioned.Zone("example.") with pytest.raises(ValueError): with dns.xfr.Inbound(z, dns.rdatatype.AXFR, is_udp=True) as xfr: pass -refused = '''id 1 + +refused = """id 1 opcode QUERY rcode REFUSED flags AA ;QUESTION example. IN AXFR -''' +""" -bad_qname = '''id 1 +bad_qname = """id 1 opcode QUERY rcode NOERROR flags AA ;QUESTION not-example. IN IXFR -''' +""" -bad_qtype = '''id 1 +bad_qtype = """id 1 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN AXFR -''' +""" -soa_not_first = '''id 1 +soa_not_first = """id 1 opcode QUERY rcode NOERROR flags AA @@ -471,9 +470,9 @@ flags AA example. IN IXFR ;ANSWER bar.foo 300 IN MX 0 blaz.foo -''' +""" -soa_not_first_2 = '''id 1 +soa_not_first_2 = """id 1 opcode QUERY rcode NOERROR flags AA @@ -481,9 +480,9 @@ flags AA example. IN IXFR ;ANSWER @ 300 IN MX 0 blaz.foo -''' +""" -no_answer = '''id 1 +no_answer = """id 1 opcode QUERY rcode NOERROR flags AA @@ -491,9 +490,9 @@ flags AA example. IN IXFR ;ADDITIONAL bar.foo 300 IN MX 0 blaz.foo -''' +""" -axfr_answers_after_final_soa = '''id 1 +axfr_answers_after_final_soa = """id 1 opcode QUERY rcode NOERROR flags AA @@ -508,76 +507,70 @@ ns1 3600 IN A 10.0.0.1 ns2 3600 IN A 10.0.0.2 @ 3600 IN SOA foo bar 1 2 3 4 5 ns3 3600 IN A 10.0.0.3 -''' +""" + def test_refused(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(refused, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(refused, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.xfr.TransferError): xfr.process_message(m) + def test_bad_qname(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(bad_qname, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_qname, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_bad_qtype(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(bad_qtype, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_qtype, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_soa_not_first(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(soa_not_first, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(soa_not_first, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) - m = dns.message.from_text(soa_not_first_2, origin=z.origin, - one_rr_per_rrset=True) + m = dns.message.from_text(soa_not_first_2, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_no_answer(): - z = dns.zone.from_text(base, 'example.', - zone_factory=dns.versioned.Zone) - m = dns.message.from_text(no_answer, origin=z.origin, - one_rr_per_rrset=True) + z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) + m = dns.message.from_text(no_answer, origin=z.origin, one_rr_per_rrset=True) with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) + def test_axfr_answers_after_final_soa(): - z = dns.versioned.Zone('example.') - m = dns.message.from_text(axfr_answers_after_final_soa, origin=z.origin, - one_rr_per_rrset=True) + z = dns.versioned.Zone("example.") + m = dns.message.from_text( + axfr_answers_after_final_soa, origin=z.origin, one_rr_per_rrset=True + ) with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: with pytest.raises(dns.exception.FormError): xfr.process_message(m) -keyring = dns.tsigkeyring.from_text( - { - 'keyname.': 'NjHwPsMKjdN++dOfE5iAiQ==' - } -) -keyname = dns.name.from_text('keyname') +keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="}) + +keyname = dns.name.from_text("keyname") + def test_make_query_basic(): - z = dns.versioned.Zone('example.') + z = dns.versioned.Zone("example.") (q, s) = dns.xfr.make_query(z) assert q.question[0].rdtype == dns.rdatatype.AXFR assert s is None @@ -590,7 +583,7 @@ def test_make_query_basic(): assert q.authority[0][0].serial == 10 assert s == 10 with z.writer() as txn: - txn.add('@', 300, dns.rdata.from_text('in', 'soa', '. . 1 2 3 4 5')) + txn.add("@", 300, dns.rdata.from_text("in", "soa", ". . 1 2 3 4 5")) (q, s) = dns.xfr.make_query(z) assert q.question[0].rdtype == dns.rdatatype.IXFR assert q.authority[0].rdtype == dns.rdatatype.SOA @@ -605,16 +598,17 @@ def test_make_query_basic(): def test_make_query_bad_serial(): - z = dns.versioned.Zone('example.') + z = dns.versioned.Zone("example.") with pytest.raises(ValueError): - dns.xfr.make_query(z, serial='hi') + dns.xfr.make_query(z, serial="hi") with pytest.raises(ValueError): dns.xfr.make_query(z, serial=-1) with pytest.raises(ValueError): dns.xfr.make_query(z, serial=4294967296) + def test_extract_serial_from_query(): - z = dns.versioned.Zone('example.') + z = dns.versioned.Zone("example.") (q, s) = dns.xfr.make_query(z) xs = dns.xfr.extract_serial_from_query(q) assert s is None @@ -623,15 +617,14 @@ def test_extract_serial_from_query(): xs = dns.xfr.extract_serial_from_query(q) assert s == 10 assert s == xs - q = dns.message.make_query('example', 'a') + q = dns.message.make_query("example", "a") with pytest.raises(ValueError): dns.xfr.extract_serial_from_query(q) class XFRNanoNameserver(Server): - def __init__(self): - super().__init__(origin=dns.name.from_text('example')) + super().__init__(origin=dns.name.from_text("example")) def handle(self, request): try: @@ -639,45 +632,62 @@ class XFRNanoNameserver(Server): text = ixfr else: text = axfr - r = dns.message.from_text(text, one_rr_per_rrset=True, - origin=self.origin) + r = dns.message.from_text(text, one_rr_per_rrset=True, origin=self.origin) r.id = request.message.id return r except Exception: pass -@pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + +@pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_sync_inbound_xfr(): with XFRNanoNameserver() as ns: - zone = dns.versioned.Zone('example') - dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - expected = dns.zone.from_text(ixfr_expected, 'example') + zone = dns.versioned.Zone("example") + dns.query.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + dns.query.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + expected = dns.zone.from_text(ixfr_expected, "example") assert zone == expected + async def async_inbound_xfr(): with XFRNanoNameserver() as ns: - zone = dns.versioned.Zone('example') - await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, - port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, - port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - expected = dns.zone.from_text(ixfr_expected, 'example') + zone = dns.versioned.Zone("example") + await dns.asyncquery.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + await dns.asyncquery.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + expected = dns.zone.from_text(ixfr_expected, "example") assert zone == expected -@pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + +@pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_asyncio_inbound_xfr(): - dns.asyncbackend.set_default_backend('asyncio') + dns.asyncbackend.set_default_backend("asyncio") + async def run(): await async_inbound_xfr() + asyncio.run(run()) + # # We don't need to do this as it's all generic code, but # just for extra caution we do it for each backend. @@ -686,34 +696,37 @@ def test_asyncio_inbound_xfr(): try: import trio - @pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + @pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_trio_inbound_xfr(): - dns.asyncbackend.set_default_backend('trio') + dns.asyncbackend.set_default_backend("trio") + async def run(): await async_inbound_xfr() + trio.run(run) + except ImportError: pass try: import curio - @pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + @pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_curio_inbound_xfr(): - dns.asyncbackend.set_default_backend('curio') + dns.asyncbackend.set_default_backend("curio") + async def run(): await async_inbound_xfr() + curio.run(run) + except ImportError: pass class UDPXFRNanoNameserver(Server): - def __init__(self): - super().__init__(origin=dns.name.from_text('example')) + super().__init__(origin=dns.name.from_text("example")) self.did_truncation = False def handle(self, request): @@ -726,48 +739,66 @@ class UDPXFRNanoNameserver(Server): self.did_truncation = True else: text = axfr - r = dns.message.from_text(text, one_rr_per_rrset=True, - origin=self.origin) + r = dns.message.from_text(text, one_rr_per_rrset=True, origin=self.origin) r.id = request.message.id return r except Exception: pass -@pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + +@pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_sync_retry_tcp_inbound_xfr(): with UDPXFRNanoNameserver() as ns: - zone = dns.versioned.Zone('example') - dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - expected = dns.zone.from_text(ixfr_expected, 'example') + zone = dns.versioned.Zone("example") + dns.query.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + dns.query.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + expected = dns.zone.from_text(ixfr_expected, "example") assert zone == expected + async def udp_async_inbound_xfr(): with UDPXFRNanoNameserver() as ns: - zone = dns.versioned.Zone('example') - await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, - port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, - port=ns.tcp_address[1], - udp_mode=dns.query.UDPMode.TRY_FIRST) - expected = dns.zone.from_text(ixfr_expected, 'example') + zone = dns.versioned.Zone("example") + await dns.asyncquery.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + await dns.asyncquery.inbound_xfr( + ns.tcp_address[0], + zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST, + ) + expected = dns.zone.from_text(ixfr_expected, "example") assert zone == expected -@pytest.mark.skipif(not _nanonameserver_available, - reason="requires nanonameserver") + +@pytest.mark.skipif(not _nanonameserver_available, reason="requires nanonameserver") def test_asyncio_retry_tcp_inbound_xfr(): - dns.asyncbackend.set_default_backend('asyncio') + dns.asyncbackend.set_default_backend("asyncio") + async def run(): await udp_async_inbound_xfr() + try: runner = asyncio.run except AttributeError: + def old_runner(awaitable): loop = asyncio.get_event_loop() return loop.run_until_complete(awaitable) + runner = old_runner runner(run()) diff --git a/tests/test_zone.py b/tests/test_zone.py index b7046b2d..de7ec015 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -124,7 +124,9 @@ $ORIGIN example. """ include_text = """$INCLUDE "%s" -""" % here("example") +""" % here( + "example" +) bad_directive_text = """$FOO bar $ORIGIN example. @@ -243,24 +245,28 @@ web a 10.0.0.4 _keep_output = True + def _rdata_sort(a): return (a[0], a[2].rdclass, a[2].to_text()) + def add_rdataset(msg, name, rds): - rrset = msg.get_rrset(msg.answer, name, rds.rdclass, rds.rdtype, - create=True, force_unique=True) + rrset = msg.get_rrset( + msg.answer, name, rds.rdclass, rds.rdtype, create=True, force_unique=True + ) for rd in rds: rrset.add(rd, ttl=rds.ttl) + def make_xfr(zone): - q = dns.message.make_query(zone.origin, 'AXFR') + q = dns.message.make_query(zone.origin, "AXFR") msg = dns.message.make_response(q) if zone.relativize: msg.origin = zone.origin soa_name = dns.name.empty else: soa_name = zone.origin - soa = zone.find_rdataset(soa_name, 'SOA') + soa = zone.find_rdataset(soa_name, "SOA") add_rdataset(msg, soa_name, soa) for (name, rds) in zone.iterate_rdatasets(): if rds.rdtype == dns.rdatatype.SOA: @@ -269,48 +275,48 @@ def make_xfr(zone): add_rdataset(msg, soa_name, soa) return [msg] + def compare_files(test_name, a_name, b_name): - with open(a_name, 'r') as a: - with open(b_name, 'r') as b: - differences = list(difflib.unified_diff(a.readlines(), - b.readlines())) + with open(a_name, "r") as a: + with open(b_name, "r") as b: + differences = list(difflib.unified_diff(a.readlines(), b.readlines())) if len(differences) == 0: return True else: - print(f'{test_name} differences:') + print(f"{test_name} differences:") sys.stdout.writelines(differences) return False -class ZoneTestCase(unittest.TestCase): +class ZoneTestCase(unittest.TestCase): def testFromFile1(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") ok = False try: - z.to_file(here('example1.out'), nl=b'\x0a') - ok = compare_files('testFromFile1', - here('example1.out'), - here('example1.good')) + z.to_file(here("example1.out"), nl=b"\x0a") + ok = compare_files( + "testFromFile1", here("example1.out"), here("example1.good") + ) finally: if not _keep_output: - os.unlink(here('example1.out')) + os.unlink(here("example1.out")) self.assertTrue(ok) def testFromFile2(self): - z = dns.zone.from_file(here('example'), 'example', relativize=False) + z = dns.zone.from_file(here("example"), "example", relativize=False) ok = False try: - z.to_file(here('example2.out'), relativize=False, nl=b'\x0a') - ok = compare_files('testFromFile2', - here('example2.out'), - here('example2.good')) + z.to_file(here("example2.out"), relativize=False, nl=b"\x0a") + ok = compare_files( + "testFromFile2", here("example2.out"), here("example2.good") + ) finally: if not _keep_output: - os.unlink(here('example2.out')) + os.unlink(here("example2.out")) self.assertTrue(ok) def testToFileTextualStream(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) f = StringIO() z.to_file(f) out = f.getvalue() @@ -318,105 +324,111 @@ class ZoneTestCase(unittest.TestCase): self.assertEqual(out, example_text_output) def testToFileBinaryStream(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) f = BytesIO() - z.to_file(f, nl=b'\n') + z.to_file(f, nl=b"\n") out = f.getvalue() f.close() self.assertEqual(out, example_text_output.encode()) def testToFileTextual(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") try: - f = open(here('example3-textual.out'), 'w') + f = open(here("example3-textual.out"), "w") z.to_file(f) f.close() - ok = compare_files('testToFileTextual', - here('example3-textual.out'), - here('example3.good')) + ok = compare_files( + "testToFileTextual", here("example3-textual.out"), here("example3.good") + ) finally: if not _keep_output: - os.unlink(here('example3-textual.out')) + os.unlink(here("example3-textual.out")) self.assertTrue(ok) def testToFileBinary(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") try: - f = open(here('example3-binary.out'), 'wb') + f = open(here("example3-binary.out"), "wb") z.to_file(f) f.close() - ok = compare_files('testToFileBinary', - here('example3-binary.out'), - here('example3.good')) + ok = compare_files( + "testToFileBinary", here("example3-binary.out"), here("example3.good") + ) finally: if not _keep_output: - os.unlink(here('example3-binary.out')) + os.unlink(here("example3-binary.out")) self.assertTrue(ok) def testToFileFilename(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") try: - z.to_file(here('example3-filename.out')) - ok = compare_files('testToFileFilename', - here('example3-filename.out'), - here('example3.good')) + z.to_file(here("example3-filename.out")) + ok = compare_files( + "testToFileFilename", + here("example3-filename.out"), + here("example3.good"), + ) finally: if not _keep_output: - os.unlink(here('example3-filename.out')) + os.unlink(here("example3-filename.out")) self.assertTrue(ok) def testToText(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") ok = False try: - text_zone = z.to_text(nl='\x0a') - f = open(here('example3.out'), 'w') + text_zone = z.to_text(nl="\x0a") + f = open(here("example3.out"), "w") f.write(text_zone) f.close() - ok = compare_files('testToText', - here('example3.out'), - here('example3.good')) + ok = compare_files( + "testToText", here("example3.out"), here("example3.good") + ) finally: if not _keep_output: - os.unlink(here('example3.out')) + os.unlink(here("example3.out")) self.assertTrue(ok) def testToFileTextualWithOrigin(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") try: - f = open(here('example4-textual.out'), 'w') + f = open(here("example4-textual.out"), "w") z.to_file(f, want_origin=True) f.close() - ok = compare_files('testToFileTextualWithOrigin', - here('example4-textual.out'), - here('example4.good')) + ok = compare_files( + "testToFileTextualWithOrigin", + here("example4-textual.out"), + here("example4.good"), + ) finally: if not _keep_output: - os.unlink(here('example4-textual.out')) + os.unlink(here("example4-textual.out")) self.assertTrue(ok) def testToFileBinaryWithOrigin(self): - z = dns.zone.from_file(here('example'), 'example') + z = dns.zone.from_file(here("example"), "example") try: - f = open(here('example4-binary.out'), 'wb') + f = open(here("example4-binary.out"), "wb") z.to_file(f, want_origin=True) f.close() - ok = compare_files('testToFileBinaryWithOrigin', - here('example4-binary.out'), - here('example4.good')) + ok = compare_files( + "testToFileBinaryWithOrigin", + here("example4-binary.out"), + here("example4.good"), + ) finally: if not _keep_output: - os.unlink(here('example4-binary.out')) + os.unlink(here("example4-binary.out")) self.assertTrue(ok) def testFromText(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) f = StringIO() names = list(z.nodes.keys()) names.sort() for n in names: f.write(z[n].to_text(n)) - f.write('\n') + f.write("\n") self.assertEqual(f.getvalue(), example_text_output) def testTorture1(self): @@ -426,8 +438,8 @@ class ZoneTestCase(unittest.TestCase): # and then back out, and see if we get equal rdatas. # f = BytesIO() - o = dns.name.from_text('example.') - z = dns.zone.from_file(here('example'), o) + o = dns.name.from_text("example.") + z = dns.zone.from_file(here("example"), o) for node in z.values(): for rds in node: for rd in rds: @@ -435,437 +447,507 @@ class ZoneTestCase(unittest.TestCase): f.truncate() rd.to_wire(f, origin=o) wire = f.getvalue() - rd2 = dns.rdata.from_wire(rds.rdclass, rds.rdtype, - wire, 0, len(wire), - origin=o) + rd2 = dns.rdata.from_wire( + rds.rdclass, rds.rdtype, wire, 0, len(wire), origin=o + ) self.assertEqual(rd, rd2) def testEqual(self): - z1 = dns.zone.from_text(example_text, 'example.', relativize=True) - z2 = dns.zone.from_text(example_text_output, 'example.', - relativize=True) + z1 = dns.zone.from_text(example_text, "example.", relativize=True) + z2 = dns.zone.from_text(example_text_output, "example.", relativize=True) self.assertEqual(z1, z2) def testNotEqual1(self): - z1 = dns.zone.from_text(example_text, 'example.', relativize=True) - z2 = dns.zone.from_text(something_quite_similar, 'example.', - relativize=True) + z1 = dns.zone.from_text(example_text, "example.", relativize=True) + z2 = dns.zone.from_text(something_quite_similar, "example.", relativize=True) self.assertNotEqual(z1, z2) def testNotEqual2(self): - z1 = dns.zone.from_text(example_text, 'example.', relativize=True) - z2 = dns.zone.from_text(something_different, 'example.', - relativize=True) + z1 = dns.zone.from_text(example_text, "example.", relativize=True) + z2 = dns.zone.from_text(something_different, "example.", relativize=True) self.assertNotEqual(z1, z2) def testNotEqual3(self): - z1 = dns.zone.from_text(example_text, 'example.', relativize=True) - z2 = dns.zone.from_text(something_different, 'example2.', - relativize=True) + z1 = dns.zone.from_text(example_text, "example.", relativize=True) + z2 = dns.zone.from_text(something_different, "example2.", relativize=True) self.assertNotEqual(z1, z2) def testFindRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rds = z.find_rdataset('@', 'soa') - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo bar 1 2 3 4 5') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rds = z.find_rdataset("@", "soa") + exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") self.assertEqual(rds, exrds) def testFindRdataset2(self): def bad(): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - z.find_rdataset('@', 'loc') + z = dns.zone.from_text(example_text, "example.", relativize=True) + z.find_rdataset("@", "loc") + self.assertRaises(KeyError, bad) def testFindRRset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rrs = z.find_rrset('@', 'soa') - exrrs = dns.rrset.from_text('@', 300, 'IN', 'SOA', 'foo bar 1 2 3 4 5') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rrs = z.find_rrset("@", "soa") + exrrs = dns.rrset.from_text("@", 300, "IN", "SOA", "foo bar 1 2 3 4 5") self.assertEqual(rrs, exrrs) def testFindRRset2(self): def bad(): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - z.find_rrset('@', 'loc') + z = dns.zone.from_text(example_text, "example.", relativize=True) + z.find_rrset("@", "loc") + self.assertRaises(KeyError, bad) def testGetRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rds = z.get_rdataset('@', 'soa') - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo bar 1 2 3 4 5') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rds = z.get_rdataset("@", "soa") + exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") self.assertEqual(rds, exrds) def testGetRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rds = z.get_rdataset('@', 'loc') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rds = z.get_rdataset("@", "loc") self.assertTrue(rds is None) def testGetRdatasetWithRelativeNameFromAbsoluteZone(self): - z = dns.zone.from_text(example_text, 'example.', relativize=False) - rds = z.get_rdataset(dns.name.empty, 'soa') + z = dns.zone.from_text(example_text, "example.", relativize=False) + rds = z.get_rdataset(dns.name.empty, "soa") self.assertIsNotNone(rds) - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo.example. bar.example. 1 2 3 4 5') + exrds = dns.rdataset.from_text( + "IN", "SOA", 300, "foo.example. bar.example. 1 2 3 4 5" + ) self.assertEqual(rds, exrds) def testGetRRset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rrs = z.get_rrset('@', 'soa') - exrrs = dns.rrset.from_text('@', 300, 'IN', 'SOA', 'foo bar 1 2 3 4 5') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rrs = z.get_rrset("@", "soa") + exrrs = dns.rrset.from_text("@", 300, "IN", "SOA", "foo bar 1 2 3 4 5") self.assertEqual(rrs, exrrs) def testGetRRset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rrs = z.get_rrset('@', 'loc') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rrs = z.get_rrset("@", "loc") self.assertTrue(rrs is None) def testReplaceRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rdataset = dns.rdataset.from_text('in', 'ns', 300, 'ns3', 'ns4') - z.replace_rdataset('@', rdataset) - rds = z.get_rdataset('@', 'ns') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rdataset = dns.rdataset.from_text("in", "ns", 300, "ns3", "ns4") + z.replace_rdataset("@", rdataset) + rds = z.get_rdataset("@", "ns") self.assertTrue(rds is rdataset) def testReplaceRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rdataset = dns.rdataset.from_text('in', 'txt', 300, '"foo"') - z.replace_rdataset('@', rdataset) - rds = z.get_rdataset('@', 'txt') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rdataset = dns.rdataset.from_text("in", "txt", 300, '"foo"') + z.replace_rdataset("@", rdataset) + rds = z.get_rdataset("@", "txt") self.assertTrue(rds is rdataset) def testDeleteRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - z.delete_rdataset('@', 'ns') - rds = z.get_rdataset('@', 'ns') + z = dns.zone.from_text(example_text, "example.", relativize=True) + z.delete_rdataset("@", "ns") + rds = z.get_rdataset("@", "ns") self.assertTrue(rds is None) def testDeleteRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - z.delete_rdataset('ns1', 'a') - node = z.get_node('ns1') + z = dns.zone.from_text(example_text, "example.", relativize=True) + z.delete_rdataset("ns1", "a") + node = z.get_node("ns1") self.assertTrue(node is None) def testNodeFindRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] rds = node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo bar 1 2 3 4 5') + exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") self.assertEqual(rds, exrds) def testNodeFindRdataset2(self): def bad(): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.LOC) + self.assertRaises(KeyError, bad) def testNodeFindRdataset3(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] - rds = node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, - dns.rdatatype.A, create=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] + rds = node.find_rdataset( + dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.A, create=True + ) self.assertEqual(rds.rdclass, dns.rdataclass.IN) self.assertEqual(rds.rdtype, dns.rdatatype.RRSIG) self.assertEqual(rds.covers, dns.rdatatype.A) def testNodeGetRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] rds = node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo bar 1 2 3 4 5') + exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") self.assertEqual(rds, exrds) def testNodeGetRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] rds = node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.LOC) self.assertTrue(rds is None) def testNodeDeleteRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] node.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) rds = node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) self.assertTrue(rds is None) def testNodeDeleteRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - node = z['@'] + z = dns.zone.from_text(example_text, "example.", relativize=True) + node = z["@"] node.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.LOC) rds = node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.LOC) self.assertTrue(rds is None) def testIterateNodes(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) count = 0 for n in z: count += 1 self.assertEqual(count, 4) def testIterateRdatasets(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - ns = [n for n, r in z.iterate_rdatasets('A')] + z = dns.zone.from_text(example_text, "example.", relativize=True) + ns = [n for n, r in z.iterate_rdatasets("A")] ns.sort() - self.assertEqual(ns, [dns.name.from_text('ns1', None), - dns.name.from_text('ns2', None)]) + self.assertEqual( + ns, [dns.name.from_text("ns1", None), dns.name.from_text("ns2", None)] + ) def testIterateAllRdatasets(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) ns = [n for n, r in z.iterate_rdatasets()] ns.sort() - self.assertEqual(ns, [dns.name.from_text('@', None), - dns.name.from_text('@', None), - dns.name.from_text('bar.foo', None), - dns.name.from_text('ns1', None), - dns.name.from_text('ns2', None)]) + self.assertEqual( + ns, + [ + dns.name.from_text("@", None), + dns.name.from_text("@", None), + dns.name.from_text("bar.foo", None), + dns.name.from_text("ns1", None), + dns.name.from_text("ns2", None), + ], + ) def testIterateRdatas(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - l = list(z.iterate_rdatas('A')) + z = dns.zone.from_text(example_text, "example.", relativize=True) + l = list(z.iterate_rdatas("A")) l.sort() - exl = [(dns.name.from_text('ns1', None), + exl = [ + ( + dns.name.from_text("ns1", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ] self.assertEqual(l, exl) def testIterateAllRdatas(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) l = list(z.iterate_rdatas()) l.sort(key=_rdata_sort) - exl = [(dns.name.from_text('@', None), + exl = [ + ( + dns.name.from_text("@", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns1')), - (dns.name.from_text('@', None), + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns1"), + ), + ( + dns.name.from_text("@", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, - 'ns2')), - (dns.name.from_text('@', None), + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, "ns2"), + ), + ( + dns.name.from_text("@", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, - 'foo bar 1 2 3 4 5')), - (dns.name.from_text('bar.foo', None), + dns.rdata.from_text( + dns.rdataclass.IN, dns.rdatatype.SOA, "foo bar 1 2 3 4 5" + ), + ), + ( + dns.name.from_text("bar.foo", None), 300, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, - '0 blaz.foo')), - (dns.name.from_text('ns1', None), + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, "0 blaz.foo"), + ), + ( + dns.name.from_text("ns1", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.1')), - (dns.name.from_text('ns2', None), + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.1"), + ), + ( + dns.name.from_text("ns2", None), 3600, - dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, - '10.0.0.2'))] + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "10.0.0.2"), + ), + ] exl.sort(key=_rdata_sort) self.assertEqual(l, exl) def testNodeGetSetDel(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) n = z.node_factory() - rds = dns.rdataset.from_text('IN', 'A', 300, '10.0.0.1') + rds = dns.rdataset.from_text("IN", "A", 300, "10.0.0.1") n.replace_rdataset(rds) - z['foo'] = n - self.assertTrue(z.find_rdataset('foo', 'A') is rds) - self.assertEqual(z['foo'], n) - self.assertEqual(z.get('foo'), n) - del z['foo'] - self.assertEqual(z.get('foo'), None) + z["foo"] = n + self.assertTrue(z.find_rdataset("foo", "A") is rds) + self.assertEqual(z["foo"], n) + self.assertEqual(z.get("foo"), n) + del z["foo"] + self.assertEqual(z.get("foo"), None) with self.assertRaises(KeyError): z[123] = n with self.assertRaises(KeyError): - z['foo.'] = n + z["foo."] = n with self.assertRaises(KeyError): - bn = z.find_node('bar') - bn = z.find_node('bar', True) + bn = z.find_node("bar") + bn = z.find_node("bar", True) self.assertTrue(isinstance(bn, dns.node.Node)) # The next two tests pass by not raising KeyError - z.delete_node('foo') - z.delete_node('bar') + z.delete_node("foo") + z.delete_node("bar") def testBadReplacement(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - rds = dns.rdataset.from_text('CH', 'TXT', 300, 'hi') + z = dns.zone.from_text(example_text, "example.", relativize=True) + rds = dns.rdataset.from_text("CH", "TXT", 300, "hi") + def bad(): - z.replace_rdataset('foo', rds) + z.replace_rdataset("foo", rds) + self.assertRaises(ValueError, bad) def testTTLs(self): - z = dns.zone.from_text(ttl_example_text, 'example.', relativize=True) - n = z['@'] # type: dns.node.Node - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)) + z = dns.zone.from_text(ttl_example_text, "example.", relativize=True) + n = z["@"] # type: dns.node.Node + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) + ) self.assertEqual(rds.ttl, 3600) - n = z['ns1'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns1"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 86401) - n = z['ns2'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns2"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 694861) def testTTLFromSOA(self): - z = dns.zone.from_text(ttl_from_soa_text, 'example.', relativize=True) - n = z['@'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)) + z = dns.zone.from_text(ttl_from_soa_text, "example.", relativize=True) + n = z["@"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) + ) self.assertEqual(rds.ttl, 3600) soa_rd = rds[0] - n = z['ns1'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns1"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 694861) - n = z['ns2'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns2"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, soa_rd.minimum) def testTTLFromLast(self): - z = dns.zone.from_text(ttl_from_last_text, 'example.', check_origin=False) - n = z['@'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NS)) + z = dns.zone.from_text(ttl_from_last_text, "example.", check_origin=False) + n = z["@"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NS) + ) self.assertEqual(rds.ttl, 3600) - n = z['ns1'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns1"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 3600) - n = z['ns2'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + n = z["ns2"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 694861) def testNoTTL(self): def bad(): - dns.zone.from_text(no_ttl_text, 'example.', check_origin=False) + dns.zone.from_text(no_ttl_text, "example.", check_origin=False) + self.assertRaises(dns.exception.SyntaxError, bad) def testNoSOA(self): def bad(): - dns.zone.from_text(no_soa_text, 'example.', relativize=True) + dns.zone.from_text(no_soa_text, "example.", relativize=True) + self.assertRaises(dns.zone.NoSOA, bad) def testNoNS(self): def bad(): - dns.zone.from_text(no_ns_text, 'example.', relativize=True) + dns.zone.from_text(no_ns_text, "example.", relativize=True) + self.assertRaises(dns.zone.NoNS, bad) def testInclude(self): - z1 = dns.zone.from_text(include_text, 'example.', relativize=True, - allow_include=True) - z2 = dns.zone.from_file(here('example'), 'example.', relativize=True) + z1 = dns.zone.from_text( + include_text, "example.", relativize=True, allow_include=True + ) + z2 = dns.zone.from_file(here("example"), "example.", relativize=True) self.assertEqual(z1, z2) def testBadDirective(self): def bad(): - dns.zone.from_text(bad_directive_text, 'example.', relativize=True) + dns.zone.from_text(bad_directive_text, "example.", relativize=True) + self.assertRaises(dns.exception.SyntaxError, bad) def testFirstRRStartsWithWhitespace(self): # no name is specified, so default to the initial origin - z = dns.zone.from_text(' 300 IN A 10.0.0.1', origin='example.', - check_origin=False) - n = z['@'] - rds = cast(dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A)) + z = dns.zone.from_text( + " 300 IN A 10.0.0.1", origin="example.", check_origin=False + ) + n = z["@"] + rds = cast( + dns.rdataset.Rdataset, n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.A) + ) self.assertEqual(rds.ttl, 300) def testZoneOrigin(self): - z = dns.zone.Zone('example.') - self.assertEqual(z.origin, dns.name.from_text('example.')) + z = dns.zone.Zone("example.") + self.assertEqual(z.origin, dns.name.from_text("example.")) + def bad1(): - o = dns.name.from_text('example', None) + o = dns.name.from_text("example", None) dns.zone.Zone(o) + self.assertRaises(ValueError, bad1) + def bad2(): dns.zone.Zone(cast(str, 1.0)) + self.assertRaises(ValueError, bad2) def testZoneOriginNone(self): dns.zone.Zone(cast(str, None)) def testZoneFromXFR(self): - z1_abs = dns.zone.from_text(example_text, 'example.', relativize=False) + z1_abs = dns.zone.from_text(example_text, "example.", relativize=False) z2_abs = dns.zone.from_xfr(make_xfr(z1_abs), relativize=False) self.assertEqual(z1_abs, z2_abs) - z1_rel = dns.zone.from_text(example_text, 'example.', relativize=True) + z1_rel = dns.zone.from_text(example_text, "example.", relativize=True) z2_rel = dns.zone.from_xfr(make_xfr(z1_rel), relativize=True) self.assertEqual(z1_rel, z2_rel) def testCodec2003(self): - z = dns.zone.from_text(codec_text, 'example.', relativize=True) - n2003 = dns.name.from_text('xn--knigsgsschen-lcb0w', None) - n2008 = dns.name.from_text('xn--knigsgchen-b4a3dun', None) + z = dns.zone.from_text(codec_text, "example.", relativize=True) + n2003 = dns.name.from_text("xn--knigsgsschen-lcb0w", None) + n2008 = dns.name.from_text("xn--knigsgchen-b4a3dun", None) self.assertTrue(n2003 in z) self.assertFalse(n2008 in z) - rrs = z.find_rrset(n2003, 'NS') + rrs = z.find_rrset(n2003, "NS") self.assertEqual(rrs[0].target, n2003) - @unittest.skipUnless(dns.name.have_idna_2008, - 'Python idna cannot be imported; no IDNA2008') + @unittest.skipUnless( + dns.name.have_idna_2008, "Python idna cannot be imported; no IDNA2008" + ) def testCodec2008(self): - z = dns.zone.from_text(codec_text, 'example.', relativize=True, - idna_codec=dns.name.IDNA_2008) - n2003 = dns.name.from_text('xn--knigsgsschen-lcb0w', None) - n2008 = dns.name.from_text('xn--knigsgchen-b4a3dun', None) + z = dns.zone.from_text( + codec_text, "example.", relativize=True, idna_codec=dns.name.IDNA_2008 + ) + n2003 = dns.name.from_text("xn--knigsgsschen-lcb0w", None) + n2008 = dns.name.from_text("xn--knigsgchen-b4a3dun", None) self.assertFalse(n2003 in z) self.assertTrue(n2008 in z) - rrs = z.find_rrset(n2008, 'NS') + rrs = z.find_rrset(n2008, "NS") self.assertEqual(rrs[0].target, n2008) def testZoneMiscCases(self): # test that leading whitespace followed by EOL is treated like # a blank line, and that out-of-zone names are dropped. - z1 = dns.zone.from_text(misc_cases_input, 'example.') - z2 = dns.zone.from_text(misc_cases_expected, 'example.') + z1 = dns.zone.from_text(misc_cases_input, "example.") + z2 = dns.zone.from_text(misc_cases_expected, "example.") self.assertEqual(z1, z2) def testUnknownOrigin(self): def bad(): - dns.zone.from_text('foo 300 in a 10.0.0.1') + dns.zone.from_text("foo 300 in a 10.0.0.1") + self.assertRaises(dns.zone.UnknownOrigin, bad) def testBadClass(self): def bad(): - dns.zone.from_text('foo 300 ch txt hi', 'example.') + dns.zone.from_text("foo 300 ch txt hi", "example.") + self.assertRaises(dns.exception.SyntaxError, bad) def testUnknownRdatatype(self): def bad(): - dns.zone.from_text('foo 300 BOGUSTYPE hi', 'example.') + dns.zone.from_text("foo 300 BOGUSTYPE hi", "example.") + self.assertRaises(dns.exception.SyntaxError, bad) def testDangling(self): def bad1(): - dns.zone.from_text('foo', 'example.') + dns.zone.from_text("foo", "example.") + self.assertRaises(dns.exception.SyntaxError, bad1) + def bad2(): - dns.zone.from_text('foo 300', 'example.') + dns.zone.from_text("foo 300", "example.") + self.assertRaises(dns.exception.SyntaxError, bad2) + def bad3(): - dns.zone.from_text('foo 300 in', 'example.') + dns.zone.from_text("foo 300 in", "example.") + self.assertRaises(dns.exception.SyntaxError, bad3) + def bad4(): - dns.zone.from_text('foo 300 in a', 'example.') + dns.zone.from_text("foo 300 in a", "example.") + self.assertRaises(dns.exception.SyntaxError, bad4) + def bad5(): - dns.zone.from_text('$TTL', 'example.') + dns.zone.from_text("$TTL", "example.") + self.assertRaises(dns.exception.SyntaxError, bad5) + def bad6(): - dns.zone.from_text('$ORIGIN', 'example.') + dns.zone.from_text("$ORIGIN", "example.") + self.assertRaises(dns.exception.SyntaxError, bad6) def testUseLastTTL(self): - z = dns.zone.from_text(last_ttl_input, 'example.') - rds = z.find_rdataset('foo', 'A') + z = dns.zone.from_text(last_ttl_input, "example.") + rds = z.find_rdataset("foo", "A") self.assertEqual(rds.ttl, 300) def testDollarOriginSetsZoneOriginIfUnknown(self): z = dns.zone.from_text(origin_sets_input) - self.assertEqual(z.origin, dns.name.from_text('example')) + self.assertEqual(z.origin, dns.name.from_text("example")) def testValidateNameRelativizesNameInZone(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) - self.assertEqual(z._validate_name('foo.bar.example.'), - dns.name.from_text('foo.bar', None)) + z = dns.zone.from_text(example_text, "example.", relativize=True) + self.assertEqual( + z._validate_name("foo.bar.example."), dns.name.from_text("foo.bar", None) + ) def testComments(self): - z = dns.zone.from_text(example_comments_text, 'example.', - relativize=True) + z = dns.zone.from_text(example_comments_text, "example.", relativize=True) f = StringIO() z.to_file(f, want_comments=True) out = f.getvalue() @@ -873,22 +955,21 @@ class ZoneTestCase(unittest.TestCase): self.assertEqual(out, example_comments_text_output) def testUncomparable(self): - z = dns.zone.from_text(example_comments_text, 'example.', - relativize=True) - self.assertFalse(z == 'a') + z = dns.zone.from_text(example_comments_text, "example.", relativize=True) + self.assertFalse(z == "a") def testUnsorted(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True) + z = dns.zone.from_text(example_text, "example.", relativize=True) f = StringIO() z.to_file(f, sorted=False) out = f.getvalue() f.close() - z2 = dns.zone.from_text(out, 'example.', relativize=True) + z2 = dns.zone.from_text(out, "example.", relativize=True) self.assertEqual(z, z2) def testNodeReplaceRdatasetConvertsRRsets(self): node = dns.node.Node() - rrs = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1') + rrs = dns.rrset.from_text("foo", 300, "in", "a", "10.0.0.1") node.replace_rdataset(rrs) rds = node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A) self.assertEqual(rds, rrs) @@ -896,219 +977,223 @@ class ZoneTestCase(unittest.TestCase): self.assertFalse(isinstance(rds, dns.rrset.RRset)) def testCnameAndOtherDataAddOther(self): - z = dns.zone.from_text(example_cname, 'example.', relativize=True) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') - z.replace_rdataset('web', rds) - z.replace_rdataset('web2', rds.copy()) - n = z.find_node('web') + z = dns.zone.from_text(example_cname, "example.", relativize=True) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") + z.replace_rdataset("web", rds) + z.replace_rdataset("web2", rds.copy()) + n = z.find_node("web") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC)) - n = z.find_node('web2') + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC) + ) + n = z.find_node("web2") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC3)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC3)) + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC3)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC3) + ) def testCnameAndOtherDataAddCname(self): - z = dns.zone.from_text(example_other_data, 'example.', relativize=True) - rds = dns.rdataset.from_text('in', 'cname', 300, 'www') - z.replace_rdataset('web', rds) - n = z.find_node('web') + z = dns.zone.from_text(example_other_data, "example.", relativize=True) + rds = dns.rdataset.from_text("in", "cname", 300, "www") + z.replace_rdataset("web", rds) + n = z.find_node("web") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, - dns.rdatatype.CNAME), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC)) + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.CNAME), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC) + ) def testCnameAndOtherDataInZonefile(self): with self.assertRaises(dns.zonefile.CNAMEAndOtherData): - dns.zone.from_text(example_cname_and_other_data, 'example.', - relativize=True) + dns.zone.from_text( + example_cname_and_other_data, "example.", relativize=True + ) def testNameInZoneWithStr(self): - z = dns.zone.from_text(example_text, 'example.', relativize=False) - self.assertTrue('ns1.example.' in z) - self.assertTrue('bar.foo.example.' in z) + z = dns.zone.from_text(example_text, "example.", relativize=False) + self.assertTrue("ns1.example." in z) + self.assertTrue("bar.foo.example." in z) def testNameInZoneWhereNameIsNotValid(self): - z = dns.zone.from_text(example_text, 'example.', relativize=False) + z = dns.zone.from_text(example_text, "example.", relativize=False) with self.assertRaises(KeyError): self.assertTrue(1 in z) class VersionedZoneTestCase(unittest.TestCase): def testUseTransaction(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) with self.assertRaises(dns.versioned.UseTransaction): - z.find_node('not_there', True) + z.find_node("not_there", True) with self.assertRaises(dns.versioned.UseTransaction): - z.delete_node('not_there') + z.delete_node("not_there") with self.assertRaises(dns.versioned.UseTransaction): - z.find_rdataset('not_there', 'a', create=True) + z.find_rdataset("not_there", "a", create=True) with self.assertRaises(dns.versioned.UseTransaction): - z.get_rdataset('not_there', 'a', create=True) + z.get_rdataset("not_there", "a", create=True) with self.assertRaises(dns.versioned.UseTransaction): - z.delete_rdataset('not_there', 'a') + z.delete_rdataset("not_there", "a") with self.assertRaises(dns.versioned.UseTransaction): - z.replace_rdataset('not_there', None) + z.replace_rdataset("not_there", None) def testImmutableNodes(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) - node = z.find_node('@') + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) + node = z.find_node("@") with self.assertRaises(TypeError): - node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.RP, - create=True) + node.find_rdataset(dns.rdataclass.IN, dns.rdatatype.RP, create=True) with self.assertRaises(TypeError): - node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RP, - create=True) + node.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RP, create=True) with self.assertRaises(TypeError): node.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA) with self.assertRaises(TypeError): node.replace_rdataset(None) def testSelectDefaultPruningPolicy(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) z.set_pruning_policy(None) self.assertEqual(z._pruning_policy, z._default_pruning_policy) def testSetAlternatePruningPolicyInConstructor(self): def never_prune(version): return False - z = dns.versioned.Zone('example', pruning_policy=never_prune) + + z = dns.versioned.Zone("example", pruning_policy=never_prune) self.assertEqual(z._pruning_policy, never_prune) def testCannotSpecifyBothSerialAndVersionIdToReader(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) with self.assertRaises(ValueError): z.reader(1, 1) def testUnknownVersion(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) with self.assertRaises(KeyError): z.reader(99999) def testUnknownSerial(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) with self.assertRaises(KeyError): z.reader(serial=99999) def testNoRelativizeReader(self): - z = dns.zone.from_text(example_text, 'example.', relativize=False, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone + ) with z.reader(serial=1) as txn: - rds = txn.get('example.', 'soa') + rds = txn.get("example.", "soa") self.assertEqual(rds[0].serial, 1) def testNoRelativizeReaderOriginInText(self): - z = dns.zone.from_text(example_text, relativize=False, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, relativize=False, zone_factory=dns.versioned.Zone + ) with z.reader(serial=1) as txn: - rds = txn.get('example.', 'soa') + rds = txn.get("example.", "soa") self.assertEqual(rds[0].serial, 1) def testNoRelativizeReaderAbsoluteGet(self): - z = dns.zone.from_text(example_text, 'example.', relativize=False, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone + ) with z.reader(serial=1) as txn: - rds = txn.get(dns.name.empty, 'soa') + rds = txn.get(dns.name.empty, "soa") self.assertEqual(rds[0].serial, 1) def testCnameAndOtherDataAddOther(self): - z = dns.zone.from_text(example_cname, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) - rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + z = dns.zone.from_text( + example_cname, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) + rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") with z.writer() as txn: - txn.replace('web', rds) - txn.replace('web2', rds.copy()) - n = z.find_node('web') + txn.replace("web", rds) + txn.replace("web2", rds.copy()) + n = z.find_node("web") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC)) - n = z.find_node('web2') + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC) + ) + n = z.find_node("web2") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC3)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC3)) + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC3)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC3) + ) def testCnameAndOtherDataAddCname(self): - z = dns.zone.from_text(example_other_data, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) - rds = dns.rdataset.from_text('in', 'cname', 300, 'www') + z = dns.zone.from_text( + example_other_data, + "example.", + relativize=True, + zone_factory=dns.versioned.Zone, + ) + rds = dns.rdataset.from_text("in", "cname", 300, "www") with z.writer() as txn: - txn.replace('web', rds) - n = z.find_node('web') + txn.replace("web", rds) + n = z.find_node("web") self.assertEqual(len(n.rdatasets), 3) - self.assertEqual(n.find_rdataset(dns.rdataclass.IN, - dns.rdatatype.CNAME), - rds) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.NSEC)) - self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, - dns.rdatatype.RRSIG, - dns.rdatatype.NSEC)) + self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.CNAME), rds) + self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NSEC)) + self.assertIsNotNone( + n.get_rdataset(dns.rdataclass.IN, dns.rdatatype.RRSIG, dns.rdatatype.NSEC) + ) def testGetSoa(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) soa = z.get_soa() self.assertTrue(soa.rdtype, dns.rdatatype.SOA) self.assertEqual(soa.serial, 1) def testGetSoaTxn(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) with z.reader(serial=1) as txn: soa = z.get_soa(txn) self.assertTrue(soa.rdtype, dns.rdatatype.SOA) self.assertEqual(soa.serial, 1) def testGetSoaEmptyZone(self): - z = dns.zone.Zone('example.') + z = dns.zone.Zone("example.") with self.assertRaises(dns.zone.NoSOA): soa = z.get_soa() def testGetRdataset1(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) - rds = z.get_rdataset('@', 'soa') - exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo bar 1 2 3 4 5') + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) + rds = z.get_rdataset("@", "soa") + exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") self.assertEqual(rds, exrds) def testGetRdataset2(self): - z = dns.zone.from_text(example_text, 'example.', relativize=True, - zone_factory=dns.versioned.Zone) - rds = z.get_rdataset('@', 'loc') + z = dns.zone.from_text( + example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + ) + rds = z.get_rdataset("@", "loc") self.assertTrue(rds is None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_zonedigest.py b/tests/test_zonedigest.py index d94be249..bb3f174c 100644 --- a/tests/test_zonedigest.py +++ b/tests/test_zonedigest.py @@ -8,10 +8,12 @@ import dns.rdata import dns.rrset import dns.zone + class ZoneDigestTestCase(unittest.TestCase): # Examples from RFC 8976, fixed per errata. - simple_example = textwrap.dedent(''' + simple_example = textwrap.dedent( + """ example. 86400 IN SOA ns1 admin 2018031900 ( 1800 900 604800 86400 ) 86400 IN NS ns1 @@ -25,9 +27,11 @@ class ZoneDigestTestCase(unittest.TestCase): 777f98b8e730044c ) ns1 3600 IN A 203.0.113.63 ns2 3600 IN AAAA 2001:db8::63 - ''') + """ + ) - complex_example = textwrap.dedent(''' + complex_example = textwrap.dedent( + """ example. 86400 IN SOA ns1 admin 2018031900 ( 1800 900 604800 86400 ) 86400 IN NS ns1 @@ -62,9 +66,11 @@ class ZoneDigestTestCase(unittest.TestCase): 6f77656420627574 2069676e6f726564 2e20616c6c6f7765 ) - ''') + """ + ) - multiple_digests_example = textwrap.dedent(''' + multiple_digests_example = textwrap.dedent( + """ example. 86400 IN SOA ns1 admin 2018031900 ( 1800 900 604800 86400 ) example. 86400 IN NS ns1.example. @@ -98,36 +104,34 @@ class ZoneDigestTestCase(unittest.TestCase): ns1.example. 3600 IN A 203.0.113.63 ns2.example. 86400 IN TXT "This example has multiple digests" NS2.EXAMPLE. 3600 IN AAAA 2001:db8::63 - ''') + """ + ) def _get_zonemd(self, zone): - return zone.get_rdataset(zone.origin, 'ZONEMD') + return zone.get_rdataset(zone.origin, "ZONEMD") def test_zonemd_simple(self): - zone = dns.zone.from_text(self.simple_example, origin='example') + zone = dns.zone.from_text(self.simple_example, origin="example") zone.verify_digest() zonemd = self._get_zonemd(zone) - self.assertEqual(zonemd[0], - zone.compute_digest(zonemd[0].hash_algorithm)) + self.assertEqual(zonemd[0], zone.compute_digest(zonemd[0].hash_algorithm)) def test_zonemd_simple_absolute(self): - zone = dns.zone.from_text(self.simple_example, origin='example', - relativize=False) + zone = dns.zone.from_text( + self.simple_example, origin="example", relativize=False + ) zone.verify_digest() zonemd = self._get_zonemd(zone) - self.assertEqual(zonemd[0], - zone.compute_digest(zonemd[0].hash_algorithm)) + self.assertEqual(zonemd[0], zone.compute_digest(zonemd[0].hash_algorithm)) def test_zonemd_complex(self): - zone = dns.zone.from_text(self.complex_example, origin='example') + zone = dns.zone.from_text(self.complex_example, origin="example") zone.verify_digest() zonemd = self._get_zonemd(zone) - self.assertEqual(zonemd[0], - zone.compute_digest(zonemd[0].hash_algorithm)) + self.assertEqual(zonemd[0], zone.compute_digest(zonemd[0].hash_algorithm)) def test_zonemd_multiple_digests(self): - zone = dns.zone.from_text(self.multiple_digests_example, - origin='example') + zone = dns.zone.from_text(self.multiple_digests_example, origin="example") zone.verify_digest() zonemd = self._get_zonemd(zone) @@ -140,54 +144,56 @@ class ZoneDigestTestCase(unittest.TestCase): zone.verify_digest(rr) def test_zonemd_no_digest(self): - zone = dns.zone.from_text(self.simple_example, origin='example') - zone.delete_rdataset(dns.name.empty, 'ZONEMD') + zone = dns.zone.from_text(self.simple_example, origin="example") + zone.delete_rdataset(dns.name.empty, "ZONEMD") with self.assertRaises(dns.zone.NoDigest): zone.verify_digest() - sha384_hash = 'ab' * 48 - sha512_hash = 'ab' * 64 + sha384_hash = "ab" * 48 + sha512_hash = "ab" * 64 def test_zonemd_parse_rdata(self): - dns.rdata.from_text('IN', 'ZONEMD', '100 1 1 ' + self.sha384_hash) - dns.rdata.from_text('IN', 'ZONEMD', '100 1 2 ' + self.sha512_hash) - dns.rdata.from_text('IN', 'ZONEMD', '100 100 1 ' + self.sha384_hash) - dns.rdata.from_text('IN', 'ZONEMD', '100 1 100 abcd') + dns.rdata.from_text("IN", "ZONEMD", "100 1 1 " + self.sha384_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 1 2 " + self.sha512_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 100 1 " + self.sha384_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 1 100 abcd") def test_zonemd_unknown_scheme(self): - zone = dns.zone.from_text(self.simple_example, origin='example') + zone = dns.zone.from_text(self.simple_example, origin="example") with self.assertRaises(dns.zone.UnsupportedDigestScheme): zone.compute_digest(dns.zone.DigestHashAlgorithm.SHA384, 2) def test_zonemd_unknown_hash_algorithm(self): - zone = dns.zone.from_text(self.simple_example, origin='example') + zone = dns.zone.from_text(self.simple_example, origin="example") with self.assertRaises(dns.zone.UnsupportedDigestHashAlgorithm): zone.compute_digest(5) def test_zonemd_invalid_digest_length(self): with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('IN', 'ZONEMD', '100 1 2 ' + self.sha384_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 1 2 " + self.sha384_hash) with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('IN', 'ZONEMD', '100 2 1 ' + self.sha512_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 2 1 " + self.sha512_hash) def test_zonemd_parse_rdata_reserved(self): with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('IN', 'ZONEMD', '100 0 1 ' + self.sha384_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 0 1 " + self.sha384_hash) with self.assertRaises(dns.exception.SyntaxError): - dns.rdata.from_text('IN', 'ZONEMD', '100 1 0 ' + self.sha384_hash) + dns.rdata.from_text("IN", "ZONEMD", "100 1 0 " + self.sha384_hash) - sorting_zone = textwrap.dedent(''' + sorting_zone = textwrap.dedent( + """ @ 86400 IN SOA ns1 admin 2018031900 ( 1800 900 604800 86400 ) 86400 IN NS ns1 86400 IN NS ns2 86400 IN RP n1.example. a. 86400 IN RP n1. b. - ''') - + """ + ) + def test_relative_zone_sorting(self): - z1 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=True) - z2 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=False) + z1 = dns.zone.from_text(self.sorting_zone, "example.", relativize=True) + z2 = dns.zone.from_text(self.sorting_zone, "example.", relativize=False) zmd1 = z1.compute_digest(dns.zone.DigestHashAlgorithm.SHA384) zmd2 = z2.compute_digest(dns.zone.DigestHashAlgorithm.SHA384) self.assertEqual(zmd1, zmd2) diff --git a/tests/ttxt_module.py b/tests/ttxt_module.py index c66131bb..0886dde4 100644 --- a/tests/ttxt_module.py +++ b/tests/ttxt_module.py @@ -1,4 +1,5 @@ import dns.rdtypes.txtbase + class TTXT(dns.rdtypes.txtbase.TXTBase): """Test TXT-like record""" diff --git a/tests/utest.py b/tests/utest.py index 20f8e186..9f0667e6 100644 --- a/tests/utest.py +++ b/tests/utest.py @@ -2,12 +2,12 @@ import os.path import sys import unittest -if __name__ == '__main__': - sys.path.insert(0, os.path.realpath('..')) +if __name__ == "__main__": + sys.path.insert(0, os.path.realpath("..")) if len(sys.argv) > 1: pattern = sys.argv[1] else: - pattern = 'test*.py' - suites = unittest.defaultTestLoader.discover('.', pattern) + pattern = "test*.py" + suites = unittest.defaultTestLoader.discover(".", pattern) if not unittest.TextTestRunner(verbosity=2).run(suites).wasSuccessful(): sys.exit(1) diff --git a/tests/util.py b/tests/util.py index e8c13630..df9ab444 100644 --- a/tests/util.py +++ b/tests/util.py @@ -19,9 +19,11 @@ import enum import inspect import os.path + def here(filename): return os.path.join(os.path.dirname(__file__), filename) + def enumerate_module(module, super_class): """Yield module attributes which are subclasses of given class""" for attr_name in dir(module): @@ -29,12 +31,13 @@ def enumerate_module(module, super_class): if inspect.isclass(attr) and issubclass(attr, super_class): yield attr + def check_enum_exports(module, eq_callback, only=None): """Make sure module exports all mnemonics from enums""" for attr in enumerate_module(module, enum.Enum): if only is not None and attr not in only: - #print('SKIP', attr) + # print('SKIP', attr) continue for flag, value in attr.__members__.items(): - #print(module, flag, value) + # print(module, flag, value) eq_callback(getattr(module, flag), value) diff --git a/util/generate-mx-pickle.py b/util/generate-mx-pickle.py index ad999421..a917e089 100644 --- a/util/generate-mx-pickle.py +++ b/util/generate-mx-pickle.py @@ -6,14 +6,14 @@ import dns.version # Generate a pickled mx RR for the current dnspython version -mx = dns.rdata.from_text('in', 'mx', '10 mx.example.') -filename = f'pickled-{dns.version.MAJOR}-{dns.version.MINOR}.pickle' -with open(filename, 'wb') as f: +mx = dns.rdata.from_text("in", "mx", "10 mx.example.") +filename = f"pickled-{dns.version.MAJOR}-{dns.version.MINOR}.pickle" +with open(filename, "wb") as f: pickle.dump(mx, f) -with open(filename, 'rb') as f: +with open(filename, "rb") as f: mx2 = pickle.load(f) if mx == mx2: - print('ok') + print("ok") else: - print('DIFFERENT!') + print("DIFFERENT!") sys.exit(1) diff --git a/util/generate-rdatatype-doc.py b/util/generate-rdatatype-doc.py index 99f02b4f..5cc9f114 100644 --- a/util/generate-rdatatype-doc.py +++ b/util/generate-rdatatype-doc.py @@ -1,14 +1,13 @@ - import dns.rdatatype -print('Rdatatypes') -print('----------') +print("Rdatatypes") +print("----------") print() by_name = {} for rdtype in dns.rdatatype.RdataType: - short_name = str(rdtype).split('.')[1] + short_name = str(rdtype).split(".")[1] by_name[short_name] = int(rdtype) for k in sorted(by_name.keys()): v = by_name[k] - print(f'.. py:data:: dns.rdatatype.{k}') - print(f' :annotation: = {v}') + print(f".. py:data:: dns.rdatatype.{k}") + print(f" :annotation: = {v}")