Generated with: ruff check --extend-select UP045 --fix && black .
control_subdomain = "set-spoofing-mode"
def __init__(self) -> None:
- self._current_handler: Optional[ResponseSpoofer] = None
+ self._current_handler: ResponseSpoofer | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ ) -> str | None:
if len(args) != 1:
qctx.response.set_rcode(dns.rcode.SERVFAIL)
return "invalid control command"
from dataclasses import dataclass
from enum import Enum
-from typing import AsyncGenerator, Optional
+from typing import AsyncGenerator
import abc
import logging
control_subdomain = "setup-chain"
def __init__(self) -> None:
- self._current_handler: Optional[ChainResponseHandler] = None
+ self._current_handler: ChainResponseHandler | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ ) -> str | None:
try:
actions, selectors = self._parse_args(args)
except ValueError as exc:
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-from typing import AsyncGenerator, Optional
+from typing import AsyncGenerator
import logging
control_subdomain = "response-sequence"
def __init__(self) -> None:
- self._current_handler: Optional[ResponseHandler] = None
+ self._current_handler: ResponseHandler | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
"""
from dataclasses import dataclass, field
-from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Sequence, cast
+from typing import Any, AsyncGenerator, Callable, Coroutine, Sequence, cast
import abc
import asyncio
self,
handler: _UdpHandler,
) -> None:
- self._transport: Optional[asyncio.DatagramTransport] = None
+ self._transport: asyncio.DatagramTransport | None = None
self._handler: _UdpHandler = handler
def connection_made(self, transport: asyncio.BaseTransport) -> None:
def __init__(
self,
- udp_handler: Optional[_UdpHandler],
- tcp_handler: Optional[_TcpHandler],
- pidfile: Optional[str] = None,
+ udp_handler: _UdpHandler | None,
+ tcp_handler: _TcpHandler | None,
+ pidfile: str | None = None,
) -> None:
logging.basicConfig(
format="%(asctime)s %(levelname)8s %(message)s",
self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
self._port: int = port
- self._udp_handler: Optional[_UdpHandler] = udp_handler
- self._tcp_handler: Optional[_TcpHandler] = tcp_handler
- self._pidfile: Optional[str] = pidfile
- self._work_done: Optional[asyncio.Future] = None
+ self._udp_handler: _UdpHandler | None = udp_handler
+ self._tcp_handler: _TcpHandler | None = tcp_handler
+ self._pidfile: str | None = pidfile
+ self._work_done: asyncio.Future | None = None
def _get_ipv4_address_from_directory_name(self) -> str:
containing_directory = pathlib.Path().absolute().stem
socket: Peer
peer: Peer
protocol: DnsProtocol
- zone: Optional[dns.zone.Zone] = field(default=None, init=False)
- soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
- node: Optional[dns.node.Node] = field(default=None, init=False)
- answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
- alias: Optional[dns.name.Name] = field(default=None, init=False)
- _initialized_response: Optional[dns.message.Message] = field(
- default=None, init=False
- )
- _initialized_response_with_zone_data: Optional[dns.message.Message] = field(
+ zone: dns.zone.Zone | None = field(default=None, init=False)
+ soa: dns.rrset.RRset | None = field(default=None, init=False)
+ node: dns.node.Node | None = field(default=None, init=False)
+ answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
+ alias: dns.name.Name | None = field(default=None, init=False)
+ _initialized_response: dns.message.Message | None = field(default=None, init=False)
+ _initialized_response_with_zone_data: dns.message.Message | None = field(
default=None, init=False
)
"""
@abc.abstractmethod
- async def perform(self) -> Optional[dns.message.Message | bytes]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
This method is expected to carry out arbitrary actions (e.g. wait for a
specific amount of time, modify the answer, etc.) and then return the
"""
response: dns.message.Message
- authoritative: Optional[bool] = None
+ authoritative: bool | None = None
delay: float = 0.0
acknowledge_hand_rolled_response: bool = False
- async def perform(self) -> Optional[dns.message.Message | bytes]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a dns.message.Message.
"""
response: bytes
delay: float = 0.0
- async def perform(self) -> Optional[dns.message.Message | bytes]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a sequence of bytes.
"""
Action which does nothing - as if a packet was dropped.
"""
- async def perform(self) -> Optional[dns.message.Message | bytes]:
+ async def perform(self) -> dns.message.Message | bytes | None:
return None
delay: float = 0.0
- async def perform(self) -> Optional[dns.message.Message | bytes]:
+ async def perform(self) -> dns.message.Message | bytes | None:
if self.delay > 0:
logging.info("Waiting %.1fs before closing TCP connection", self.delay)
await asyncio.sleep(self.delay)
"""
@property
- def rcode(self) -> Optional[dns.rcode.Rcode]:
+ def rcode(self) -> dns.rcode.Rcode | None:
"""
Optional RCODE to be set in the response.
"""
return []
@property
- def authoritative(self) -> Optional[bool]:
+ def authoritative(self) -> bool | None:
"""
Whether to set the AA bit in the response.
"""
self._domains: list[dns.name.Name] = sorted(
[dns.name.from_text(d) for d in self.domains], reverse=True
)
- self._matched_domain: Optional[dns.name.Name] = None
+ self._matched_domain: dns.name.Name | None = None
@property
def matched_domain(self) -> dns.name.Name:
A node representing a zone with one origin.
"""
- zone: Optional[dns.zone.Zone]
+ zone: dns.zone.Zone | None
children: list["_ZoneTreeNode"] = field(default_factory=list)
node_from.children.remove(child)
node_to.children.append(child)
- def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
+ def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
"""
Return the closest matching zone (if any) for the domain name.
"""
"""
class _DisableTsigHandling(contextlib.ContextDecorator):
- def __init__(self, message: Optional[dns.message.Message] = None) -> None:
+ def __init__(self, message: dns.message.Message | None = None) -> None:
self.original_tsig_sign = dns.tsig.sign
self.original_tsig_validate = dns.tsig.validate
if message:
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
- self._connection_handler: Optional[ConnectionHandler] = None
+ self._connection_handler: ConnectionHandler | None = None
self._response_handlers: list[ResponseHandler] = []
self._default_rcode = default_rcode
self._default_aa = default_aa
async def _read_tcp_query(
self, reader: asyncio.StreamReader, peer: Peer
- ) -> Optional[bytes]:
+ ) -> bytes | None:
wire_length = await self._read_tcp_query_wire_length(reader, peer)
if not wire_length:
return None
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader, peer: Peer
- ) -> Optional[int]:
+ ) -> int | None:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
- ) -> Optional[bytes]:
+ ) -> bytes | None:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await self._read_tcp_octets(reader, peer, wire_length)
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
- ) -> Optional[bytes]:
+ ) -> bytes | None:
buffer = b""
while len(buffer) < expected:
)
def _log_response(
- self, qctx: QueryContext, response: Optional[dns.message.Message | bytes]
+ self, qctx: QueryContext, response: dns.message.Message | bytes | None
) -> None:
if not response:
logging.info(
async def _prepare_responses(
self, qctx: QueryContext
- ) -> AsyncGenerator[Optional[dns.message.Message | bytes], None]:
+ ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Yield response(s) either from response handlers or zone data.
"""
async def _prepare_responses(
self, qctx: QueryContext
- ) -> AsyncGenerator[Optional[dns.message.Message | bytes], None]:
+ ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Detect and handle control queries, falling back to normal processing
for non-control queries.
async for response in super()._prepare_responses(qctx):
yield response
- def _handle_control_command(
- self, qctx: QueryContext
- ) -> Optional[dns.message.Message]:
+ def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
"""
Detect and handle control queries.
@abc.abstractmethod
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ ) -> str | None:
"""
This method is expected to carry out arbitrary actions in response to a
control query. Note that it is invoked synchronously (it is not a
control_subdomain = "send-responses"
def __init__(self) -> None:
- self._current_handler: Optional[IgnoreAllQueries] = None
+ self._current_handler: IgnoreAllQueries | None = None
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ ) -> str | None:
if len(args) != 1:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
def handle(
self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ ) -> str | None:
if len(args) != 1 or args[0] not in self._handler_mapping:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-from typing import Optional, cast
+from typing import cast
import difflib
import os
assert not ede_options, f"unexpected EDE options {ede_options} in {message}"
-def ede(
- message: dns.message.Message, code: EDECode, text: Optional[str] = None
-) -> None:
+def ede(message: dns.message.Message, code: EDECode, text: str | None = None) -> None:
"""Check if message contains expected EDE code (and its text)."""
msg_opts = _extract_ede_options(message)
matching_opts = [opt for opt in msg_opts if opt.code == code]
def rrsets_equal(
first_rrset: dns.rrset.RRset,
second_rrset: dns.rrset.RRset,
- compare_ttl: Optional[bool] = False,
+ compare_ttl: bool | None = False,
) -> None:
"""Compare two RRset (optionally including TTL)"""
def zones_equal(
first_zone: dns.zone.Zone,
second_zone: dns.zone.Zone,
- compare_ttl: Optional[bool] = False,
+ compare_ttl: bool | None = False,
) -> None:
"""Compare two zones (optionally including TTL)"""
# information regarding copyright ownership.
from pathlib import Path
-from typing import NamedTuple, Optional
+from typing import NamedTuple
import os
import re
def __init__(
self,
identifier: str,
- num: Optional[int] = None,
- ports: Optional[NamedPorts] = None,
+ num: int | None = None,
+ ports: NamedPorts | None = None,
) -> None:
"""
`identifier` is the name of the instance's directory
return f"10.53.0.{self.num}"
@staticmethod
- def _identifier_to_num(identifier: str, num: Optional[int] = None) -> int:
+ def _identifier_to_num(identifier: str, num: int | None = None) -> int:
regex_match = re.match(r"^ns(?P<index>[0-9]{1,2})$", identifier)
if not regex_match:
if num is None:
watcher.wait_for_line("all zones loaded")
return cmd
- def stop(self, args: Optional[list[str]] = None) -> None:
+ def stop(self, args: list[str] | None = None) -> None:
"""Stop the instance."""
args = args or []
perl(
[self.system_test_name, self.identifier] + args,
)
- def start(self, args: Optional[list[str]] = None) -> None:
+ def start(self, args: list[str] | None = None) -> None:
"""Start the instance."""
args = args or []
perl(
from functools import total_ordering
from pathlib import Path
from re import compile as Re
-from typing import Optional
import glob
import os
@dataclass
class SettimeOptions:
- P: Optional[str] = None
+ P: str | None = None
"""-P date/[+-]offset/none: set/unset key publication date"""
- P_ds: Optional[str] = None
+ P_ds: str | None = None
"""-P ds date/[+-]offset/none: set/unset DS publication date"""
- P_sync: Optional[str] = None
+ P_sync: str | None = None
"""-P sync date/[+-]offset/none: set/unset CDS and CDNSKEY publication date"""
- A: Optional[str] = None
+ A: str | None = None
"""-A date/[+-]offset/none: set/unset key activation date"""
- R: Optional[str] = None
+ R: str | None = None
"""-R date/[+-]offset/none: set/unset key revocation date"""
- I: Optional[str] = None
+ I: str | None = None
"""-I date/[+-]offset/none: set/unset key inactivation date"""
- D: Optional[str] = None
+ D: str | None = None
"""-D date/[+-]offset/none: set/unset key deletion date"""
- D_ds: Optional[str] = None
+ D_ds: str | None = None
"""-D ds date/[+-]offset/none: set/unset DS deletion date"""
- D_sync: Optional[str] = None
+ D_sync: str | None = None
"""-D sync date/[+-]offset/none: set/unset CDS and CDNSKEY deletion date"""
- g: Optional[str] = None
+ g: str | None = None
"""-g state: set the goal state for this key"""
- d: Optional[str] = None
+ d: str | None = None
"""-d state date/[+-]offset: set the DS state"""
- k: Optional[str] = None
+ k: str | None = None
"""-k state date/[+-]offset: set the DNSKEY state"""
- r: Optional[str] = None
+ r: str | None = None
"""-r state date/[+-]offset: set the RRSIG (KSK) state"""
- z: Optional[str] = None
+ z: str | None = None
"""-z state date/[+-]offset: set the RRSIG (ZSK) state"""
def __str__(self):
operations for KASP tests.
"""
- def __init__(self, name: str, keydir: Optional[str | Path] = None):
+ def __init__(self, name: str, keydir: str | Path | None = None):
self.name = name
if keydir is None:
self.keydir = Path()
def get_timing(
self, metadata: str, must_exist: bool = True
- ) -> Optional[KeyTimingMetadata]:
+ ) -> KeyTimingMetadata | None:
regex = rf";\s+{metadata}:\s+(\d+).*"
with open(self.keyfile, "r", encoding="utf-8") as file:
for line in file:
def keydir_to_keylist(
- zone: Optional[str], keydir: Optional[str] = None, in_use: bool = False
+ zone: str | None, keydir: str | None = None, in_use: bool = False
) -> list[Key]:
"""
Retrieve all keys from the key files in a directory. If 'zone' is None,
return [k for k in all_keys if used(k)]
-def keystr_to_keylist(keystr: str, keydir: Optional[str] = None) -> list[Key]:
+def keystr_to_keylist(keystr: str, keydir: str | None = None) -> list[Key]:
return [Key(name, keydir) for name in keystr.split()]
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-
-from typing import Any, Match, Optional, Pattern, TextIO, TypeAlias, TypeVar
+from typing import Any, Match, Pattern, TextIO, TypeAlias, TypeVar
import abc
import os
...
isctest.log.watchlog.WatchLogException: timeout must be greater than 0
"""
- self._fd: Optional[TextIO] = None
- self._reader: Optional[LineReader] = None
+ self._fd: TextIO | None = None
+ self._reader: LineReader | None = None
self._path = path
self._wait_function_called = False
if timeout <= 0.0:
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-from typing import Any, Callable, Optional
+from typing import Any, Callable
import os
import time
query_func: Callable[..., Any],
message: dns.message.Message,
ip: str,
- port: Optional[int] = None,
- source: Optional[str] = None,
+ port: int | None = None,
+ source: str | None = None,
timeout: int = QUERY_TIMEOUT,
attempts: int = 10,
- expected_rcode: Optional[dns.rcode.Rcode] = None,
+ expected_rcode: dns.rcode.Rcode | None = None,
verify: bool = False,
log_query: bool = True,
log_response: bool = True,
# information regarding copyright ownership.
from pathlib import Path
-from typing import Optional
import os
import subprocess
stderr=subprocess.PIPE,
log_stdout=True,
log_stderr=True,
- input_text: Optional[bytes] = None,
+ input_text: bytes | None = None,
raise_on_exception=True,
- env: Optional[dict] = None,
+ env: dict | None = None,
) -> CmdResult:
"""Execute a command with given args as subprocess."""
isctest.log.debug(f"isctest.run.cmd(): {' '.join(args)}")
def _run_script(
interpreter: str,
script: str,
- args: Optional[list[str]] = None,
+ args: list[str] | None = None,
):
if args is None:
args = []
isctest.log.debug(" exited with %d", returncode)
-def shell(script: str, args: Optional[list[str]] = None) -> None:
+def shell(script: str, args: list[str] | None = None) -> None:
"""Run a given script with system's shell interpreter."""
_run_script(os.environ["SHELL"], script, args)
-def perl(script: str, args: Optional[list[str]] = None) -> None:
+def perl(script: str, args: list[str] | None = None) -> None:
"""Run a given script with system's perl interpreter."""
_run_script(os.environ["PERL"], script, args)
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Optional
+from typing import Any
import jinja2
def render(
self,
output: str,
- data: Optional[dict[str, Any]] = None,
- template: Optional[str] = None,
+ data: dict[str, Any] | None = None,
+ template: str | None = None,
) -> None:
"""
Render `output` file from jinja `template` and fill in the `data`. The
stream = self.j2env.get_template(template).stream(data)
stream.dump(output, encoding="utf-8")
- def render_auto(self, data: Optional[dict[str, Any]] = None):
+ def render_auto(self, data: dict[str, Any] | None = None):
"""
Render all *.j2 templates with default (and optionally the provided)
values and write the output to files without the .j2 extensions.
# information regarding copyright ownership.
from re import compile as Re
-from typing import Iterator, Match, Optional, Pattern, TextIO
+from typing import Iterator, Match, Pattern, TextIO
import abc
import re
self._stream = stream
self._linebuf = ""
- def readline(self) -> Optional[str]:
+ def readline(self) -> str | None:
"""
Wrapper around io.readline() function to handle unfinished lines.
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-from typing import NamedTuple, Optional
+from typing import NamedTuple
import os
import platform
return algs_env
-def set_algorithm_set(name: Optional[str]):
+def set_algorithm_set(name: str | None):
if name is None:
name = "stable"
assert name in ALGORITHM_SETS, f'ALGORITHM_SET "{name}" unknown'
# information regarding copyright ownership.
from re import compile as Re
-from typing import Optional
import os
}
-def parse_openssl_config(path: Optional[str]):
+def parse_openssl_config(path: str | None):
if path is None or not os.path.exists(path):
OPENSSL_VARS["SOFTHSM2_MODULE"] = None
os.environ.pop("SOFTHSM2_MODULE", None)
from dataclasses import dataclass
from pathlib import Path
-from typing import Container, Iterable, Optional
+from typing import Container, Iterable
import os
algorithm: int
flags: int
iterations: int
- salt: Optional[bytes]
+ salt: bytes | None
class NSEC3Checker:
# unnecessary `typing` imports
"UP006",
"UP007",
+ "UP045",
# f-strings
"UP031",
"UP032",