class Connection:
- def __init__(self, timeout: TimeoutConfig):
+ def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None):
self.reader = None
self.writer = None
self.state = h11.Connection(our_role=h11.CLIENT)
self.timeout = timeout
+ self.on_release = on_release
+
+ @property
+ def is_closed(self) -> bool:
+ return self.state.our_state in (h11.CLOSED, h11.ERROR)
async def open(
- self,
- hostname: str,
- port: int,
- *,
- ssl: typing.Union[bool, ssl.SSLContext] = False
+ self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None
) -> None:
try:
self.reader, self.writer = await asyncio.wait_for( # type: ignore
assert isinstance(event, h11.Response)
status_code = event.status_code
headers = event.headers
- body = self.body_iter()
+ body = self._body_iter()
return Response(
- status_code=status_code, headers=headers, body=body, on_close=self.close
+ status_code=status_code, headers=headers, body=body, on_close=self._release
)
- async def body_iter(self) -> typing.AsyncIterator[bytes]:
+ async def _body_iter(self) -> typing.AsyncIterator[bytes]:
event = await self._receive_event()
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event()
assert isinstance(event, h11.EndOfMessage)
- await self.close()
async def _send_event(self, event: H11Event) -> None:
assert self.writer is not None
return event
- async def close(self) -> None:
- if self.writer is not None:
+ async def _release(self) -> None:
+ assert self.writer is not None
+
+ if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
+ self.state.start_next_cycle()
+ else:
+ event = h11.ConnectionClosed()
+ try:
+ # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+ self.state.send(event)
+ except h11.ProtocolError:
+ # If we're in some other state then it's a premature close,
+ # and we'll end up in h11.ERROR.
+ pass
+
+ if self.is_closed:
self.writer.close()
if hasattr(self.writer, "wait_closed"):
await self.writer.wait_closed()
+
+ if self.on_release is not None:
+ await self.on_release(self)
import asyncio
+import functools
import os
import ssl
import typing
from .connections import Connection
from .datastructures import URL, Request, Response
+ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port)
+
class ConnectionPool:
def __init__(
self.timeout = timeout
self.limits = limits
self.is_closed = False
+ self.num_active_connections = 0
+ self.num_keepalive_connections = 0
+ self._connections = (
+ {}
+ ) # type: typing.Dict[ConnectionKey, typing.List[Connection]]
async def request(
self,
return response
async def acquire_connection(
- self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
+ self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
) -> Connection:
- connection = Connection(timeout=self.timeout)
- await connection.open(url.hostname, url.port, ssl=ssl)
+ key = (url.scheme, url.hostname, url.port)
+ try:
+ connection = self._connections[key].pop()
+ if not self._connections[key]:
+ del self._connections[key]
+ self.num_keepalive_connections -= 1
+ self.num_active_connections += 1
+
+ except (KeyError, IndexError):
+ release = functools.partial(self.release_connection, key=key)
+ connection = Connection(timeout=self.timeout, on_release=release)
+ self.num_active_connections += 1
+ await connection.open(url.hostname, url.port, ssl=ssl)
+
return connection
- async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
+ async def release_connection(
+ self, connection: Connection, key: ConnectionKey
+ ) -> None:
+ self.num_active_connections -= 1
+ if not connection.is_closed:
+ self.num_keepalive_connections += 1
+ try:
+ self._connections[key].append(connection)
+ except KeyError:
+ self._connections[key] = [connection]
+
+ async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]:
if not url.is_secure:
- return False
+ return None
if not hasattr(self, "ssl_context"):
if not self.ssl_config.verify:
await asyncio.sleep(0.0001)
yield server
finally:
- task.cancel()
+ server.should_exit = True
+ server.force_exit = True
+ await task
--- /dev/null
+import pytest
+
+import httpcore
+
+
+@pytest.mark.asyncio
+async def test_keepalive_connections(server):
+ """
+ Connections should default to staying in a keep-alive state.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_differing_connection_keys(server):
+ """
+ Connnections to differing connection keys should result in multiple connections.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+ response = await http.request("GET", "http://localhost:8000/")
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 2
+
+
+@pytest.mark.asyncio
+async def test_streaming_response_holds_connection(server):
+ """
+ A streaming request should hold the connection open until the response is read.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ assert http.num_active_connections == 1
+ assert http.num_keepalive_connections == 0
+
+ await response.read()
+
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_multiple_concurrent_connections(server):
+ """
+ Multiple conncurrent requests should open multiple conncurrent connections.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ assert http.num_active_connections == 1
+ assert http.num_keepalive_connections == 0
+
+ response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ assert http.num_active_connections == 2
+ assert http.num_keepalive_connections == 0
+
+ await response_b.read()
+ assert http.num_active_connections == 1
+ assert http.num_keepalive_connections == 1
+
+ await response_a.read()
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 2
+
+
+@pytest.mark.asyncio
+async def test_close_connections(server):
+ """
+ Using a `Connection: close` header should close the connection.
+ """
+ headers = [(b"connection", b"close")]
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 0
+
+
+@pytest.mark.asyncio
+async def test_standard_response_close(server):
+ """
+ A standard close should keep the connection open.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ await response.read()
+ await response.close()
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_premature_response_close(server):
+ """
+ A premature close should close the connection.
+ """
+ async with httpcore.ConnectionPool() as http:
+ response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ await response.close()
+ assert http.num_active_connections == 0
+ assert http.num_keepalive_connections == 0