from .config import PoolLimits, SSLConfig, TimeoutConfig
-from .datastructures import URL, Request, Response
+from .connections import Connection
+from .datastructures import URL, Origin, Request, Response
from .exceptions import (
BadResponse,
ConnectTimeout,
+import asyncio
+import os
import ssl
import typing
import asyncio
-import ssl
import typing
import h11
-from .config import TimeoutConfig
-from .datastructures import Request, Response
+from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
+from .datastructures import Client, Origin, Request, Response
from .exceptions import ConnectTimeout, ReadTimeout
H11Event = typing.Union[
class Connection:
- def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None):
- self.reader = None
- self.writer = None
- self.state = h11.Connection(our_role=h11.CLIENT)
+ def __init__(
+ self,
+ origin: typing.Union[str, Origin],
+ ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+ timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+ on_release: typing.Callable = None,
+ ):
+ self.origin = Origin(origin) if isinstance(origin, str) else origin
+ self.ssl = ssl
self.timeout = timeout
self.on_release = on_release
+ self._reader = None
+ self._writer = None
+ self._h11_state = h11.Connection(our_role=h11.CLIENT)
@property
def is_closed(self) -> bool:
- return self.state.our_state in (h11.CLOSED, h11.ERROR)
+ return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
+
+ async def send(
+ self,
+ request: Request,
+ ssl: typing.Optional[SSLConfig] = None,
+ timeout: typing.Optional[TimeoutConfig] = None,
+ stream: bool = False,
+ ) -> Response:
+ assert request.url.origin == self.origin
+
+ if ssl is None:
+ ssl = self.ssl
+ if timeout is None:
+ timeout = self.timeout
+
+ # Make the connection
+ if self._reader is None:
+ await self._connect(ssl, timeout)
- async def open(
- self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None
- ) -> None:
- try:
- self.reader, self.writer = await asyncio.wait_for( # type: ignore
- asyncio.open_connection(hostname, port, ssl=ssl),
- self.timeout.connect_timeout,
- )
- except asyncio.TimeoutError:
- raise ConnectTimeout()
-
- async def send(self, request: Request) -> Response:
+ # Start sending the request.
method = request.method.encode()
target = request.url.target
headers = request.headers
-
- # Start sending the request.
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event)
await self._send_event(event)
# Start getting the response.
- event = await self._receive_event()
+ event = await self._receive_event(timeout)
if isinstance(event, h11.InformationalResponse):
- event = await self._receive_event()
+ event = await self._receive_event(timeout)
assert isinstance(event, h11.Response)
- reason = event.reason.decode('latin1')
+ reason = event.reason.decode("latin1")
status_code = event.status_code
headers = event.headers
- body = self._body_iter()
- return Response(
- status_code=status_code, reason=reason, headers=headers, body=body, on_close=self._release
+ body = self._body_iter(timeout)
+ response = Response(
+ status_code=status_code,
+ reason=reason,
+ headers=headers,
+ body=body,
+ on_close=self._release,
)
- async def _body_iter(self) -> typing.AsyncIterator[bytes]:
- event = await self._receive_event()
+ if not stream:
+ # Read the response body.
+ try:
+ await response.read()
+ finally:
+ await response.close()
+
+ return response
+
+ async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
+ ssl_context = await ssl.load_ssl_context() if self.origin.is_secure else None
+
+ try:
+ self._reader, self._writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_connection(
+ self.origin.hostname, self.origin.port, ssl=ssl_context
+ ),
+ timeout.connect_timeout,
+ )
+ except asyncio.TimeoutError:
+ raise ConnectTimeout()
+
+ async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+ event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
yield event.data
- event = await self._receive_event()
+ event = await self._receive_event(timeout)
assert isinstance(event, h11.EndOfMessage)
async def _send_event(self, event: H11Event) -> None:
- assert self.writer is not None
+ assert self._writer is not None
- data = self.state.send(event)
- self.writer.write(data)
+ data = self._h11_state.send(event)
+ self._writer.write(data)
- async def _receive_event(self) -> H11Event:
- assert self.reader is not None
+ async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
+ assert self._reader is not None
- event = self.state.next_event()
+ event = self._h11_state.next_event()
while event is h11.NEED_DATA:
try:
data = await asyncio.wait_for(
- self.reader.read(2048), self.timeout.read_timeout
+ self._reader.read(2048), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
- self.state.receive_data(data)
- event = self.state.next_event()
+ self._h11_state.receive_data(data)
+ event = self._h11_state.next_event()
return event
async def _release(self) -> None:
- assert self.writer is not None
+ assert self._writer is not None
- if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
- self.state.start_next_cycle()
+ if (
+ self._h11_state.our_state is h11.DONE
+ and self._h11_state.their_state is h11.DONE
+ ):
+ self._h11_state.start_next_cycle()
else:
- self.close()
+ await self.close()
if self.on_release is not None:
await self.on_release(self)
- def close(self) -> None:
- assert self.writer is not None
-
+ async def close(self) -> None:
event = h11.ConnectionClosed()
try:
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
- self.state.send(event)
+ self._h11_state.send(event)
except h11.ProtocolError:
# If we're in some other state then it's a premature close,
# and we'll end up in h11.ERROR.
pass
- self.writer.close()
+ if self._writer is not None:
+ self._writer.close()
def is_secure(self) -> bool:
return self.components.scheme == "https"
+ @property
+ def origin(self) -> "Origin":
+ return Origin(self)
+
+
+class Origin:
+ def __init__(self, url: typing.Union[str, URL]) -> None:
+ if isinstance(url, str):
+ url = URL(url)
+ self.scheme = url.scheme
+ self.hostname = url.hostname
+ self.port = url.port
+
+ @property
+ def is_secure(self) -> bool:
+ return self.scheme == "https"
+
+ def __eq__(self, other: typing.Any) -> bool:
+ return (
+ isinstance(other, self.__class__)
+ and self.scheme == other.scheme
+ and self.hostname == other.hostname
+ and self.port == other.port
+ )
+
+ def __hash__(self) -> int:
+ return hash((self.scheme, self.hostname, self.port))
+
class Request:
def __init__(
self.is_closed = True
if self.on_close is not None:
await self.on_close()
+
+
+class Client:
+ async def send(self, request: Request, **options: typing.Any) -> Response:
+ raise NotImplementedError() # pragma: nocover
+
+ async def close(self) -> None:
+ raise NotImplementedError() # pragma: nocover
import asyncio
-import functools
-import os
-import ssl
import typing
from types import TracebackType
TimeoutConfig,
)
from .connections import Connection
-from .datastructures import URL, Request, Response
+from .datastructures import Client, Origin, Request, Response
from .exceptions import PoolTimeout
-ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig]
-
class ConnectionSemaphore:
def __init__(self, max_connections: int = None):
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
- self.ssl_config = ssl
+ self.ssl = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
self.num_keepalive_connections = 0
self._keepalive_connections = (
{}
- ) # type: typing.Dict[ConnectionKey, typing.List[Connection]]
+ ) # type: typing.Dict[Origin, typing.List[Connection]]
self._max_connections = ConnectionSemaphore(
max_connections=self.limits.hard_limit
)
- async def request(
+ async def send(
self,
- method: str,
- url: str,
- *,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
- body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
- stream: bool = False,
+ request: Request,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
+ stream: bool = False,
) -> Response:
- if ssl is None:
- ssl = self.ssl_config
- if timeout is None:
- timeout = self.timeout
-
- parsed_url = URL(url)
- request = Request(method, parsed_url, headers=headers, body=body)
- connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout)
- response = await connection.send(request)
- if not stream:
- try:
- await response.read()
- finally:
- await response.close()
+ connection = await self.acquire_connection(request.url.origin, timeout=timeout)
+ response = await connection.send(
+ request, ssl=ssl, timeout=timeout, stream=stream
+ )
return response
@property
return self.num_active_connections + self.num_keepalive_connections
async def acquire_connection(
- self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig
+ self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
) -> Connection:
- key = (url.scheme, url.hostname, url.port, ssl, timeout)
try:
- connection = self._keepalive_connections[key].pop()
- if not self._keepalive_connections[key]:
- del self._keepalive_connections[key]
+ connection = self._keepalive_connections[origin].pop()
+ if not self._keepalive_connections[origin]:
+ del self._keepalive_connections[origin]
self.num_keepalive_connections -= 1
self.num_active_connections += 1
except (KeyError, IndexError):
- if url.is_secure:
- ssl_context = await ssl.load_ssl_context()
+ if timeout is None:
+ pool_timeout = self.timeout.pool_timeout
else:
- ssl_context = None
+ pool_timeout = timeout.pool_timeout
try:
- await asyncio.wait_for(
- self._max_connections.acquire(), timeout.pool_timeout
- )
+ await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
- release = functools.partial(self.release_connection, key=key)
- connection = Connection(timeout=timeout, on_release=release)
+ connection = Connection(
+ origin,
+ ssl=self.ssl,
+ timeout=self.timeout,
+ on_release=self.release_connection,
+ )
self.num_active_connections += 1
- await connection.open(url.hostname, url.port, ssl=ssl_context)
return connection
- async def release_connection(
- self, connection: Connection, key: ConnectionKey
- ) -> None:
+ async def release_connection(self, connection: Connection) -> None:
if connection.is_closed:
self._max_connections.release()
self.num_active_connections -= 1
):
self._max_connections.release()
self.num_active_connections -= 1
- connection.close()
+ await connection.close()
else:
self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
- self._keepalive_connections[key].append(connection)
+ self._keepalive_connections[connection.origin].append(connection)
except KeyError:
- self._keepalive_connections[key] = [connection]
+ self._keepalive_connections[connection.origin] = [connection]
async def close(self) -> None:
self.is_closed = True
@pytest.mark.asyncio
async def test_get(server):
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/")
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request)
assert response.status_code == 200
assert response.body == b"Hello, world!"
@pytest.mark.asyncio
async def test_post(server):
- async with httpcore.ConnectionPool() as http:
- response = await http.request(
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request(
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
)
+ response = await client.send(request)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_stream_response(server):
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request, stream=True)
assert response.status_code == 200
assert not hasattr(response, "body")
body = await response.read()
yield b"Hello, "
yield b"world!"
- async with httpcore.ConnectionPool() as http:
- response = await http.request(
- "POST", "http://127.0.0.1:8000/", body=hello_world()
- )
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("POST", "http://127.0.0.1:8000/", body=hello_world())
+ response = await client.send(request)
assert response.status_code == 200
--- /dev/null
+import pytest
+
+import httpcore
+
+
+@pytest.mark.asyncio
+async def test_get(server):
+ client = httpcore.Connection(origin="http://127.0.0.1:8000/")
+ request = httpcore.Request(method="GET", url="http://127.0.0.1:8000/")
+ response = await client.send(request)
+ assert response.status_code == 200
+ assert response.body == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_post(server):
+ client = httpcore.Connection(origin="http://127.0.0.1:8000/")
+ request = httpcore.Request(
+ method="POST", url="http://127.0.0.1:8000/", body=b"Hello, world!"
+ )
+ response = await client.send(request)
+ assert response.status_code == 200
"""
Connections should default to staying in a keep-alive state.
"""
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
- response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
@pytest.mark.asyncio
"""
Connnections to differing connection keys should result in multiple connections.
"""
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
- response = await http.request("GET", "http://localhost:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 2
+ request = httpcore.Request("GET", "http://localhost:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 2
@pytest.mark.asyncio
"""
limits = httpcore.PoolLimits(soft_limit=1)
- async with httpcore.ConnectionPool(limits=limits) as http:
- response = await http.request("GET", "http://127.0.0.1:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ async with httpcore.ConnectionPool(limits=limits) as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
- response = await http.request("GET", "http://localhost:8000/")
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ request = httpcore.Request("GET", "http://localhost:8000/")
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
@pytest.mark.asyncio
"""
A streaming request should hold the connection open until the response is read.
"""
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 0
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request, stream=True)
+ assert client.num_active_connections == 1
+ assert client.num_keepalive_connections == 0
await response.read()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
@pytest.mark.asyncio
"""
Multiple conncurrent requests should open multiple conncurrent connections.
"""
- async with httpcore.ConnectionPool() as http:
- response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 0
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response_a = await client.send(request, stream=True)
+ assert client.num_active_connections == 1
+ assert client.num_keepalive_connections == 0
- response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
- assert http.num_active_connections == 2
- assert http.num_keepalive_connections == 0
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response_b = await client.send(request, stream=True)
+ assert client.num_active_connections == 2
+ assert client.num_keepalive_connections == 0
await response_b.read()
- assert http.num_active_connections == 1
- assert http.num_keepalive_connections == 1
+ assert client.num_active_connections == 1
+ assert client.num_keepalive_connections == 1
await response_a.read()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 2
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 2
@pytest.mark.asyncio
Using a `Connection: close` header should close the connection.
"""
headers = [(b"connection", b"close")]
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 0
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/", headers=headers)
+ response = await client.send(request)
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 0
@pytest.mark.asyncio
"""
A standard close should keep the connection open.
"""
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request, stream=True)
await response.read()
await response.close()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 1
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 1
@pytest.mark.asyncio
"""
A premature close should close the connection.
"""
- async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ async with httpcore.ConnectionPool() as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request, stream=True)
await response.close()
- assert http.num_active_connections == 0
- assert http.num_keepalive_connections == 0
+ assert client.num_active_connections == 0
+ assert client.num_keepalive_connections == 0
import httpcore
-class MockHTTP(httpcore.ConnectionPool):
- async def request(
- self, method, url, *, headers=(), body=b"", stream=False
- ) -> httpcore.Response:
- if stream:
+async def streaming_body():
+ yield b"Hello, "
+ yield b"world!"
- async def streaming_body():
- yield b"Hello, "
- yield b"world!"
- return httpcore.Response(200, body=streaming_body())
- return httpcore.Response(200, body=b"Hello, world!")
-
-
-http = MockHTTP()
-
-
-@pytest.mark.asyncio
-async def test_request():
- response = await http.request("GET", "http://example.com")
+def test_response():
+ response = httpcore.Response(200, body=b"Hello, world!")
assert response.status_code == 200
assert response.reason == "OK"
assert response.body == b"Hello, world!"
@pytest.mark.asyncio
async def test_read_response():
- response = await http.request("GET", "http://example.com")
+ response = httpcore.Response(200, body=b"Hello, world!")
assert response.status_code == 200
assert response.body == b"Hello, world!"
@pytest.mark.asyncio
-async def test_stream_response():
- response = await http.request("GET", "http://example.com")
-
- assert response.status_code == 200
- assert response.body == b"Hello, world!"
- assert response.is_closed
-
- body = b""
- async for part in response.stream():
- body += part
-
- assert body == b"Hello, world!"
- assert response.body == b"Hello, world!"
- assert response.is_closed
-
-
-@pytest.mark.asyncio
-async def test_read_streaming_response():
- response = await http.request("GET", "http://example.com", stream=True)
+async def test_streaming_response():
+ response = httpcore.Response(200, body=streaming_body())
assert response.status_code == 200
assert not hasattr(response, "body")
assert response.is_closed
-@pytest.mark.asyncio
-async def test_stream_streaming_response():
- response = await http.request("GET", "http://example.com", stream=True)
-
- assert response.status_code == 200
- assert not hasattr(response, "body")
- assert not response.is_closed
-
- body = b""
- async for part in response.stream():
- body += part
-
- assert body == b"Hello, world!"
- assert not hasattr(response, "body")
- assert response.is_closed
-
-
@pytest.mark.asyncio
async def test_cannot_read_after_stream_consumed():
- response = await http.request("GET", "http://example.com", stream=True)
+ response = httpcore.Response(200, body=streaming_body())
body = b""
async for part in response.stream():
@pytest.mark.asyncio
async def test_cannot_read_after_response_closed():
- response = await http.request("GET", "http://example.com", stream=True)
+ response = httpcore.Response(200, body=streaming_body())
await response.close()
async def test_read_timeout(server):
timeout = httpcore.TimeoutConfig(read_timeout=0.0001)
- async with httpcore.ConnectionPool(timeout=timeout) as http:
+ async with httpcore.ConnectionPool(timeout=timeout) as client:
with pytest.raises(httpcore.ReadTimeout):
- await http.request("GET", "http://127.0.0.1:8000/slow_response")
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/slow_response")
+ await client.send(request)
@pytest.mark.asyncio
async def test_connect_timeout(server):
timeout = httpcore.TimeoutConfig(connect_timeout=0.0001)
- async with httpcore.ConnectionPool(timeout=timeout) as http:
+ async with httpcore.ConnectionPool(timeout=timeout) as client:
with pytest.raises(httpcore.ConnectTimeout):
# See https://stackoverflow.com/questions/100841/
- await http.request("GET", "http://10.255.255.1/")
+ request = httpcore.Request("GET", "http://10.255.255.1/")
+ await client.send(request)
@pytest.mark.asyncio
timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
limits = httpcore.PoolLimits(hard_limit=1)
- async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as client:
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ response = await client.send(request, stream=True)
with pytest.raises(httpcore.PoolTimeout):
- await http.request("GET", "http://localhost:8000/")
+ request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+ await client.send(request)
await response.read()