From 0ec74e5b20a25bd55c28318cf68033ebc014bce9 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Tue, 5 Mar 2024 14:45:59 -0800 Subject: [PATCH] Add support for saving quic tokens. (#1065) This caches tokens in the manager, so that they can be used for address validation in future connections. --- dns/quic/_common.py | 30 ++++++++++++++++++++++++++++++ dns/quic/_sync.py | 7 ++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/dns/quic/_common.py b/dns/quic/_common.py index 5e6c40d3..ada6d54e 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -226,6 +226,7 @@ class BaseQuicManager: self._connections = {} self._connection_factory = connection_factory self._session_tickets = {} + self._tokens = {} self._h3 = h3 if conf is None: verify_path = None @@ -252,6 +253,7 @@ class BaseQuicManager: source=None, source_port=0, want_session_ticket=True, + want_token=True, ): connection = self._connections.get((address, port)) if connection is not None: @@ -273,9 +275,26 @@ class BaseQuicManager: ) else: session_ticket_handler = None + if want_token: + try: + token = self._tokens.pop((address, port)) + # We found a token, so make a configuration that uses it. + conf = copy.copy(conf) + conf.token = token + except KeyError: + # No token + pass + # Whether or not we found a token, we want a handler to save # one. + token_handler = functools.partial(self.save_token, address, port) + else: + token_handler = None + + + qconn = aioquic.quic.connection.QuicConnection( configuration=conf, session_ticket_handler=session_ticket_handler, + token_handler=token_handler, ) lladdress = dns.inet.low_level_address_tuple((address, port)) qconn.connect(lladdress, time.time()) @@ -305,6 +324,17 @@ class BaseQuicManager: del self._session_tickets[key] self._session_tickets[(address, port)] = ticket + def save_token(self, address, port, token): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._tokens) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._tokens[key] + self._tokens[(address, port)] = token + class AsyncQuicManager(BaseQuicManager): def connect(self, address, port=853, source=None, source_port=0): diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index 63ccd4e7..f2538063 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -260,10 +260,11 @@ class SyncQuicManager(BaseQuicManager): source=None, source_port=0, want_session_ticket=True, + want_token=True ): with self._lock: (connection, start) = self._connect( - address, port, source, source_port, want_session_ticket + address, port, source, source_port, want_session_ticket, want_token ) if start: connection.run() @@ -277,6 +278,10 @@ class SyncQuicManager(BaseQuicManager): with self._lock: super().save_session_ticket(address, port, ticket) + def save_token(self, address, port, token): + with self._lock: + super().save_token(address, port, token) + def __enter__(self): return self -- 2.47.3