]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add QUIC TLS session ticket support.
authorBob Halley <halley@dnspython.org>
Sat, 28 Oct 2023 16:08:00 +0000 (09:08 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 28 Oct 2023 16:08:00 +0000 (09:08 -0700)
dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py

index b2597371849d5fd84ebadc87c1166862e50ea9ee..0f44331f61830b3d9c7da6bb26b4f72e89744d64 100644 (file)
@@ -206,8 +206,12 @@ class AsyncioQuicManager(AsyncQuicManager):
     def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
         super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
 
-    def connect(self, address, port=853, source=None, source_port=0):
-        (connection, start) = self._connect(address, port, source, source_port)
+    def connect(
+        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+    ):
+        (connection, start) = self._connect(
+            address, port, source, source_port, want_session_ticket
+        )
         if start:
             connection.run()
         return connection
index 4cbdcae1d714778088a0e1d2c6b7ac125e6e9886..0eacc691aac712294ff24afbeef91b7dafbcb674 100644 (file)
@@ -1,5 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+import copy
+import functools
 import socket
 import struct
 import time
@@ -11,6 +13,10 @@ import aioquic.quic.connection  # type: ignore
 import dns.inet
 
 QUIC_MAX_DATAGRAM = 2048
+MAX_SESSION_TICKETS = 8
+# If we hit the max sessions limit we will delete this many of the oldest connections.
+# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
+SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
 
 
 class UnexpectedEOF(Exception):
@@ -145,6 +151,7 @@ class BaseQuicManager:
     def __init__(self, conf, verify_mode, connection_factory, server_name=None):
         self._connections = {}
         self._connection_factory = connection_factory
+        self._session_tickets = {}
         if conf is None:
             verify_path = None
             if isinstance(verify_mode, str):
@@ -159,11 +166,33 @@ class BaseQuicManager:
                 conf.load_verify_locations(verify_path)
         self._conf = conf
 
-    def _connect(self, address, port=853, source=None, source_port=0):
+    def _connect(
+        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+    ):
         connection = self._connections.get((address, port))
         if connection is not None:
             return (connection, False)
-        qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
+        conf = self._conf
+        if want_session_ticket:
+            try:
+                session_ticket = self._session_tickets.pop((address, port))
+                # We found a session ticket, so make a configuration that uses it.
+                conf = copy.copy(conf)
+                conf.session_ticket = session_ticket
+            except KeyError:
+                # No session ticket.
+                pass
+            # Whether or not we found a session ticket, we want a handler to save
+            # one.
+            session_ticket_handler = functools.partial(
+                self.save_session_ticket, address, port
+            )
+        else:
+            session_ticket_handler = None
+        qconn = aioquic.quic.connection.QuicConnection(
+            configuration=conf,
+            session_ticket_handler=session_ticket_handler,
+        )
         lladdress = dns.inet.low_level_address_tuple((address, port))
         qconn.connect(lladdress, time.time())
         connection = self._connection_factory(
@@ -178,6 +207,17 @@ class BaseQuicManager:
         except KeyError:
             pass
 
+    def save_session_ticket(self, address, port, ticket):
+        # We rely on dictionaries keys() being in insertion order here.  We
+        # can't just popitem() as that would be LIFO which is the opposite of
+        # what we want.
+        l = len(self._session_tickets)
+        if l >= MAX_SESSION_TICKETS:
+            keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
+            for key in keys_to_delete:
+                del self._session_tickets[key]
+        self._session_tickets[(address, port)] = ticket
+
 
 class AsyncQuicManager(BaseQuicManager):
     def connect(self, address, port=853, source=None, source_port=0):
index a71ac67c4b4e7121d1fd3b0a66b9b844f83597aa..d6731c904c34605dff3a78afec9d2e22f0bbb81f 100644 (file)
@@ -207,9 +207,13 @@ class SyncQuicManager(BaseQuicManager):
         super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
         self._lock = threading.Lock()
 
-    def connect(self, address, port=853, source=None, source_port=0):
+    def connect(
+        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+    ):
         with self._lock:
-            (connection, start) = self._connect(address, port, source, source_port)
+            (connection, start) = self._connect(
+                address, port, source, source_port, want_session_ticket
+            )
             if start:
                 connection.run()
             return connection
@@ -218,6 +222,10 @@ class SyncQuicManager(BaseQuicManager):
         with self._lock:
             super().closed(address, port)
 
+    def save_session_ticket(self, address, port, ticket):
+        with self._lock:
+            super().save_session_ticket(address, port, ticket)
+
     def __enter__(self):
         return self
 
index 0ff0497ec01b44dba3a2bf5e518a088a39847649..0284c98294feb1ddcf406e0a8ccc39b791d6e097 100644 (file)
@@ -186,8 +186,12 @@ class TrioQuicManager(AsyncQuicManager):
         super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
         self._nursery = nursery
 
-    def connect(self, address, port=853, source=None, source_port=0):
-        (connection, start) = self._connect(address, port, source, source_port)
+    def connect(
+        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+    ):
+        (connection, start) = self._connect(
+            address, port, source, source_port, want_session_ticket
+        )
         if start:
             self._nursery.start_soon(connection.run)
         return connection