]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support alternate concurrency backends from client & rejig tests 60/head
authorTom Christie <tom@tomchristie.com>
Tue, 14 May 2019 13:11:13 +0000 (14:11 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 14 May 2019 13:11:13 +0000 (14:11 +0100)
httpcore/__init__.py
httpcore/client.py
tests/dispatch/__init__.py [new file with mode: 0644]
tests/dispatch/test_http2.py
tests/dispatch/utils.py [new file with mode: 0644]

index c40ba329dd207c70813003ad26ae83d8a0eb7ed5..5821b2ad37dac50aae8797076c25bf4019dd7718 100644 (file)
@@ -1,4 +1,5 @@
 from .client import AsyncClient, Client
+from .concurrency import AsyncioBackend
 from .config import PoolLimits, SSLConfig, TimeoutConfig
 from .dispatch.connection import HTTPConnection
 from .dispatch.connection_pool import ConnectionPool
@@ -19,7 +20,7 @@ from .exceptions import (
     Timeout,
     TooManyRedirects,
 )
-from .interfaces import BaseReader, BaseWriter, Dispatcher, Protocol
+from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
 from .models import URL, Headers, Origin, QueryParams, Request, Response
 from .status_codes import codes
 
index ec755750df1827dede2430e84ee7228bb7b3034e..cb8ead9fbdad7c4f09972c943fad2509c81afffb 100644 (file)
@@ -13,7 +13,7 @@ from .config import (
 )
 from .dispatch.connection_pool import ConnectionPool
 from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
-from .interfaces import Dispatcher
+from .interfaces import ConcurrencyBackend, Dispatcher
 from .models import (
     URL,
     Headers,
@@ -36,9 +36,12 @@ class AsyncClient:
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         dispatch: Dispatcher = None,
+        backend: ConcurrencyBackend = None,
     ):
         if dispatch is None:
-            dispatch = ConnectionPool(ssl=ssl, timeout=timeout, pool_limits=pool_limits)
+            dispatch = ConnectionPool(
+                ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
+            )
 
         self.max_redirects = max_redirects
         self.dispatch = dispatch
@@ -377,6 +380,7 @@ class Client:
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         dispatch: Dispatcher = None,
+        backend: ConcurrencyBackend = None,
     ) -> None:
         self._client = AsyncClient(
             ssl=ssl,
@@ -384,6 +388,7 @@ class Client:
             pool_limits=pool_limits,
             max_redirects=max_redirects,
             dispatch=dispatch,
+            backend=backend,
         )
         self._loop = asyncio.new_event_loop()
 
diff --git a/tests/dispatch/__init__.py b/tests/dispatch/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index 940c3ee11f7875523bd55c8545432e11288a671f..1c330777ed5ef3b7f48d55130f1d2e6eea9ba138 100644 (file)
 import json
 
-import h2.config
-import h2.connection
-import h2.events
 import pytest
 
-from httpcore import BaseReader, BaseWriter, HTTP2Connection, Request
+from httpcore import Client, Response
+from .utils import MockHTTP2Backend
 
 
-class MockServer(BaseReader, BaseWriter):
-    """
-    This class exposes Reader and Writer style interfaces
-    """
+def app(request):
+    content = json.dumps({
+        "method": request.method,
+        "path": request.url.path,
+        "body": request.content.decode(),
+    }).encode()
+    headers = {'Content-Length': str(len(content))}
+    return Response(200, headers=headers, content=content)
 
-    def __init__(self):
-        config = h2.config.H2Configuration(client_side=False)
-        self.conn = h2.connection.H2Connection(config=config)
-        self.buffer = b""
-        self.requests = {}
 
-    # BaseReader interface
+def test_http2_get_request():
+    backend = MockHTTP2Backend(app=app)
 
-    async def read(self, n, timeout) -> bytes:
-        send, self.buffer = self.buffer[:n], self.buffer[n:]
-        return send
-
-    # BaseWriter interface
-
-    def write_no_block(self, data: bytes) -> None:
-        events = self.conn.receive_data(data)
-        self.buffer += self.conn.data_to_send()
-        for event in events:
-            if isinstance(event, h2.events.RequestReceived):
-                self.request_received(event.headers, event.stream_id)
-            elif isinstance(event, h2.events.DataReceived):
-                self.receive_data(event.data, event.stream_id)
-            elif isinstance(event, h2.events.StreamEnded):
-                self.stream_complete(event.stream_id)
-
-    async def write(self, data: bytes, timeout) -> None:
-        self.write_no_block(data)
-
-    async def close(self) -> None:
-        pass
-
-    # Server implementation
-
-    def request_received(self, headers, stream_id):
-        if stream_id not in self.requests:
-            self.requests[stream_id] = []
-        self.requests[stream_id].append({"headers": headers, "data": b""})
-
-    def receive_data(self, data, stream_id):
-        self.requests[stream_id][-1]["data"] += data
-
-    def stream_complete(self, stream_id):
-        request = self.requests[stream_id].pop(0)
-        if not self.requests[stream_id]:
-            del self.requests[stream_id]
-
-        request_headers = dict(request["headers"])
-        request_data = request["data"]
-
-        response_body = json.dumps(
-            {
-                "method": request_headers[b":method"].decode(),
-                "path": request_headers[b":path"].decode(),
-                "body": request_data.decode(),
-            }
-        ).encode()
-
-        response_headers = (
-            (b":status", b"200"),
-            (b"content-length", str(len(response_body)).encode()),
-        )
-        self.conn.send_headers(stream_id, response_headers)
-        self.conn.send_data(stream_id, response_body, end_stream=True)
-        self.buffer += self.conn.data_to_send()
-
-
-@pytest.mark.asyncio
-async def test_http2_get_request():
-    server = MockServer()
-    conn = HTTP2Connection(reader=server, writer=server)
-    request = Request("GET", "http://example.org")
-    request.prepare()
-
-    response = await conn.send(request)
+    with Client(backend=backend) as client:
+        response = client.get("http://example.org")
 
     assert response.status_code == 200
     assert json.loads(response.content) == {"method": "GET", "path": "/", "body": ""}
 
 
