]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Context managed transports (#1218)
authorTom Christie <tom@tomchristie.com>
Wed, 26 Aug 2020 11:05:05 +0000 (12:05 +0100)
committerGitHub <noreply@github.com>
Wed, 26 Aug 2020 11:05:05 +0000 (12:05 +0100)
* Context managed transports

* Update httpx/_client.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Update httpx/_client.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Update tests/client/test_client.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Update tests/client/test_async_client.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Code comment around close/__enter__/__exit__ interaction

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_client.py
tests/client/test_async_client.py
tests/client/test_client.py

index 2d2ca9ac1612d9dadc9a6943f7930527d345bb3d..0f4110c8f0645052264a398503c3518dbcd63473 100644 (file)
@@ -1035,6 +1035,10 @@ class Client(BaseClient):
                 proxy.close()
 
     def __enter__(self) -> "Client":
+        self._transport.__enter__()
+        for proxy in self._proxies.values():
+            if proxy is not None:
+                proxy.__enter__()
         return self
 
     def __exit__(
@@ -1043,7 +1047,10 @@ class Client(BaseClient):
         exc_value: BaseException = None,
         traceback: TracebackType = None,
     ) -> None:
-        self.close()
+        self._transport.__exit__(exc_type, exc_value, traceback)
+        for proxy in self._proxies.values():
+            if proxy is not None:
+                proxy.__exit__(exc_type, exc_value, traceback)
 
 
 class AsyncClient(BaseClient):
@@ -1639,6 +1646,10 @@ class AsyncClient(BaseClient):
                 await proxy.aclose()
 
     async def __aenter__(self) -> "AsyncClient":
+        await self._transport.__aenter__()
+        for proxy in self._proxies.values():
+            if proxy is not None:
+                await proxy.__aenter__()
         return self
 
     async def __aexit__(
@@ -1647,7 +1658,10 @@ class AsyncClient(BaseClient):
         exc_value: BaseException = None,
         traceback: TracebackType = None,
     ) -> None:
-        await self.aclose()
+        await self._transport.__aexit__(exc_type, exc_value, traceback)
+        for proxy in self._proxies.values():
+            if proxy is not None:
+                await proxy.__aexit__(exc_type, exc_value, traceback)
 
 
 class StreamContextManager:
index 78669e2860bf9de984c174482df15c671d439deb..126c50e95bec93e392233a4423b6122598b2cb9d 100644 (file)
@@ -1,5 +1,6 @@
 from datetime import timedelta
 
+import httpcore
 import pytest
 
 import httpx
@@ -166,3 +167,42 @@ async def test_100_continue(server):
 
     assert response.status_code == 200
     assert response.content == data
+
+
+@pytest.mark.usefixtures("async_environment")
+async def test_context_managed_transport():
+    class Transport(httpcore.AsyncHTTPTransport):
+        def __init__(self):
+            self.events = []
+
+        async def aclose(self):
+            # The base implementation of httpcore.AsyncHTTPTransport just
+            # calls into `.aclose`, so simple transport cases can just override
+            # this method for any cleanup, where more complex cases
+            # might want to additionally override `__aenter__`/`__aexit__`.
+            self.events.append("transport.aclose")
+
+        async def __aenter__(self):
+            await super().__aenter__()
+            self.events.append("transport.__aenter__")
+
+        async def __aexit__(self, *args):
+            await super().__aexit__(*args)
+            self.events.append("transport.__aexit__")
+
+    # Note that we're including 'proxies' here to *also* run through the
+    # proxy context management, although we can't easily test that at the
+    # moment, since we can't add proxies as transport instances.
+    #
+    # Once we have a more generalised Mount API we'll be able to remove this
+    # in favour of ensuring all mounts are context managed, which will
+    # also neccessarily include proxies.
+    transport = Transport()
+    async with httpx.AsyncClient(transport=transport, proxies="http://www.example.com"):
+        pass
+
+    assert transport.events == [
+        "transport.__aenter__",
+        "transport.aclose",
+        "transport.__aexit__",
+    ]
index b3e449677d014230c2e4051be9f24b8b750f0044..7d4bdd34411604374a11813e7d259b45f0d3731d 100644 (file)
@@ -1,5 +1,6 @@
 from datetime import timedelta
 
+import httpcore
 import pytest
 
 import httpx
@@ -208,3 +209,41 @@ def test_pool_limits_deprecated():
 
     with pytest.warns(DeprecationWarning):
         httpx.AsyncClient(pool_limits=limits)
+
+
+def test_context_managed_transport():
+    class Transport(httpcore.SyncHTTPTransport):
+        def __init__(self):
+            self.events = []
+
+        def close(self):
+            # The base implementation of httpcore.SyncHTTPTransport just
+            # calls into `.close`, so simple transport cases can just override
+            # this method for any cleanup, where more complex cases
+            # might want to additionally override `__enter__`/`__exit__`.
+            self.events.append("transport.close")
+
+        def __enter__(self):
+            super().__enter__()
+            self.events.append("transport.__enter__")
+
+        def __exit__(self, *args):
+            super().__exit__(*args)
+            self.events.append("transport.__exit__")
+
+    # Note that we're including 'proxies' here to *also* run through the
+    # proxy context management, although we can't easily test that at the
+    # moment, since we can't add proxies as transport instances.
+    #
+    # Once we have a more generalised Mount API we'll be able to remove this
+    # in favour of ensuring all mounts are context managed, which will
+    # also neccessarily include proxies.
+    transport = Transport()
+    with httpx.Client(transport=transport, proxies="http://www.example.com"):
+        pass
+
+    assert transport.events == [
+        "transport.__enter__",
+        "transport.close",
+        "transport.__exit__",
+    ]