From: Bob Halley Date: Sat, 28 Oct 2023 16:08:00 +0000 (-0700) Subject: Add QUIC TLS session ticket support. X-Git-Tag: v2.5.0rc1~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7bd3df6286b1653b788404f2a6a9e5d2132ced40;p=thirdparty%2Fdnspython.git Add QUIC TLS session ticket support. --- diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index b2597371..0f44331f 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -206,8 +206,12 @@ class AsyncioQuicManager(AsyncQuicManager): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) - def connect(self, address, port=853, source=None, source_port=0): - (connection, start) = self._connect(address, port, source, source_port) + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: connection.run() return connection diff --git a/dns/quic/_common.py b/dns/quic/_common.py index 4cbdcae1..0eacc691 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import copy +import functools import socket import struct import time @@ -11,6 +13,10 @@ import aioquic.quic.connection # type: ignore import dns.inet QUIC_MAX_DATAGRAM = 2048 +MAX_SESSION_TICKETS = 8 +# If we hit the max sessions limit we will delete this many of the oldest connections. +# The value must be a integer > 0 and <= MAX_SESSION_TICKETS. +SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4 class UnexpectedEOF(Exception): @@ -145,6 +151,7 @@ class BaseQuicManager: def __init__(self, conf, verify_mode, connection_factory, server_name=None): self._connections = {} self._connection_factory = connection_factory + self._session_tickets = {} if conf is None: verify_path = None if isinstance(verify_mode, str): @@ -159,11 +166,33 @@ class BaseQuicManager: conf.load_verify_locations(verify_path) self._conf = conf - def _connect(self, address, port=853, source=None, source_port=0): + def _connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): connection = self._connections.get((address, port)) if connection is not None: return (connection, False) - qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf) + conf = self._conf + if want_session_ticket: + try: + session_ticket = self._session_tickets.pop((address, port)) + # We found a session ticket, so make a configuration that uses it. + conf = copy.copy(conf) + conf.session_ticket = session_ticket + except KeyError: + # No session ticket. + pass + # Whether or not we found a session ticket, we want a handler to save + # one. + session_ticket_handler = functools.partial( + self.save_session_ticket, address, port + ) + else: + session_ticket_handler = None + qconn = aioquic.quic.connection.QuicConnection( + configuration=conf, + session_ticket_handler=session_ticket_handler, + ) lladdress = dns.inet.low_level_address_tuple((address, port)) qconn.connect(lladdress, time.time()) connection = self._connection_factory( @@ -178,6 +207,17 @@ class BaseQuicManager: except KeyError: pass + def save_session_ticket(self, address, port, ticket): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._session_tickets) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._session_tickets[key] + self._session_tickets[(address, port)] = ticket + class AsyncQuicManager(BaseQuicManager): def connect(self, address, port=853, source=None, source_port=0): diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index a71ac67c..d6731c90 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -207,9 +207,13 @@ class SyncQuicManager(BaseQuicManager): super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self._lock = threading.Lock() - def connect(self, address, port=853, source=None, source_port=0): + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): with self._lock: - (connection, start) = self._connect(address, port, source, source_port) + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: connection.run() return connection @@ -218,6 +222,10 @@ class SyncQuicManager(BaseQuicManager): with self._lock: super().closed(address, port) + def save_session_ticket(self, address, port, ticket): + with self._lock: + super().save_session_ticket(address, port, ticket) + def __enter__(self): return self diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 0ff0497e..0284c982 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -186,8 +186,12 @@ class TrioQuicManager(AsyncQuicManager): super().__init__(conf, verify_mode, TrioQuicConnection, server_name) self._nursery = nursery - def connect(self, address, port=853, source=None, source_port=0): - (connection, start) = self._connect(address, port, source, source_port) + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: self._nursery.start_soon(connection.run) return connection