-@pytest.mark.asyncio
-async def test_http2_post_request():
-    server = MockServer()
-    conn = HTTP2Connection(reader=server, writer=server)
-    request = Request("POST", "http://example.org", data=b"<data>")
-    request.prepare()
+def test_http2_post_request():
+    backend = MockHTTP2Backend(app=app)
 
-    response = await conn.send(request)
+    with Client(backend=backend) as client:
+        response = client.post("http://example.org", data=b"<data>")
 
     assert response.status_code == 200
     assert json.loads(response.content) == {
@@ -109,21 +40,13 @@ async def test_http2_post_request():
     }
 
 
-@pytest.mark.asyncio
-async def test_http2_multiple_requests():
-    server = MockServer()
-    conn = HTTP2Connection(reader=server, writer=server)
-    request_1 = Request("GET", "http://example.org/1")
-    request_2 = Request("GET", "http://example.org/2")
-    request_3 = Request("GET", "http://example.org/3")
+def test_http2_multiple_requests():
+    backend = MockHTTP2Backend(app=app)
 
-    request_1.prepare()
-    request_2.prepare()
-    request_3.prepare()
-
-    response_1 = await conn.send(request_1)
-    response_2 = await conn.send(request_2)
-    response_3 = await conn.send(request_3)
+    with Client(backend=backend) as client:
+        response_1 = client.get("http://example.org/1")
+        response_2 = client.get("http://example.org/2")
+        response_3 = client.get("http://example.org/3")
 
     assert response_1.status_code == 200
     assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
@@ -133,5 +56,3 @@ async def test_http2_multiple_requests():
 
     assert response_3.status_code == 200
     assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
-
-    await conn.close()
diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py
new file mode 100644 (file)
index 0000000..c4df4b7
--- /dev/null
@@ -0,0 +1,108 @@
+import ssl
+import typing
+
+import h2.config
+import h2.connection
+import h2.events
+
+from httpcore import AsyncioBackend, BaseReader, BaseWriter, Protocol, Request, TimeoutConfig
+
+
+class MockHTTP2Backend(AsyncioBackend):
+    def __init__(self, app):
+        self.app = app
+
+    async def connect(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
+        server = MockHTTP2Server(self.app)
+        return (server, server, Protocol.HTTP_2)
+
+
+class MockHTTP2Server(BaseReader, BaseWriter):
+    """
+    This class exposes Reader and Writer style interfaces.
+    """
+
+    def __init__(self, app):
+        config = h2.config.H2Configuration(client_side=False)
+        self.conn = h2.connection.H2Connection(config=config)
+        self.app = app
+        self.buffer = b""
+        self.requests = {}
+
+    # BaseReader interface
+
+    async def read(self, n, timeout) -> bytes:
+        send, self.buffer = self.buffer[:n], self.buffer[n:]
+        return send
+
+    # BaseWriter interface
+
+    def write_no_block(self, data: bytes) -> None:
+        events = self.conn.receive_data(data)
+        self.buffer += self.conn.data_to_send()
+        for event in events:
+            if isinstance(event, h2.events.RequestReceived):
+                self.request_received(event.headers, event.stream_id)
+            elif isinstance(event, h2.events.DataReceived):
+                self.receive_data(event.data, event.stream_id)
+            elif isinstance(event, h2.events.StreamEnded):
+                self.stream_complete(event.stream_id)
+
+    async def write(self, data: bytes, timeout) -> None:
+        self.write_no_block(data)
+
+    async def close(self) -> None:
+        pass
+
+    # Server implementation
+
+    def request_received(self, headers, stream_id):
+        """
+        Handler for when the initial part of the HTTP request is received.
+        """
+        if stream_id not in self.requests:
+            self.requests[stream_id] = []
+        self.requests[stream_id].append({"headers": headers, "data": b""})
+
+    def receive_data(self, data, stream_id):
+        """
+        Handler for when a data part of the HTTP request is received.
+        """
+        self.requests[stream_id][-1]["data"] += data
+
+    def stream_complete(self, stream_id):
+        """
+        Handler for when the HTTP request is completed.
+        """
+        request = self.requests[stream_id].pop(0)
+        if not self.requests[stream_id]:
+            del self.requests[stream_id]
+
+        headers_dict = dict(request["headers"])
+
+        method = headers_dict[b":method"].decode("ascii")
+        url = "%s://%s%s" % (
+            headers_dict[b":scheme"].decode("ascii"),
+            headers_dict[b":authority"].decode("ascii"),
+            headers_dict[b":path"].decode("ascii"),
+        )
+        headers = [(k, v) for k, v in request["headers"] if not k.startswith(b":")]
+        data = request["data"]
+
+        # Call out to the app.
+        request = Request(method, url, headers=headers, data=data)
+        response = self.app(request)
+
+        # Write the response to the buffer.
+        status_code_bytes = str(int(response.status_code)).encode("ascii")
+        response_headers = [(b":status", status_code_bytes)] + response.headers.raw
+
+        self.conn.send_headers(stream_id, response_headers)
+        self.conn.send_data(stream_id, response.content, end_stream=True)
+        self.buffer += self.conn.data_to_send()