def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
- def connect(self, address, port=853, source=None, source_port=0):
- (connection, start) = self._connect(address, port, source, source_port)
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
if start:
connection.run()
return connection
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+import copy
+import functools
import socket
import struct
import time
import dns.inet
QUIC_MAX_DATAGRAM = 2048
+MAX_SESSION_TICKETS = 8
+# If we hit the max sessions limit we will delete this many of the oldest connections.
+# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
+SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception):
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {}
self._connection_factory = connection_factory
+ self._session_tickets = {}
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
conf.load_verify_locations(verify_path)
self._conf = conf
- def _connect(self, address, port=853, source=None, source_port=0):
+ def _connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
- qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
+ conf = self._conf
+ if want_session_ticket:
+ try:
+ session_ticket = self._session_tickets.pop((address, port))
+ # We found a session ticket, so make a configuration that uses it.
+ conf = copy.copy(conf)
+ conf.session_ticket = session_ticket
+ except KeyError:
+ # No session ticket.
+ pass
+ # Whether or not we found a session ticket, we want a handler to save
+ # one.
+ session_ticket_handler = functools.partial(
+ self.save_session_ticket, address, port
+ )
+ else:
+ session_ticket_handler = None
+ qconn = aioquic.quic.connection.QuicConnection(
+ configuration=conf,
+ session_ticket_handler=session_ticket_handler,
+ )
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
except KeyError:
pass
+ def save_session_ticket(self, address, port, ticket):
+ # We rely on dictionaries keys() being in insertion order here. We
+ # can't just popitem() as that would be LIFO which is the opposite of
+ # what we want.
+ l = len(self._session_tickets)
+ if l >= MAX_SESSION_TICKETS:
+ keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
+ for key in keys_to_delete:
+ del self._session_tickets[key]
+ self._session_tickets[(address, port)] = ticket
+
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock()
- def connect(self, address, port=853, source=None, source_port=0):
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
with self._lock:
- (connection, start) = self._connect(address, port, source, source_port)
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
if start:
connection.run()
return connection
with self._lock:
super().closed(address, port)
+ def save_session_ticket(self, address, port, ticket):
+ with self._lock:
+ super().save_session_ticket(address, port, ticket)
+
def __enter__(self):
return self
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery
- def connect(self, address, port=853, source=None, source_port=0):
- (connection, start) = self._connect(address, port, source, source_port)
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
if start:
self._nursery.start_soon(connection.run)
return connection