-import asyncio
import typing
from .config import (
from .connection import HTTPConnection
from .exceptions import PoolTimeout
from .models import Client, Origin, Request, Response
+from .streams import PoolSemaphore
class ConnectionPool(Client):
self._keepalive_connections = (
{}
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
- self._max_connections = ConnectionSemaphore(
- max_connections=self.limits.hard_limit
- )
+ self._max_connections = PoolSemaphore(limits, timeout)
async def send(
self,
self.num_active_connections += 1
except (KeyError, IndexError):
- if timeout is None:
- pool_timeout = self.timeout.pool_timeout
- else:
- pool_timeout = timeout.pool_timeout
-
- try:
- await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
- except asyncio.TimeoutError:
- raise PoolTimeout()
+ await self._max_connections.acquire(timeout)
connection = HTTPConnection(
origin,
ssl=self.ssl,
self._keepalive_connections.clear()
for connection in all_connections:
await connection.close()
-
-
-class ConnectionSemaphore:
- def __init__(self, max_connections: int = None):
- self.max_connections = max_connections
-
- @property
- def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
- if not hasattr(self, "_semaphore"):
- if self.max_connections is None:
- self._semaphore = None
- else:
- self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
- return self._semaphore
-
- async def acquire(self) -> None:
- if self.semaphore is not None:
- await self.semaphore.acquire()
-
- def release(self) -> None:
- if self.semaphore is not None:
- self.semaphore.release()
The `Reader` and `Writer` classes here provide a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
-They help encapsulate the timeout logic, make it easier to unit-test
+Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
+
+These classes help encapsulate the timeout logic, make it easier to unit-test
protocols, and help keep the rest of the package more `async`/`await`
based, and less strictly `asyncio`-specific.
"""
import ssl
import typing
-from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
-from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
+from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
OptionalTimeout = typing.Optional[TimeoutConfig]
raise NotImplementedError() # pragma: no cover
+class BasePoolSemaphore:
+ def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
+ raise NotImplementedError() # pragma: no cover
+
+ async def acquire(self, timeout: OptionalTimeout = None) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ def release(self) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+
class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
self.stream_writer.close()
+class PoolSemaphore(BasePoolSemaphore):
+ def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
+ self.limits = limits
+ self.timeout = timeout
+
+ @property
+ def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
+ if not hasattr(self, "_semaphore"):
+ max_connections = self.limits.hard_limit
+ if max_connections is None:
+ self._semaphore = None
+ else:
+ self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
+ return self._semaphore
+
+ async def acquire(self, timeout: OptionalTimeout = None) -> None:
+ if self.semaphore is None:
+ return
+
+ if timeout is None:
+ timeout = self.timeout
+
+ try:
+ await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
+ except asyncio.TimeoutError:
+ raise PoolTimeout()
+
+ def release(self) -> None:
+ if self.semaphore is None:
+ return
+
+ self.semaphore.release()
+
+
async def connect(
hostname: str,
port: int,
+import json
+
import h2.config
import h2.connection
import h2.events
import pytest
-import json
import httpcore
request_headers = dict(request["headers"])
request_data = request["data"]
- response_body = json.dumps({
- "method": request_headers[b":method"].decode(),
- "path": request_headers[b":path"].decode(),
- "body": request_data.decode()
- }).encode()
+ response_body = json.dumps(
+ {
+ "method": request_headers[b":method"].decode(),
+ "path": request_headers[b":path"].decode(),
+ "body": request_data.decode(),
+ }
+ ).encode()
- response_headers = (
- (b":status", b"200"),
- )
+ response_headers = ((b":status", b"200"),)
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response_body, end_stream=True)
self.buffer += self.conn.data_to_send()
async def test_http2_get_request():
server = MockServer()
origin = httpcore.Origin("http://example.org")
- async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+ async with httpcore.HTTP2Connection(
+ reader=server, writer=server, origin=origin
+ ) as client:
response = await client.request("GET", "http://example.org")
assert response.status_code == 200
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
async def test_http2_post_request():
server = MockServer()
origin = httpcore.Origin("http://example.org")
- async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+ async with httpcore.HTTP2Connection(
+ reader=server, writer=server, origin=origin
+ ) as client:
response = await client.request("POST", "http://example.org", body=b"<data>")
assert response.status_code == 200
- assert json.loads(response.body) == {"method": "POST", "path": "/", "body": "<data>"}
+ assert json.loads(response.body) == {
+ "method": "POST",
+ "path": "/",
+ "body": "<data>",
+ }
@pytest.mark.asyncio
async def test_http2_multiple_requests():
server = MockServer()
origin = httpcore.Origin("http://example.org")
- async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+ async with httpcore.HTTP2Connection(
+ reader=server, writer=server, origin=origin
+ ) as client:
response_1 = await client.request("GET", "http://example.org/1")
response_2 = await client.request("GET", "http://example.org/2")
response_3 = await client.request("GET", "http://example.org/3")