from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
from .models import URL, Origin, Request, Response
+from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
from .sync import SyncClient, SyncConnectionPool
__version__ = "0.2.1"
*,
connect_timeout: float = None,
read_timeout: float = None,
+ write_timeout: float = None,
pool_timeout: float = None,
):
if timeout is not None:
# Specified as a single timeout value
assert connect_timeout is None
assert read_timeout is None
+ assert write_timeout is None
assert pool_timeout is None
connect_timeout = timeout
read_timeout = timeout
+ write_timeout = timeout
pool_timeout = timeout
self.timeout = timeout
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
+ self.write_timeout = write_timeout
self.pool_timeout = pool_timeout
def __eq__(self, other: typing.Any) -> bool:
isinstance(other, self.__class__)
and self.connect_timeout == other.connect_timeout
and self.read_timeout == other.read_timeout
+ and self.write_timeout == other.write_timeout
and self.pool_timeout == other.pool_timeout
)
def __hash__(self) -> int:
- as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
+ as_tuple = (
+ self.connect_timeout,
+ self.read_timeout,
+ self.write_timeout,
+ self.pool_timeout,
+ )
return hash(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
if self.timeout is not None:
return f"{class_name}(timeout={self.timeout})"
- return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
+ return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})"
class PoolLimits:
-import asyncio
import typing
import h2.connection
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
from .models import Client, Origin, Request, Response
+from .streams import Protocol, connect
class HTTPConnection(Client):
if timeout is None:
timeout = self.timeout
- reader, writer, protocol = await self.connect(ssl, timeout)
- if protocol == "h2":
+ hostname = self.origin.hostname
+ port = self.origin.port
+ ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
+
+ reader, writer, protocol = await connect(
+ hostname, port, ssl_context, timeout
+ )
+ if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(
reader,
writer,
async def close(self) -> None:
if self.h2_connection is not None:
await self.h2_connection.close()
- else:
- assert self.h11_connection is not None
+ elif self.h11_connection is not None:
await self.h11_connection.close()
@property
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed
-
- async def connect(
- self, ssl: SSLConfig, timeout: TimeoutConfig
- ) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]:
- hostname = self.origin.hostname
- port = self.origin.port
- ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
-
- try:
- reader, writer = await asyncio.wait_for( # type: ignore
- asyncio.open_connection(hostname, port, ssl=ssl_context),
- timeout.connect_timeout,
- )
- except asyncio.TimeoutError:
- raise ConnectTimeout()
-
- ssl_object = writer.get_extra_info("ssl_object")
- if ssl_object is None:
- protocol = "http/1.1"
- else:
- protocol = ssl_object.selected_alpn_protocol()
- if protocol is None:
- protocol = ssl_object.selected_npn_protocol()
-
- return (reader, writer, protocol)
"""
+class WriteTimeout(Timeout):
+ """
+ Timeout while writing request data.
+ """
+
+
class PoolTimeout(Timeout):
"""
Timeout while waiting to acquire a connection from the pool.
-import asyncio
import typing
import h11
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
from .exceptions import ConnectTimeout, ReadTimeout
from .models import Client, Origin, Request, Response
+from .streams import BaseReader, BaseWriter
H11Event = typing.Union[
h11.Request,
]
+OptionalTimeout = typing.Optional[TimeoutConfig]
+
+
class HTTP11Connection(Client):
+ READ_NUM_BYTES = 4096
+
def __init__(
self,
- reader: asyncio.StreamReader,
- writer: asyncio.StreamWriter,
+ reader: BaseReader,
+ writer: BaseWriter,
origin: Origin,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None
) -> Response:
- if timeout is None:
- timeout = self.timeout
-
# Start sending the request.
method = request.method.encode()
target = request.url.full_path
headers = request.headers
event = h11.Request(method=method, target=target, headers=headers)
- await self._send_event(event)
+ await self._send_event(event, timeout)
# Send the request body.
async for data in request.stream():
event = h11.Data(data=data)
- await self._send_event(event)
+ await self._send_event(event, timeout)
# Finalize sending the request.
event = h11.EndOfMessage()
- await self._send_event(event)
+ await self._send_event(event, timeout)
# Start getting the response.
event = await self._receive_event(timeout)
on_close=self._release,
)
- async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+ async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event(timeout)
assert isinstance(event, h11.EndOfMessage)
- async def _send_event(self, event: H11Event) -> None:
+ async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
data = self.h11_state.send(event)
- self.writer.write(data)
+ await self.writer.write(data, timeout)
- async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
+ async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
event = self.h11_state.next_event()
while event is h11.NEED_DATA:
- try:
- data = await asyncio.wait_for(
- self.reader.read(2048), timeout.read_timeout
- )
- except asyncio.TimeoutError:
- raise ReadTimeout()
+ data = await self.reader.read(self.READ_NUM_BYTES, timeout)
self.h11_state.receive_data(data)
event = self.h11_state.next_event()
# and we'll end up in h11.ERROR.
pass
- if self.writer is not None:
- self.writer.close()
+ await self.writer.close()
-import asyncio
import typing
import h2.connection
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
from .exceptions import ConnectTimeout, ReadTimeout
from .models import Client, Origin, Request, Response
+from .streams import BaseReader, BaseWriter
+
+OptionalTimeout = typing.Optional[TimeoutConfig]
class HTTP2Connection(Client):
+ READ_NUM_BYTES = 4096
+
def __init__(
self,
- reader: asyncio.StreamReader,
- writer: asyncio.StreamWriter,
+ reader: BaseReader,
+ writer: BaseWriter,
origin: Origin,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
self.initiate_connection()
# Start sending the request.
- stream_id = await self.send_headers(stream_id, request)
+ stream_id = await self.send_headers(request, timeout)
self.events[stream_id] = []
# Send the request body.
async for data in request.stream():
- await self.send_data(stream_id, data)
+ await self.send_data(stream_id, data, timeout)
# Finalize sending the request.
- await self.end_stream(stream_id)
+ await self.end_stream(stream_id, timeout)
# Start getting the response.
while True:
def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
- self.writer.write(data_to_send)
+ self.writer.write_no_block(data_to_send)
self.initialized = True
- async def send_headers(self, stream_id: int, request: Request) -> int:
+ async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
(b":method", request.method.encode()),
] + request.headers
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
- self.writer.write(data_to_send)
+ await self.writer.write(data_to_send, timeout)
return stream_id
- async def send_data(self, stream_id: int, data: bytes) -> None:
+ async def send_data(
+ self, stream_id: int, data: bytes, timeout: OptionalTimeout
+ ) -> None:
self.h2_state.send_data(stream_id, data)
data_to_send = self.h2_state.data_to_send()
- self.writer.write(data_to_send)
+ await self.writer.write(data_to_send, timeout)
- async def end_stream(self, stream_id: int) -> None:
+ async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
- self.writer.write(data_to_send)
+ await self.writer.write(data_to_send, timeout)
async def body_iter(
- self, stream_id: int, timeout: TimeoutConfig
+ self, stream_id: int, timeout: OptionalTimeout
) -> typing.AsyncIterator[bytes]:
while True:
event = await self.receive_event(stream_id, timeout)
break
async def receive_event(
- self, stream_id: int, timeout: TimeoutConfig
+ self, stream_id: int, timeout: OptionalTimeout
) -> h2.events.Event:
while not self.events[stream_id]:
- try:
- data = await asyncio.wait_for(
- self.reader.read(2048), timeout.read_timeout
- )
- except asyncio.TimeoutError:
- raise ReadTimeout()
-
+ data = await self.reader.read(self.READ_NUM_BYTES, timeout)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
- if data_to_send:
- self.writer.write(data_to_send)
+ await self.writer.write(data_to_send, timeout)
return self.events[stream_id].pop(0)
--- /dev/null
+"""
+The `Reader` and `Writer` classes here provide a lightweight layer over
+`asyncio.StreamReader` and `asyncio.StreamWriter`.
+
+They help encapsulate the timeout logic, make it easier to unit-test
+protocols, and help keep the rest of the package more `async`/`await`
+based, and less strictly `asyncio`-specific.
+"""
+import asyncio
+import enum
+import ssl
+import typing
+
+from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
+from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+
+OptionalTimeout = typing.Optional[TimeoutConfig]
+
+
+class Protocol(enum.Enum):
+ HTTP_11 = 1
+ HTTP_2 = 2
+
+
+class BaseReader:
+ async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+ raise NotImplementedError() # pragma: no cover
+
+
+class BaseWriter:
+ def write_no_block(self, data: bytes) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ async def close(self) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+
+class Reader(BaseReader):
+ def __init__(
+ self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
+ ) -> None:
+ self.stream_reader = stream_reader
+ self.timeout = timeout
+
+ async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+ if timeout is None:
+ timeout = self.timeout
+
+ try:
+ data = await asyncio.wait_for(
+ self.stream_reader.read(n), timeout.read_timeout
+ )
+ except asyncio.TimeoutError:
+ raise ReadTimeout()
+
+ return data
+
+
+class Writer(BaseWriter):
+ def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
+ self.stream_writer = stream_writer
+ self.timeout = timeout
+
+ def write_no_block(self, data: bytes) -> None:
+ self.stream_writer.write(data)
+
+ async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+ if not data:
+ return
+
+ if timeout is None:
+ timeout = self.timeout
+
+ self.stream_writer.write(data)
+ try:
+ data = await asyncio.wait_for( # type: ignore
+ self.stream_writer.drain(), timeout.write_timeout
+ )
+ except asyncio.TimeoutError:
+ raise WriteTimeout()
+
+ async def close(self) -> None:
+ self.stream_writer.close()
+
+
+async def connect(
+ hostname: str,
+ port: int,
+ ssl_context: typing.Optional[ssl.SSLContext] = None,
+ timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+) -> typing.Tuple[Reader, Writer, Protocol]:
+ try:
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_connection(hostname, port, ssl=ssl_context),
+ timeout.connect_timeout,
+ )
+ except asyncio.TimeoutError:
+ raise ConnectTimeout()
+
+ ssl_object = stream_writer.get_extra_info("ssl_object")
+ if ssl_object is None:
+ ident = "http/1.1"
+ else:
+ ident = ssl_object.selected_alpn_protocol()
+ if ident is None:
+ ident = ssl_object.selected_npn_protocol()
+
+ reader = Reader(stream_reader=stream_reader, timeout=timeout)
+ writer = Writer(stream_writer=stream_writer, timeout=timeout)
+ protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
+
+ return (reader, writer, protocol)
timeout = httpcore.TimeoutConfig(read_timeout=5.0)
assert (
repr(timeout)
- == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
+ == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)"
)
--- /dev/null
+import h2.config
+import h2.connection
+import h2.events
+import pytest
+
+import httpcore
+
+
+class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
+ """
+ This class exposes Reader and Writer style interfaces
+ """
+
+ def __init__(self):
+ config = h2.config.H2Configuration(client_side=False)
+ self.conn = h2.connection.H2Connection(config=config)
+ self.buffer = b""
+ self.requests = {}
+
+ # BaseReader interface
+
+ async def read(self, n, timeout) -> bytes:
+ send, self.buffer = self.buffer[:n], self.buffer[n:]
+ return send
+
+ # BaseWriter interface
+
+ def write_no_block(self, data: bytes) -> None:
+ events = self.conn.receive_data(data)
+ self.buffer += self.conn.data_to_send()
+ for event in events:
+ if isinstance(event, h2.events.RequestReceived):
+ self.request_received(event.headers, event.stream_id)
+ elif isinstance(event, h2.events.DataReceived):
+ self.receive_data(event.data, event.stream_id)
+ elif isinstance(event, h2.events.StreamEnded):
+ self.stream_complete(event.stream_id)
+
+ async def write(self, data: bytes, timeout) -> None:
+ self.write_no_block(data)
+
+ async def close(self) -> None:
+ pass
+
+ # Server implementation
+
+ def request_received(self, headers, stream_id):
+ if stream_id not in self.requests:
+ self.requests[stream_id] = []
+ self.requests[stream_id].append({"headers": headers, "data": b""})
+
+ def receive_data(self, data, stream_id):
+ self.requests[stream_id][-1]["data"] += data
+
+ def stream_complete(self, stream_id):
+ requests = self.requests[stream_id].pop(0)
+ if not self.requests[stream_id]:
+ del self.requests[stream_id]
+
+ response_headers = (
+ (b":status", b"200"),
+ )
+ response_body = b"Hello, world!"
+ self.conn.send_headers(stream_id, response_headers)
+ self.conn.send_data(stream_id, response_body, end_stream=True)
+ self.buffer += self.conn.data_to_send()
+
+
+@pytest.mark.asyncio
+async def test_http2():
+ server = MockServer()
+ origin = httpcore.Origin("http://example.org")
+ client = httpcore.HTTP2Connection(reader=server, writer=server, origin=origin)
+ response = await client.request("GET", "http://example.org")
+ assert response.status_code == 200
+ assert response.body == b"Hello, world!"