]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
pyright linting for quic
authorBob Halley <halley@dnspython.org>
Fri, 18 Oct 2024 17:44:02 +0000 (10:44 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 18 Oct 2024 17:44:02 +0000 (10:44 -0700)
dns/quic/__init__.py
dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py
pyproject.toml

index 0750e729b4401e77bf3da1f8716e23e4538c1d24..e371c417ba8094a285fc2d18aab69533bb129a33 100644 (file)
@@ -1,6 +1,6 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
-from typing import List, Tuple
+from typing import Any, Dict, List, Tuple
 
 import dns._features
 import dns.asyncbackend
@@ -14,8 +14,11 @@ if dns._features.have("doq"):
         AsyncioQuicManager,
         AsyncioQuicStream,
     )
-    from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
-    from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
+    from dns.quic._common import AsyncQuicConnection  # pyright: ignore
+    from dns.quic._common import AsyncQuicManager
+    from dns.quic._sync import SyncQuicConnection  # pyright: ignore
+    from dns.quic._sync import SyncQuicStream  # pyright: ignore
+    from dns.quic._sync import SyncQuicManager
 
     have_quic = True
 
@@ -33,7 +36,9 @@ if dns._features.have("doq"):
     # We have a context factory and a manager factory as for trio we need to have
     # a nursery.
 
-    _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
+    _async_factories: Dict[str, Tuple[Any, Any]] = {
+        "asyncio": (null_factory, _asyncio_manager_factory)
+    }
 
     if dns._features.have("trio"):
         import trio
@@ -60,8 +65,6 @@ if dns._features.have("doq"):
 else:  # pragma: no cover
     have_quic = False
 
-    from typing import Any
-
     class AsyncQuicStream:  # type: ignore
         pass
 
index f87515dacfd2252fbb8204afdd25b8c91dff0207..0a177b676e2ee4925d3207e172186fcb95d3e28e 100644 (file)
@@ -6,6 +6,8 @@ import ssl
 import struct
 import time
 
+import aioquic.h3.connection  # type: ignore
+import aioquic.h3.events  # type: ignore
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
@@ -144,6 +146,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
             datagrams = self._connection.datagrams_to_send(time.time())
             for datagram, address in datagrams:
                 assert address == self._peer
+                assert self._socket is not None
                 await self._socket.sendto(datagram, self._peer, None)
             (expiration, interval) = self._get_timer_values()
             try:
@@ -161,6 +164,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
                 if self.is_h3():
+                    assert self._h3_conn is not None
                     h3_events = self._h3_conn.handle_event(event)
                     for h3_event in h3_events:
                         if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
@@ -186,7 +190,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                 self._handshake_complete.set()
             elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
                 self._done = True
-                self._receiver_task.cancel()
+                if self._receiver_task is not None:
+                    self._receiver_task.cancel()
             elif isinstance(event, aioquic.quic.events.StreamReset):
                 stream = self._streams.get(event.stream_id)
                 if stream:
@@ -222,21 +227,25 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
     async def close(self):
         if not self._closed:
-            self._manager.closed(self._peer[0], self._peer[1])
+            if self._manager is not None:
+                self._manager.closed(self._peer[0], self._peer[1])
             self._closed = True
             self._connection.close()
             # sender might be blocked on this, so set it
             self._socket_created.set()
             await self._wakeup()
             try:
-                await self._receiver_task
+                if self._receiver_task is not None:
+                    await self._receiver_task
             except asyncio.CancelledError:
                 pass
             try:
-                await self._sender_task
+                if self._sender_task is not None:
+                    await self._sender_task
             except asyncio.CancelledError:
                 pass
-            await self._socket.close()
+            if self._socket is not None:
+                await self._socket.close()
 
 
 class AsyncioQuicManager(AsyncQuicManager):
index ce575b038959e56380c3d33e53bb778fb5b5684d..930cf660cacbb4ca413110a6a9e567f67107d8b3 100644 (file)
@@ -6,7 +6,7 @@ import functools
 import socket
 import struct
 import time
-import urllib
+import urllib.parse
 from typing import Any, Optional
 
 import aioquic.h3.connection  # type: ignore
@@ -165,7 +165,7 @@ class BaseQuicConnection:
         self._closed = False
         self._manager = manager
         self._streams = {}
-        if manager.is_h3():
+        if manager is not None and manager.is_h3():
             self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
         else:
             self._h3_conn = None
@@ -190,9 +190,11 @@ class BaseQuicConnection:
         del self._streams[stream_id]
 
     def send_headers(self, stream_id, headers, is_end=False):
+        assert self._h3_conn is not None
         self._h3_conn.send_headers(stream_id, headers, is_end)
 
     def send_data(self, stream_id, data, is_end=False):
+        assert self._h3_conn is not None
         self._h3_conn.send_data(stream_id, data, is_end)
 
     def _get_timer_values(self, closed_is_special=True):
index 473d1f4811e4a11dc67ac88acec282bb68ddb805..7115984540330be4bc43f96657c728f69834b3c6 100644 (file)
@@ -7,6 +7,8 @@ import struct
 import threading
 import time
 
+import aioquic.h3.connection  # type: ignore
+import aioquic.h3.events  # type: ignore
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
@@ -165,6 +167,7 @@ class SyncQuicConnection(BaseQuicConnection):
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
                 if self.is_h3():
+                    assert self._h3_conn is not None
                     h3_events = self._h3_conn.handle_event(event)
                     for h3_event in h3_events:
                         if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
@@ -240,11 +243,13 @@ class SyncQuicConnection(BaseQuicConnection):
         with self._lock:
             if self._closed:
                 return
-            self._manager.closed(self._peer[0], self._peer[1])
+            if self._manager is not None:
+                self._manager.closed(self._peer[0], self._peer[1])
             self._closed = True
             self._connection.close()
             self._send_wakeup.send(b"\x01")
-        self._worker_thread.join()
+        if self._worker_thread is not None:
+            self._worker_thread.join()
 
 
 class SyncQuicManager(BaseQuicManager):
index ae79f36957c20aed90f785720d9a767c85a9825c..7eead5796deb66bf37dc2fbd3ebc583c58e3fbbe 100644 (file)
@@ -5,6 +5,8 @@ import ssl
 import struct
 import time
 
+import aioquic.h3.connection  # type: ignore
+import aioquic.h3.events  # type: ignore
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
@@ -137,6 +139,7 @@ class TrioQuicConnection(AsyncQuicConnection):
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
                 if self.is_h3():
+                    assert self._h3_conn is not None
                     h3_events = self._h3_conn.handle_event(event)
                     for h3_event in h3_events:
                         if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
@@ -203,7 +206,8 @@ class TrioQuicConnection(AsyncQuicConnection):
 
     async def close(self):
         if not self._closed:
-            self._manager.closed(self._peer[0], self._peer[1])
+            if self._manager is not None:
+                self._manager.closed(self._peer[0], self._peer[1])
             self._closed = True
             self._connection.close()
             self._send_pending = True
index d7d72e101a75de8eef7d4bb3324a6a38f56e8c8a..9b0ba018a37c8ad71431240422d6765bd52b70de 100644 (file)
@@ -119,4 +119,4 @@ ignore_missing_imports = true
 
 [tool.pyright]
 reportUnsupportedDunderAll = false
-exclude = ["dns/quic/*.py", "examples/*.py", "tests/*.py"] # (mostly) temporary!
+exclude = ["examples/*.py", "tests/*.py"]