+import functools
import typing
import h2.connection
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
- on_release: typing.Callable = None,
+ pool_release_func: typing.Callable = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
- self.on_release = on_release
+ self.pool_release_func = pool_release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
+ if self.pool_release_func is None:
+ on_release = None
+ else:
+ on_release = functools.partial(self.pool_release_func, self)
+
reader, writer, protocol = await connect(
hostname, port, ssl_context, timeout
)
writer,
origin=self.origin,
timeout=self.timeout,
- on_release=self.on_release,
+ on_release=on_release,
)
else:
self.h11_connection = HTTP11Connection(
writer,
origin=self.origin,
timeout=self.timeout,
- on_release=self.on_release,
+ on_release=on_release,
)
if self.h2_connection is not None:
+import collections.abc
import typing
from .config import (
from .models import Client, Origin, Request, Response
from .streams import PoolSemaphore
+CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
+
+
+class ConnectionStore(collections.abc.Sequence):
+ """
+ We need to maintain collections of connections in a way that allows us to:
+
+ * Lookup connections by origin.
+ * Iterate over connections by insertion time.
+ * Return the total number of connections.
+ """
+
+ def __init__(self) -> None:
+ self.all = {} # type: typing.Dict[HTTPConnection, float]
+ self.by_origin = (
+ {}
+ ) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
+
+ def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]:
+ try:
+ connections = self.by_origin[origin]
+ except KeyError:
+ return None
+
+ connection = next(reversed(list(connections.keys())))
+ del connections[connection]
+ if not connections:
+ del self.by_origin[origin]
+ del self.all[connection]
+
+ return connection
+
+ def add(self, connection: HTTPConnection) -> None:
+ self.all[connection] = 0.0
+ try:
+ self.by_origin[connection.origin][connection] = 0.0
+ except KeyError:
+ self.by_origin[connection.origin] = {connection: 0.0}
+
+ def remove(self, connection: HTTPConnection) -> None:
+ del self.all[connection]
+ del self.by_origin[connection.origin][connection]
+ if not self.by_origin[connection.origin]:
+ del self.by_origin[connection.origin]
+
+ def clear(self) -> None:
+ self.all.clear()
+ self.by_origin.clear()
+
+ def __iter__(self) -> typing.Iterator[HTTPConnection]:
+ return iter(self.all.keys())
+
+ def __getitem__(self, key: typing.Any) -> typing.Any:
+ if key in self.all:
+ return key
+ return None
+
+ def __len__(self) -> int:
+ return len(self.all)
+
class ConnectionPool(Client):
def __init__(
self.timeout = timeout
self.limits = limits
self.is_closed = False
- self.num_active_connections = 0
- self.num_keepalive_connections = 0
- self._keepalive_connections = (
- {}
- ) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
- self._max_connections = PoolSemaphore(limits, timeout)
+
+ self.max_connections = PoolSemaphore(limits, timeout)
+ self.keepalive_connections = ConnectionStore()
+ self.active_connections = ConnectionStore()
+
+ @property
+ def num_connections(self) -> int:
+ return len(self.keepalive_connections) + len(self.active_connections)
async def send(
self,
response = await connection.send(request, ssl=ssl, timeout=timeout)
return response
- @property
- def num_connections(self) -> int:
- return self.num_active_connections + self.num_keepalive_connections
-
async def acquire_connection(
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
) -> HTTPConnection:
- try:
- connection = self._keepalive_connections[origin].pop()
- if not self._keepalive_connections[origin]:
- del self._keepalive_connections[origin]
- self.num_keepalive_connections -= 1
- self.num_active_connections += 1
-
- except (KeyError, IndexError):
- await self._max_connections.acquire(timeout)
+ connection = self.keepalive_connections.pop_by_origin(origin)
+
+ if connection is None:
+ await self.max_connections.acquire(timeout)
connection = HTTPConnection(
origin,
ssl=self.ssl,
timeout=self.timeout,
- on_release=self.release_connection,
+ pool_release_func=self.release_connection,
)
- self.num_active_connections += 1
+
+ self.active_connections.add(connection)
return connection
async def release_connection(self, connection: HTTPConnection) -> None:
if connection.is_closed:
- self._max_connections.release()
- self.num_active_connections -= 1
+ self.active_connections.remove(connection)
+ self.max_connections.release()
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
- self._max_connections.release()
- self.num_active_connections -= 1
+ self.active_connections.remove(connection)
+ self.max_connections.release()
await connection.close()
else:
- self.num_active_connections -= 1
- self.num_keepalive_connections += 1
- try:
- self._keepalive_connections[connection.origin].append(connection)
- except KeyError:
- self._keepalive_connections[connection.origin] = [connection]
+ self.active_connections.remove(connection)
+ self.keepalive_connections.add(connection)
async def close(self) -> None:
self.is_closed = True
- all_connections = []
- for connections in self._keepalive_connections.values():
- all_connections.extend(list(connections))
- self._keepalive_connections.clear()
- for connection in all_connections:
+ connections = list(self.keepalive_connections)
+ self.keepalive_connections.clear()
+ for connection in connections:
await connection.close()
protocol="HTTP/1.1",
headers=headers,
body=body,
- on_close=self._release,
+ on_close=self.response_closed,
)
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
return event
- async def _release(self) -> None:
+ async def response_closed(self) -> None:
if (
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
await self.close()
if self.on_release is not None:
- await self.on_release(self)
+ await self.on_release()
async def close(self) -> None:
event = h11.ConnectionClosed()
async def release(self) -> None:
if self.on_release is not None:
- await self.on_release(self)
+ await self.on_release()
async def close(self) -> None:
await self.writer.close()
import typing
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
-from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
+from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
OptionalTimeout = typing.Optional[TimeoutConfig]
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 2
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
async with httpcore.ConnectionPool(limits=limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 0
+ assert len(http.active_connections) == 1
+ assert len(http.keepalive_connections) == 0
await response.read()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
"""
async with httpcore.ConnectionPool() as http:
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 0
+ assert len(http.active_connections) == 1
+ assert len(http.keepalive_connections) == 0
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 2
- assert http.num_keepalive_connections == 0
+ assert len(http.active_connections) == 2
+ assert len(http.keepalive_connections) == 0
await response_b.read()
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 1
+ assert len(http.keepalive_connections) == 1
await response_a.read()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 2
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 0
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 0
@pytest.mark.asyncio
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.read()
await response.close()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.close()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 0
+ assert len(http.active_connections) == 0
+ assert len(http.keepalive_connections) == 0