From 81699fe23e87efd25006995a2b55f3e17de970aa Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 18 Oct 2024 10:44:02 -0700 Subject: [PATCH] pyright linting for quic --- dns/quic/__init__.py | 15 +++++++++------ dns/quic/_asyncio.py | 19 ++++++++++++++----- dns/quic/_common.py | 6 ++++-- dns/quic/_sync.py | 9 +++++++-- dns/quic/_trio.py | 6 +++++- pyproject.toml | 2 +- 6 files changed, 40 insertions(+), 17 deletions(-) diff --git a/dns/quic/__init__.py b/dns/quic/__init__.py index 0750e729..e371c417 100644 --- a/dns/quic/__init__.py +++ b/dns/quic/__init__.py @@ -1,6 +1,6 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import List, Tuple +from typing import Any, Dict, List, Tuple import dns._features import dns.asyncbackend @@ -14,8 +14,11 @@ if dns._features.have("doq"): AsyncioQuicManager, AsyncioQuicStream, ) - from dns.quic._common import AsyncQuicConnection, AsyncQuicManager - from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream + from dns.quic._common import AsyncQuicConnection # pyright: ignore + from dns.quic._common import AsyncQuicManager + from dns.quic._sync import SyncQuicConnection # pyright: ignore + from dns.quic._sync import SyncQuicStream # pyright: ignore + from dns.quic._sync import SyncQuicManager have_quic = True @@ -33,7 +36,9 @@ if dns._features.have("doq"): # We have a context factory and a manager factory as for trio we need to have # a nursery. - _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} + _async_factories: Dict[str, Tuple[Any, Any]] = { + "asyncio": (null_factory, _asyncio_manager_factory) + } if dns._features.have("trio"): import trio @@ -60,8 +65,6 @@ if dns._features.have("doq"): else: # pragma: no cover have_quic = False - from typing import Any - class AsyncQuicStream: # type: ignore pass diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index f87515da..0a177b67 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -6,6 +6,8 @@ import ssl import struct import time +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore @@ -144,6 +146,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): datagrams = self._connection.datagrams_to_send(time.time()) for datagram, address in datagrams: assert address == self._peer + assert self._socket is not None await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() try: @@ -161,6 +164,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): return if isinstance(event, aioquic.quic.events.StreamDataReceived): if self.is_h3(): + assert self._h3_conn is not None h3_events = self._h3_conn.handle_event(event) for h3_event in h3_events: if isinstance(h3_event, aioquic.h3.events.HeadersReceived): @@ -186,7 +190,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): self._done = True - self._receiver_task.cancel() + if self._receiver_task is not None: + self._receiver_task.cancel() elif isinstance(event, aioquic.quic.events.StreamReset): stream = self._streams.get(event.stream_id) if stream: @@ -222,21 +227,25 @@ class AsyncioQuicConnection(AsyncQuicConnection): async def close(self): if not self._closed: - self._manager.closed(self._peer[0], self._peer[1]) + if self._manager is not None: + self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() # sender might be blocked on this, so set it self._socket_created.set() await self._wakeup() try: - await self._receiver_task + if self._receiver_task is not None: + await self._receiver_task except asyncio.CancelledError: pass try: - await self._sender_task + if self._sender_task is not None: + await self._sender_task except asyncio.CancelledError: pass - await self._socket.close() + if self._socket is not None: + await self._socket.close() class AsyncioQuicManager(AsyncQuicManager): diff --git a/dns/quic/_common.py b/dns/quic/_common.py index ce575b03..930cf660 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -6,7 +6,7 @@ import functools import socket import struct import time -import urllib +import urllib.parse from typing import Any, Optional import aioquic.h3.connection # type: ignore @@ -165,7 +165,7 @@ class BaseQuicConnection: self._closed = False self._manager = manager self._streams = {} - if manager.is_h3(): + if manager is not None and manager.is_h3(): self._h3_conn = aioquic.h3.connection.H3Connection(connection, False) else: self._h3_conn = None @@ -190,9 +190,11 @@ class BaseQuicConnection: del self._streams[stream_id] def send_headers(self, stream_id, headers, is_end=False): + assert self._h3_conn is not None self._h3_conn.send_headers(stream_id, headers, is_end) def send_data(self, stream_id, data, is_end=False): + assert self._h3_conn is not None self._h3_conn.send_data(stream_id, data, is_end) def _get_timer_values(self, closed_is_special=True): diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index 473d1f48..71159845 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -7,6 +7,8 @@ import struct import threading import time +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore @@ -165,6 +167,7 @@ class SyncQuicConnection(BaseQuicConnection): return if isinstance(event, aioquic.quic.events.StreamDataReceived): if self.is_h3(): + assert self._h3_conn is not None h3_events = self._h3_conn.handle_event(event) for h3_event in h3_events: if isinstance(h3_event, aioquic.h3.events.HeadersReceived): @@ -240,11 +243,13 @@ class SyncQuicConnection(BaseQuicConnection): with self._lock: if self._closed: return - self._manager.closed(self._peer[0], self._peer[1]) + if self._manager is not None: + self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() self._send_wakeup.send(b"\x01") - self._worker_thread.join() + if self._worker_thread is not None: + self._worker_thread.join() class SyncQuicManager(BaseQuicManager): diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index ae79f369..7eead579 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -5,6 +5,8 @@ import ssl import struct import time +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore @@ -137,6 +139,7 @@ class TrioQuicConnection(AsyncQuicConnection): return if isinstance(event, aioquic.quic.events.StreamDataReceived): if self.is_h3(): + assert self._h3_conn is not None h3_events = self._h3_conn.handle_event(event) for h3_event in h3_events: if isinstance(h3_event, aioquic.h3.events.HeadersReceived): @@ -203,7 +206,8 @@ class TrioQuicConnection(AsyncQuicConnection): async def close(self): if not self._closed: - self._manager.closed(self._peer[0], self._peer[1]) + if self._manager is not None: + self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() self._send_pending = True diff --git a/pyproject.toml b/pyproject.toml index d7d72e10..9b0ba018 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,4 +119,4 @@ ignore_missing_imports = true [tool.pyright] reportUnsupportedDunderAll = false -exclude = ["dns/quic/*.py", "examples/*.py", "tests/*.py"] # (mostly) temporary! +exclude = ["examples/*.py", "tests/*.py"] -- 2.47.3