]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add support for saving quic tokens. (#1065)
authorBrian Wellington <bwelling@xbill.org>
Tue, 5 Mar 2024 22:45:59 +0000 (14:45 -0800)
committerGitHub <noreply@github.com>
Tue, 5 Mar 2024 22:45:59 +0000 (14:45 -0800)
This caches tokens in the manager, so that they can be used for address
validation in future connections.

dns/quic/_common.py
dns/quic/_sync.py

index 5e6c40d3c4fda845d28b5348151fb37a604ee436..ada6d54e9e23f1440a4cf748cd9a3b525e87c12d 100644 (file)
@@ -226,6 +226,7 @@ class BaseQuicManager:
         self._connections = {}
         self._connection_factory = connection_factory
         self._session_tickets = {}
+        self._tokens = {}
         self._h3 = h3
         if conf is None:
             verify_path = None
@@ -252,6 +253,7 @@ class BaseQuicManager:
         source=None,
         source_port=0,
         want_session_ticket=True,
+        want_token=True,
     ):
         connection = self._connections.get((address, port))
         if connection is not None:
@@ -273,9 +275,26 @@ class BaseQuicManager:
             )
         else:
             session_ticket_handler = None
+        if want_token:
+            try:
+                token = self._tokens.pop((address, port))
+                # We found a token, so make a configuration that uses it.
+                conf = copy.copy(conf)
+                conf.token = token
+            except KeyError:
+                # No token
+                pass
+            # Whether or not we found a token, we want a handler to save # one.
+            token_handler = functools.partial(self.save_token, address, port)
+        else:
+            token_handler = None
+
+
+
         qconn = aioquic.quic.connection.QuicConnection(
             configuration=conf,
             session_ticket_handler=session_ticket_handler,
+            token_handler=token_handler,
         )
         lladdress = dns.inet.low_level_address_tuple((address, port))
         qconn.connect(lladdress, time.time())
@@ -305,6 +324,17 @@ class BaseQuicManager:
                 del self._session_tickets[key]
         self._session_tickets[(address, port)] = ticket
 
+    def save_token(self, address, port, token):
+        # We rely on dictionaries keys() being in insertion order here.  We
+        # can't just popitem() as that would be LIFO which is the opposite of
+        # what we want.
+        l = len(self._tokens)
+        if l >= MAX_SESSION_TICKETS:
+            keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
+            for key in keys_to_delete:
+                del self._tokens[key]
+        self._tokens[(address, port)] = token
+
 
 class AsyncQuicManager(BaseQuicManager):
     def connect(self, address, port=853, source=None, source_port=0):
index 63ccd4e7aec2f9a8d7ef0b20fe1f05a566fd4975..f253806339a1d7d497f59cbe25d73df35bfb3bf2 100644 (file)
@@ -260,10 +260,11 @@ class SyncQuicManager(BaseQuicManager):
         source=None,
         source_port=0,
         want_session_ticket=True,
+        want_token=True
     ):
         with self._lock:
             (connection, start) = self._connect(
-                address, port, source, source_port, want_session_ticket
+                address, port, source, source_port, want_session_ticket, want_token
             )
             if start:
                 connection.run()
@@ -277,6 +278,10 @@ class SyncQuicManager(BaseQuicManager):
         with self._lock:
             super().save_session_ticket(address, port, ticket)
 
+    def save_token(self, address, port, token):
+        with self._lock:
+            super().save_token(address, port, token)
+
     def __enter__(self):
         return self