]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add client.request
authorTom Christie <tom@tomchristie.com>
Tue, 23 Apr 2019 08:34:41 +0000 (09:34 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 23 Apr 2019 08:34:41 +0000 (09:34 +0100)
httpcore/connections.py
httpcore/datastructures.py
httpcore/pool.py
tests/test_api.py
tests/test_connections.py
tests/test_pool.py
tests/test_timeouts.py

index c9b695730ddfca78c13d700377379e726607ede3..6ddea6da0be2a3f387b65bb5c926198262da0cc1 100644 (file)
@@ -17,7 +17,7 @@ H11Event = typing.Union[
 ]
 
 
-class Connection:
+class Connection(Client):
     def __init__(
         self,
         origin: typing.Union[str, Origin],
@@ -43,7 +43,6 @@ class Connection:
         *,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
-        stream: bool = False,
     ) -> Response:
         assert request.url.origin == self.origin
 
@@ -85,7 +84,7 @@ class Connection:
         status_code = event.status_code
         headers = event.headers
         body = self._body_iter(timeout)
-        response = Response(
+        return Response(
             status_code=status_code,
             reason=reason,
             headers=headers,
@@ -93,15 +92,6 @@ class Connection:
             on_close=self._release,
         )
 
-        if not stream:
-            # Read the response body.
-            try:
-                await response.read()
-            finally:
-                await response.close()
-
-        return response
-
     async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
         ssl_context = await ssl.load_ssl_context() if self.origin.is_secure else None
 
index 09c288bc2fa07afb6c85a5d1303fa7362460b1d4..7389d451d8a59b60f4f7929ed5b3c71e6ed124b2 100644 (file)
@@ -1,5 +1,6 @@
 import http
 import typing
+from types import TracebackType
 from urllib.parse import urlsplit
 
 from .config import SSLConfig, TimeoutConfig
@@ -251,7 +252,13 @@ class Client:
         stream: bool = False,
     ) -> Response:
         request = Request(method, url, headers=headers, body=body)
-        return await self.send(request, ssl=ssl, timeout=timeout, stream=stream)
+        response = await self.send(request, ssl=ssl, timeout=timeout)
+        if not stream:
+            try:
+                await response.read()
+            finally:
+                await response.close()
+        return response
 
     async def send(
         self,
@@ -259,9 +266,19 @@ class Client:
         *,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
-        stream: bool = False,
     ) -> Response:
         raise NotImplementedError()  # pragma: nocover
 
     async def close(self) -> None:
         raise NotImplementedError()  # pragma: nocover
+
+    async def __aenter__(self) -> "Client":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.close()
index 3f9ecd70213914dfe2620a10773a543dbbd67053..a13657185a436328a304453590b405660abed9f6 100644 (file)
@@ -1,6 +1,5 @@
 import asyncio
 import typing
-from types import TracebackType
 
 from .config import (
     DEFAULT_CA_BUNDLE_PATH,
@@ -16,21 +15,7 @@ from .datastructures import Client, Origin, Request, Response
 from .exceptions import PoolTimeout
 
 
-class ConnectionSemaphore:
-    def __init__(self, max_connections: int = None):
-        if max_connections is not None:
-            self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
-
-    async def acquire(self) -> None:
-        if hasattr(self, "semaphore"):
-            await self.semaphore.acquire()
-
-    def release(self) -> None:
-        if hasattr(self, "semaphore"):
-            self.semaphore.release()
-
-
-class ConnectionPool:
+class ConnectionPool(Client):
     def __init__(
         self,
         *,
@@ -57,12 +42,9 @@ class ConnectionPool:
         *,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
-        stream: bool = False,
     ) -> Response:
         connection = await self.acquire_connection(request.url.origin, timeout=timeout)
-        response = await connection.send(
-            request, ssl=ssl, timeout=timeout, stream=stream
-        )
+        response = await connection.send(request, ssl=ssl, timeout=timeout)
         return response
 
     @property
@@ -121,13 +103,16 @@ class ConnectionPool:
     async def close(self) -> None:
         self.is_closed = True
 
-    async def __aenter__(self) -> "ConnectionPool":
-        return self
 
-    async def __aexit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        await self.close()
+class ConnectionSemaphore:
+    def __init__(self, max_connections: int = None):
+        if max_connections is not None:
+            self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
+
+    async def acquire(self) -> None:
+        if hasattr(self, "semaphore"):
+            await self.semaphore.acquire()
+
+    def release(self) -> None:
+        if hasattr(self, "semaphore"):
+            self.semaphore.release()
index 30199c93de4641264540bd028e5a4d714cec9044..6b80587dbc4f475f27175ba9175093c06b257431 100644 (file)
@@ -5,28 +5,25 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request)
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_post(server):
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request(
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request(
             "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
         )
-        response = await client.send(request)
     assert response.status_code == 200
 
 
 @pytest.mark.asyncio
 async def test_stream_response(server):
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request, stream=True)
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
     assert response.status_code == 200
     assert not hasattr(response, "body")
     body = await response.read()
@@ -39,7 +36,8 @@ async def test_stream_request(server):
         yield b"Hello, "
         yield b"world!"
 
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("POST", "http://127.0.0.1:8000/", body=hello_world())
-        response = await client.send(request)
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request(
+            "POST", "http://127.0.0.1:8000/", body=hello_world()
+        )
     assert response.status_code == 200
index 958c1b523093c45f25bdee3468a24b68cc198f32..5cfca6116ab81aaadccb1010ecd8d2cfc36d710b 100644 (file)
@@ -5,18 +5,16 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
-    client = httpcore.Connection(origin="http://127.0.0.1:8000/")
-    request = httpcore.Request(method="GET", url="http://127.0.0.1:8000/")
-    response = await client.send(request)
+    http = httpcore.Connection(origin="http://127.0.0.1:8000/")
+    response = await http.request("GET", "http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_post(server):
-    client = httpcore.Connection(origin="http://127.0.0.1:8000/")
-    request = httpcore.Request(
-        method="POST", url="http://127.0.0.1:8000/", body=b"Hello, world!"
+    http = httpcore.Connection(origin="http://127.0.0.1:8000/")
+    response = await http.request(
+        "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
     )
-    response = await client.send(request)
     assert response.status_code == 200
index de25fe65a792e4d206865459618754dc407ef38b..77a221575441fda934bd4799ed8a4a33b00726f0 100644 (file)
@@ -8,16 +8,14 @@ async def test_keepalive_connections(server):
     """
     Connections should default to staying in a keep-alive state.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -25,16 +23,14 @@ async def test_differing_connection_keys(server):
     """
     Connnections to differing connection keys should result in multiple connections.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
-        request = httpcore.Request("GET", "http://localhost:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 2
+        response = await http.request("GET", "http://localhost:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 2
 
 
 @pytest.mark.asyncio
@@ -44,16 +40,14 @@ async def test_soft_limit(server):
     """
     limits = httpcore.PoolLimits(soft_limit=1)
 
-    async with httpcore.ConnectionPool(limits=limits) as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool(limits=limits) as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
-        request = httpcore.Request("GET", "http://localhost:8000/")
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+        response = await http.request("GET", "http://localhost:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -61,16 +55,15 @@ async def test_streaming_response_holds_connection(server):
     """
     A streaming request should hold the connection open until the response is read.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request, stream=True)
-        assert client.num_active_connections == 1
-        assert client.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 0
 
         await response.read()
 
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -78,24 +71,22 @@ async def test_multiple_concurrent_connections(server):
     """
     Multiple conncurrent requests should open multiple conncurrent connections.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response_a = await client.send(request, stream=True)
-        assert client.num_active_connections == 1
-        assert client.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as http:
+        response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 0
 
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response_b = await client.send(request, stream=True)
-        assert client.num_active_connections == 2
-        assert client.num_keepalive_connections == 0
+        response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 2
+        assert http.num_keepalive_connections == 0
 
         await response_b.read()
-        assert client.num_active_connections == 1
-        assert client.num_keepalive_connections == 1
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 1
 
         await response_a.read()
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 2
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 2
 
 
 @pytest.mark.asyncio
@@ -104,11 +95,10 @@ async def test_close_connections(server):
     Using a `Connection: close` header should close the connection.
     """
     headers = [(b"connection", b"close")]
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/", headers=headers)
-        response = await client.send(request)
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 0
 
 
 @pytest.mark.asyncio
@@ -116,13 +106,12 @@ async def test_standard_response_close(server):
     """
     A standard close should keep the connection open.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request, stream=True)
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
         await response.read()
         await response.close()
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 1
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -130,9 +119,8 @@ async def test_premature_response_close(server):
     """
     A premature close should close the connection.
     """
-    async with httpcore.ConnectionPool() as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request, stream=True)
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
         await response.close()
-        assert client.num_active_connections == 0
-        assert client.num_keepalive_connections == 0
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 0
index 5b61aee21c0403f43a6aaddb1c039fa9ad7dffef..e9003195691fb1491694fb877e6707875c1feff8 100644 (file)
@@ -7,21 +7,19 @@ import httpcore
 async def test_read_timeout(server):
     timeout = httpcore.TimeoutConfig(read_timeout=0.0001)
 
-    async with httpcore.ConnectionPool(timeout=timeout) as client:
+    async with httpcore.ConnectionPool(timeout=timeout) as http:
         with pytest.raises(httpcore.ReadTimeout):
-            request = httpcore.Request("GET", "http://127.0.0.1:8000/slow_response")
-            await client.send(request)
+            await http.request("GET", "http://127.0.0.1:8000/slow_response")
 
 
 @pytest.mark.asyncio
 async def test_connect_timeout(server):
     timeout = httpcore.TimeoutConfig(connect_timeout=0.0001)
 
-    async with httpcore.ConnectionPool(timeout=timeout) as client:
+    async with httpcore.ConnectionPool(timeout=timeout) as http:
         with pytest.raises(httpcore.ConnectTimeout):
             # See https://stackoverflow.com/questions/100841/
-            request = httpcore.Request("GET", "http://10.255.255.1/")
-            await client.send(request)
+            await http.request("GET", "http://10.255.255.1/")
 
 
 @pytest.mark.asyncio
@@ -29,12 +27,10 @@ async def test_pool_timeout(server):
     timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
     limits = httpcore.PoolLimits(hard_limit=1)
 
-    async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as client:
-        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-        response = await client.send(request, stream=True)
+    async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
 
         with pytest.raises(httpcore.PoolTimeout):
-            request = httpcore.Request("GET", "http://127.0.0.1:8000/")
-            await client.send(request)
+            await http.request("GET", "http://127.0.0.1:8000/")
 
         await response.read()