From: Tom Christie Date: Sat, 6 Apr 2019 12:18:39 +0000 (+0100) Subject: Connections X-Git-Tag: 0.0.3~1^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=86263fa073aa97cc4f11026e9ca81bef12599f1b;p=thirdparty%2Fhttpx.git Connections --- diff --git a/README.md b/README.md index c6b8d257..63eb03dd 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,8 @@ of it, and exposes only plain datastructures that reflect the network response. ```python import httpcore -response = await httpcore.request('GET', 'http://example.com') +http = httpcore.ConnectionPool() +response = await http.request('GET', 'http://example.com') assert response.status_code == 200 assert response.body == b'Hello, world' ``` @@ -71,20 +72,22 @@ assert response.body == b'Hello, world' Top-level API... ```python -response = await httpcore.request(method, url, [headers], [body], [stream]) +http = httpcore.ConnectionPool([ssl], [timeout], [limits]) +response = await http.request(method, url, [headers], [body], [stream]) ``` -Explicit PoolManager... +ConnectionPool as a context-manager... ```python -async with httpcore.PoolManager([ssl], [timeout], [limits]) as pool: - response = await pool.request(method, url, [headers], [body], [stream]) +async with httpcore.ConnectionPool([ssl], [timeout], [limits]) as http: + response = await http.request(method, url, [headers], [body], [stream]) ``` Streaming... ```python -response = await httpcore.request(method, url, stream=True) +http = httpcore.ConnectionPool() +response = await http.request(method, url, stream=True) async for part in response.stream(): ... ``` @@ -100,7 +103,7 @@ import httpcore class GatewayServer: def __init__(self, base_url): self.base_url = base_url - self.pool = httpcore.PoolManager() + self.http = httpcore.ConnectionPool() async def __call__(self, scope, receive, send): assert scope['type'] == 'http' @@ -122,7 +125,7 @@ class GatewayServer: if not message.get('more_body', False): break - response = await self.pool.request( + response = await self.http.request( method, url, headers=headers, body=body, stream=True ) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 69894f36..24a6bbfb 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,5 +1,6 @@ -from .api import PoolManager, Response, request from .config import PoolLimits, SSLConfig, TimeoutConfig +from .datastructures import URL, Request, Response from .exceptions import ResponseClosed, StreamConsumed +from .pool import ConnectionPool __version__ = "0.0.2" diff --git a/httpcore/api.py b/httpcore/api.py deleted file mode 100644 index 24c4fec0..00000000 --- a/httpcore/api.py +++ /dev/null @@ -1,67 +0,0 @@ -import typing -from types import TracebackType - -from .config import ( - DEFAULT_POOL_LIMITS, - DEFAULT_SSL_CONFIG, - DEFAULT_TIMEOUT_CONFIG, - PoolLimits, - SSLConfig, - TimeoutConfig, -) -from .models import Response - - -async def request( - method: str, - url: str, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - stream: bool = False, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, -) -> Response: - async with PoolManager(ssl=ssl, timeout=timeout) as pool: - return await pool.request( - method=method, url=url, headers=headers, body=body, stream=stream - ) - - -class PoolManager: - def __init__( - self, - *, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - limits: PoolLimits = DEFAULT_POOL_LIMITS, - ): - self.ssl = ssl - self.timeout = timeout - self.limits = limits - self.is_closed = False - - async def request( - self, - method: str, - url: str, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - stream: bool = False, - ) -> Response: - raise NotImplementedError() - - async def close(self) -> None: - self.is_closed = True - - async def __aenter__(self) -> "PoolManager": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.close() diff --git a/httpcore/config.py b/httpcore/config.py index aa0c0718..d169e0af 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -1,5 +1,7 @@ import typing +import certifi + class SSLConfig: """ @@ -52,3 +54,4 @@ class PoolLimits: DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False) +DEFAULT_CA_BUNDLE_PATH = certifi.where() diff --git a/httpcore/connections.py b/httpcore/connections.py index dc9c4ff0..db27b03c 100644 --- a/httpcore/connections.py +++ b/httpcore/connections.py @@ -1,62 +1,119 @@ -from config import TimeoutConfig - import asyncio -import h11 import ssl +import typing + +import h11 + +from .config import TimeoutConfig +from .datastructures import Request, Response +from .exceptions import ConnectTimeout, ReadTimeout + +H11Event = typing.Union[ + h11.Request, + h11.Response, + h11.InformationalResponse, + h11.Data, + h11.EndOfMessage, + h11.ConnectionClosed, +] class Connection: - def __init__(self): + def __init__(self, timeout: TimeoutConfig): self.reader = None self.writer = None self.state = h11.Connection(our_role=h11.CLIENT) + self.timeout = timeout - async def open(self, host: str, port: int, ssl: ssl.SSLContext): + async def open( + self, + hostname: str, + port: int, + *, + ssl: typing.Union[bool, ssl.SSLContext] = False + ) -> None: try: - self.reader, self.writer = await asyncio.wait_for( - asyncio.open_connection(host, port, ssl=ssl), timeout + self.reader, self.writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl), + self.timeout.connect_timeout, ) except asyncio.TimeoutError: raise ConnectTimeout() - async def send(self, request: Request) -> Response: - method = request.method - - target = request.url.path - if request.url.query: - target += "?" + request.url.query + async def send(self, request: Request, stream: bool=False) -> Response: + method = request.method.encode() + target = request.url.target + host_header = (b"host", request.url.netloc.encode("ascii")) + if request.is_streaming: + content_length = (b"transfer-encoding", b"chunked") + else: + content_length = (b"content-length", str(len(request.body)).encode()) - headers = [ - ("host", request.url.netloc) - ] += request.headers + headers = [host_header, content_length] + request.headers - # Send the request method, path/query, and headers. + #  Start sending the request. event = h11.Request(method=method, target=target, headers=headers) await self._send_event(event) # Send the request body. if request.is_streaming: - async for data in request.raw(): + async for data in request.stream(): event = h11.Data(data=data) await self._send_event(event) - else: + elif request.body: event = h11.Data(data=request.body) await self._send_event(event) # Finalize sending the request. event = h11.EndOfMessage() - await connection.send_event(event) + await self._send_event(event) + + # Start getting the response. + event = await self._receive_event() + if isinstance(event, h11.InformationalResponse): + event = await self._receive_event() + assert isinstance(event, h11.Response) + status_code = event.status_code + headers = event.headers + + if stream: + return Response(status_code=status_code, headers=headers, body=self.body_iter()) + + #  Get the response body. + body = b"" + event = await self._receive_event() + while isinstance(event, h11.Data): + body += event.data + event = await self._receive_event() + assert isinstance(event, h11.EndOfMessage) + await self.close() - async def _send_event(self, message): - data = self.state.send(message) + return Response(status_code=status_code, headers=headers, body=body) + + async def body_iter(self) -> typing.Iterable[bytes]: + event = await self._receive_event() + while isinstance(event, h11.Data): + yield event.data + event = await self._receive_event() + assert isinstance(event, h11.EndOfMessage) + await self.close() + + async def _send_event(self, event: H11Event) -> None: + assert self.writer is not None + + data = self.state.send(event) self.writer.write(data) - async def _receive_event(self, timeout): + async def _receive_event(self) -> H11Event: + assert self.reader is not None + event = self.state.next_event() - while type(event) is h11.NEED_DATA: + while event is h11.NEED_DATA: try: - data = await asyncio.wait_for(self.reader.read(2048), timeout) + data = await asyncio.wait_for( + self.reader.read(2048), self.timeout.read_timeout + ) except asyncio.TimeoutError: raise ReadTimeout() self.state.receive_data(data) @@ -64,7 +121,8 @@ class Connection: return event - async def close(self): - self.writer.close() - if hasattr(self.writer, "wait_closed"): - await self.writer.wait_closed() + async def close(self) -> None: + if self.writer is not None: + self.writer.close() + if hasattr(self.writer, "wait_closed"): + await self.writer.wait_closed() diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py new file mode 100644 index 00000000..d60e18a5 --- /dev/null +++ b/httpcore/datastructures.py @@ -0,0 +1,145 @@ +import typing +from urllib.parse import urlsplit + +from .decoders import IdentityDecoder +from .exceptions import ResponseClosed, StreamConsumed + + +class URL: + def __init__(self, url: str = "") -> None: + self.components = urlsplit(url) + if not self.components.scheme: + raise ValueError("No scheme included in URL.") + if self.components.scheme not in ("http", "https"): + raise ValueError('URL scheme must be "http" or "https".') + if not self.components.hostname: + raise ValueError("No hostname included in URL.") + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def hostname(self) -> str: + return self.components.hostname + + @property + def port(self) -> int: + port = self.components.port + if port is None: + return {"https": 443, "http": 80}[self.scheme] + return port + + @property + def target(self) -> str: + path = self.path or "/" + query = self.query + if query: + return path + "?" + query + return path + + @property + def is_secure(self) -> bool: + return self.components.scheme == "https" + + +class Request: + def __init__( + self, + method: str, + url: URL, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + ): + self.method = method + self.url = url + self.headers = list(headers) + if isinstance(body, bytes): + self.is_streaming = False + self.body = body + else: + self.is_streaming = True + self.body_aiter = body + + async def stream(self) -> typing.AsyncIterator[bytes]: + assert self.is_streaming + + async for part in self.body_aiter: + yield part + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + on_close: typing.Callable = None, + ): + self.status_code = status_code + self.headers = list(headers) + self.on_close = on_close + self.is_closed = False + self.is_streamed = False + self.decoder = IdentityDecoder() + if isinstance(body, bytes): + self.is_closed = True + self.body = body + else: + self.body_aiter = body + + async def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "body"): + body = b"" + async for part in self.stream(): + body += part + self.body = body + return self.body + + async def stream(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This will allow us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "body"): + yield self.body + else: + async for chunk in self.raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() + + async def raw(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_streamed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() + self.is_streamed = True + async for part in self.body_aiter: + yield part + await self.close() + + async def close(self) -> None: + if not self.is_closed: + self.is_closed = True + if self.on_close is not None: + await self.on_close() diff --git a/httpcore/decoders.py b/httpcore/decoders.py index 09b9336b..2d35a44f 100644 --- a/httpcore/decoders.py +++ b/httpcore/decoders.py @@ -8,7 +8,7 @@ class IdentityDecoder: return chunk def flush(self) -> bytes: - return b'' + return b"" # class DeflateDecoder: diff --git a/httpcore/models.py b/httpcore/models.py deleted file mode 100644 index edf174b1..00000000 --- a/httpcore/models.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -from .decoders import IdentityDecoder -from .exceptions import ResponseClosed, StreamConsumed - - -class Response: - def __init__( - self, - status_code: int, - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - on_close: typing.Callable = None, - ): - self.status_code = status_code - self.headers = list(headers) - self.on_close = on_close - self.is_closed = False - self.is_streamed = False - self.decoder = IdentityDecoder() - if isinstance(body, bytes): - self.is_closed = True - self.body = body - else: - self.body_aiter = body - - async def read(self) -> bytes: - """ - Read and return the response content. - """ - if not hasattr(self, "body"): - body = b"" - async for part in self.stream(): - body += part - self.body = body - return self.body - - async def stream(self): - """ - A byte-iterator over the decoded response content. - This will allow us to handle gzip, deflate, and brotli encoded responses. - """ - if hasattr(self, "body"): - yield self.body - else: - async for chunk in self.raw(): - yield self.decoder.decode(chunk) - yield self.decoder.flush() - - async def raw(self) -> typing.AsyncIterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - if self.is_streamed: - raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() - self.is_streamed = True - async for part in self.body_aiter(): - yield part - await self.close() - - async def close(self) -> None: - if not self.is_closed: - self.is_closed = True - if self.on_close is not None: - await self.on_close() diff --git a/httpcore/pool.py b/httpcore/pool.py new file mode 100644 index 00000000..75948477 --- /dev/null +++ b/httpcore/pool.py @@ -0,0 +1,126 @@ +import asyncio +import os +import ssl +import typing +from types import TracebackType + +from .config import ( + DEFAULT_CA_BUNDLE_PATH, + DEFAULT_POOL_LIMITS, + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + PoolLimits, + SSLConfig, + TimeoutConfig, +) +from .connections import Connection +from .datastructures import URL, Request, Response + + +class ConnectionPool: + def __init__( + self, + *, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + ): + self.ssl_config = ssl + self.timeout = timeout + self.limits = limits + self.is_closed = False + + async def request( + self, + method: str, + url: str, + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + stream: bool = False, + ) -> Response: + parsed_url = URL(url) + request = Request(method, parsed_url, headers=headers, body=body) + ssl_context = await self.get_ssl_context(parsed_url) + connection = await self.acquire_connection(parsed_url, ssl=ssl_context) + response = await connection.send(request, stream=stream) + return response + + async def acquire_connection( + self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False + ) -> Connection: + connection = Connection(timeout=self.timeout) + await connection.open(url.hostname, url.port, ssl=ssl) + return connection + + async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]: + if not url.is_secure: + return False + + if not hasattr(self, "ssl_context"): + if not self.ssl_config.verify: + self.ssl_context = self.get_ssl_context_no_verify() + else: + # Run the SSL loading in a threadpool, since it makes disk accesses. + loop = asyncio.get_event_loop() + self.ssl_context = await loop.run_in_executor( + None, self.get_ssl_context_verify + ) + + return self.ssl_context + + def get_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + context.set_default_verify_paths() + return context + + def get_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + cert = self.ssl_config.cert + verify = self.ssl_config.verify + + if isinstance(verify, bool): + ca_bundle_path = DEFAULT_CA_BUNDLE_PATH + elif os.path.exists(verify): + ca_bundle_path = verify + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(verify) + ) + + context = ssl.create_default_context() + if os.path.isfile(ca_bundle_path): + context.load_verify_locations(cafile=ca_bundle_path) + elif os.path.isdir(ca_bundle_path): + context.load_verify_locations(capath=ca_bundle_path) + + if cert is not None: + if isinstance(cert, str): + context.load_cert_chain(certfile=cert) + else: + context.load_cert_chain(certfile=cert[0], keyfile=cert[1]) + + return context + + async def close(self) -> None: + self.is_closed = True + + async def __aenter__(self) -> "ConnectionPool": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/requirements.txt b/requirements.txt index 16c77063..1baef341 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +certifi h11 # Testing @@ -9,3 +10,4 @@ mypy pytest pytest-asyncio pytest-cov +uvicorn diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..08e36702 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,34 @@ +import asyncio +import json + +import pytest +from uvicorn.config import Config +from uvicorn.main import Server + + +async def app(scope, receive, send): + assert scope['type'] == 'http' + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [ + [b'content-type', b'text/plain'], + ] + }) + await send({ + 'type': 'http.response.body', + 'body': b'Hello, world!', + }) + + +@pytest.fixture +async def server(): + config = Config(app=app, lifespan="off") + server = Server(config=config) + task = asyncio.ensure_future(server.serve()) + try: + while not server.started: + await asyncio.sleep(0.0001) + yield server + finally: + task.cancel() diff --git a/tests/test_api.py b/tests/test_api.py index e69de29b..c24e7373 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -0,0 +1,38 @@ +import pytest +import httpcore + + +@pytest.mark.asyncio +async def test_get(server): + async with httpcore.ConnectionPool() as http: + response = await http.request('GET', "http://127.0.0.1:8000/") + assert response.status_code == 200 + assert response.body == b'Hello, world!' + + +@pytest.mark.asyncio +async def test_post(server): + async with httpcore.ConnectionPool() as http: + response = await http.request('POST', "http://127.0.0.1:8000/", body=b"Hello, world!") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_stream_response(server): + async with httpcore.ConnectionPool() as http: + response = await http.request('GET', "http://127.0.0.1:8000/", stream=True) + assert response.status_code == 200 + assert not hasattr(response, 'body') + body = await response.read() + assert body == b'Hello, world!' + + +@pytest.mark.asyncio +async def test_stream_request(server): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + async with httpcore.ConnectionPool() as http: + response = await http.request('POST', "http://127.0.0.1:8000/", body=hello_world()) + assert response.status_code == 200 diff --git a/tests/test_responses.py b/tests/test_responses.py index 19a387ba..ae754b40 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -3,17 +3,21 @@ import pytest import httpcore -class MockRequests(httpcore.PoolManager): - async def request(self, method, url, *, headers = (), body = b'', stream = False) -> httpcore.Response: +class MockHTTP(httpcore.ConnectionPool): + async def request( + self, method, url, *, headers=(), body=b"", stream=False + ) -> httpcore.Response: if stream: + async def streaming_body(): yield b"Hello, " yield b"world!" - return httpcore.Response(200, body=streaming_body) + + return httpcore.Response(200, body=streaming_body()) return httpcore.Response(200, body=b"Hello, world!") -http = MockRequests() +http = MockHTTP() @pytest.mark.asyncio @@ -47,7 +51,7 @@ async def test_stream_response(): assert response.body == b"Hello, world!" assert response.is_closed - body = b'' + body = b"" async for part in response.stream(): body += part @@ -61,7 +65,7 @@ async def test_read_streaming_response(): response = await http.request("GET", "http://example.com", stream=True) assert response.status_code == 200 - assert not hasattr(response, 'body') + assert not hasattr(response, "body") assert not response.is_closed body = await response.read() @@ -76,15 +80,15 @@ async def test_stream_streaming_response(): response = await http.request("GET", "http://example.com", stream=True) assert response.status_code == 200 - assert not hasattr(response, 'body') + assert not hasattr(response, "body") assert not response.is_closed - body = b'' + body = b"" async for part in response.stream(): body += part assert body == b"Hello, world!" - assert not hasattr(response, 'body') + assert not hasattr(response, "body") assert response.is_closed @@ -92,13 +96,14 @@ async def test_stream_streaming_response(): async def test_cannot_read_after_stream_consumed(): response = await http.request("GET", "http://example.com", stream=True) - body = b'' + body = b"" async for part in response.stream(): body += part with pytest.raises(httpcore.StreamConsumed): await response.read() + @pytest.mark.asyncio async def test_cannot_read_after_response_closed(): response = await http.request("GET", "http://example.com", stream=True)