]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add support for Mount API (#1362)
authorTom Christie <tom@tomchristie.com>
Tue, 24 Nov 2020 10:35:51 +0000 (10:35 +0000)
committerGitHub <noreply@github.com>
Tue, 24 Nov 2020 10:35:51 +0000 (10:35 +0000)
* Add support for Mount API

* Add test cases

* Add test case for all: mounted transport

* Use 'transport' variable, in preference to 'proxy'

* Add docs for mounted transports

docs/advanced.md
httpx/_client.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/client/test_proxies.py

index 1a4513b64868c92fce7ede2770065b3600865d16..3e009b1aacb8e484f378e1aa172510df8571531a 100644 (file)
@@ -1040,12 +1040,13 @@ class HelloWorldTransport(httpcore.SyncHTTPTransport):
     A mock transport that always returns a JSON "Hello, world!" response.
     """
 
-    def request(self, method, url, headers=None, stream=None, timeout=None):
+    def request(self, method, url, headers=None, stream=None, ext=None):
         message = {"text": "Hello, world!"}
         content = json.dumps(message).encode("utf-8")
         stream = httpcore.PlainByteStream(content)
         headers = [(b"content-type", b"application/json")]
-        return b"HTTP/1.1", 200, b"OK", headers, stream
+        ext = {"http_version": b"HTTP/1.1"}
+        return 200, headers, stream, ext
 ```
 
 Which we can use in the same way:
@@ -1057,3 +1058,54 @@ Which we can use in the same way:
 >>> response.json()
 {"text": "Hello, world!"}
 ```
+
+### Mounting transports
+
+You can also mount transports against given schemes or domains, to control
+which transport an outgoing request should be routed via, with [the same style
+used for specifying proxy routing](#routing).
+
+```python
+import httpcore
+import httpx
+
+class HTTPSRedirectTransport(httpcore.SyncHTTPTransport):
+    """
+    A transport that always redirects to HTTPS.
+    """
+
+    def request(self, method, url, headers=None, stream=None, ext=None):
+        scheme, host, port, path = url
+        if port is None:
+            location = b"https://%s%s" % (host, path)
+        else:
+            location = b"https://%s:%d%s" % (host, port, path)
+        stream = httpcore.PlainByteStream(b"")
+        headers = [(b"location", location)]
+        ext = {"http_version": b"HTTP/1.1"}
+        return 303, headers, stream, ext
+
+
+# A client where any `http` requests are always redirected to `https`
+mounts = {'http://': HTTPSRedirectTransport()}
+client = httpx.Client(mounts=mounts)
+```
+
+A couple of other sketches of how you might take advantage of mounted transports...
+
+Mocking requests to a given domain:
+
+```python
+# All requests to "example.org" should be mocked out.
+# Other requests occur as usual.
+mounts = {"all://example.org": MockTransport()}
+client = httpx.Client(mounts=mounts)
+```
+
+Adding support for custom schemes:
+
+```python
+# Support URLs like "file:///Users/sylvia_green/websites/new_client/index.html"
+mounts = {"file://": FileSystemTransport()}
+client = httpx.Client(mounts=mounts)
+```
index 2e9bd48dca6774a5906c64cb9f015d23fd7112fe..2d764b0229f51b42f80da49abc9c3290f7d6f048 100644 (file)
@@ -87,7 +87,7 @@ class BaseClient:
         cookies: CookieTypes = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
+        event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
         base_url: URLTypes = "",
         trust_env: bool = True,
     ):
@@ -561,11 +561,12 @@ class Client(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         proxies: ProxiesTypes = None,
+        mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
+        event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
         base_url: URLTypes = "",
         transport: httpcore.SyncHTTPTransport = None,
         app: typing.Callable = None,
@@ -611,7 +612,7 @@ class Client(BaseClient):
             app=app,
             trust_env=trust_env,
         )
-        self._proxies: typing.Dict[
+        self._mounts: typing.Dict[
             URLPattern, typing.Optional[httpcore.SyncHTTPTransport]
         ] = {
             URLPattern(key): None
@@ -626,7 +627,12 @@ class Client(BaseClient):
             )
             for key, proxy in proxy_map.items()
         }
-        self._proxies = dict(sorted(self._proxies.items()))
+        if mounts is not None:
+            self._mounts.update(
+                {URLPattern(key): transport for key, transport in mounts.items()}
+            )
+
+        self._mounts = dict(sorted(self._mounts.items()))
 
     def _init_transport(
         self,
@@ -681,7 +687,7 @@ class Client(BaseClient):
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
         """
-        for pattern, transport in self._proxies.items():
+        for pattern, transport in self._mounts.items():
             if pattern.matches(url):
                 return self._transport if transport is None else transport
 
@@ -1109,17 +1115,17 @@ class Client(BaseClient):
             self._state = ClientState.CLOSED
 
             self._transport.close()
-            for proxy in self._proxies.values():
-                if proxy is not None:
-                    proxy.close()
+            for transport in self._mounts.values():
+                if transport is not None:
+                    transport.close()
 
     def __enter__(self: T) -> T:
         self._state = ClientState.OPENED
 
         self._transport.__enter__()
-        for proxy in self._proxies.values():
-            if proxy is not None:
-                proxy.__enter__()
+        for transport in self._mounts.values():
+            if transport is not None:
+                transport.__enter__()
         return self
 
     def __exit__(
@@ -1131,9 +1137,9 @@ class Client(BaseClient):
         self._state = ClientState.CLOSED
 
         self._transport.__exit__(exc_type, exc_value, traceback)
-        for proxy in self._proxies.values():
-            if proxy is not None:
-                proxy.__exit__(exc_type, exc_value, traceback)
+        for transport in self._mounts.values():
+            if transport is not None:
+                transport.__exit__(exc_type, exc_value, traceback)
 
     def __del__(self) -> None:
         self.close()
@@ -1198,11 +1204,12 @@ class AsyncClient(BaseClient):
         cert: CertTypes = None,
         http2: bool = False,
         proxies: ProxiesTypes = None,
+        mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
+        event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
         base_url: URLTypes = "",
         transport: httpcore.AsyncHTTPTransport = None,
         app: typing.Callable = None,
@@ -1249,7 +1256,7 @@ class AsyncClient(BaseClient):
             trust_env=trust_env,
         )
 
-        self._proxies: typing.Dict[
+        self._mounts: typing.Dict[
             URLPattern, typing.Optional[httpcore.AsyncHTTPTransport]
         ] = {
             URLPattern(key): None
@@ -1264,7 +1271,11 @@ class AsyncClient(BaseClient):
             )
             for key, proxy in proxy_map.items()
         }
-        self._proxies = dict(sorted(self._proxies.items()))
+        if mounts is not None:
+            self._mounts.update(
+                {URLPattern(key): transport for key, transport in mounts.items()}
+            )
+        self._mounts = dict(sorted(self._mounts.items()))
 
     def _init_transport(
         self,
@@ -1319,7 +1330,7 @@ class AsyncClient(BaseClient):
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
         """
-        for pattern, transport in self._proxies.items():
+        for pattern, transport in self._mounts.items():
             if pattern.matches(url):
                 return self._transport if transport is None else transport
 
@@ -1499,7 +1510,7 @@ class AsyncClient(BaseClient):
         await timer.async_start()
 
         with map_exceptions(HTTPCORE_EXC_MAP, request=request):
-            (status_code, headers, stream, ext,) = await transport.arequest(
+            (status_code, headers, stream, ext) = await transport.arequest(
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
@@ -1750,7 +1761,7 @@ class AsyncClient(BaseClient):
             self._state = ClientState.CLOSED
 
             await self._transport.aclose()
-            for proxy in self._proxies.values():
+            for proxy in self._mounts.values():
                 if proxy is not None:
                     await proxy.aclose()
 
@@ -1758,7 +1769,7 @@ class AsyncClient(BaseClient):
         self._state = ClientState.OPENED
 
         await self._transport.__aenter__()
-        for proxy in self._proxies.values():
+        for proxy in self._mounts.values():
             if proxy is not None:
                 await proxy.__aenter__()
         return self
@@ -1772,7 +1783,7 @@ class AsyncClient(BaseClient):
         self._state = ClientState.CLOSED
 
         await self._transport.__aexit__(exc_type, exc_value, traceback)
-        for proxy in self._proxies.values():
+        for proxy in self._mounts.values():
             if proxy is not None:
                 await proxy.__aexit__(exc_type, exc_value, traceback)
 
index 44ff90fe511707b2dc324bb0186d396a7acacb87..696f202cff9aa36870b02a1668eae6751ff288a7 100644 (file)
@@ -1,3 +1,4 @@
+import typing
 from datetime import timedelta
 
 import httpcore
@@ -188,15 +189,8 @@ async def test_context_managed_transport():
             await super().__aexit__(*args)
             self.events.append("transport.__aexit__")
 
-    # Note that we're including 'proxies' here to *also* run through the
-    # proxy context management, although we can't easily test that at the
-    # moment, since we can't add proxies as transport instances.
-    #
-    # Once we have a more generalised Mount API we'll be able to remove this
-    # in favour of ensuring all mounts are context managed, which will
-    # also neccessarily include proxies.
     transport = Transport()
-    async with httpx.AsyncClient(transport=transport, proxies="http://www.example.com"):
+    async with httpx.AsyncClient(transport=transport):
         pass
 
     assert transport.events == [
@@ -206,6 +200,47 @@ async def test_context_managed_transport():
     ]
 
 
+@pytest.mark.usefixtures("async_environment")
+async def test_context_managed_transport_and_mount():
+    class Transport(httpcore.AsyncHTTPTransport):
+        def __init__(self, name: str):
+            self.name: str = name
+            self.events: typing.List[str] = []
+
+        async def aclose(self):
+            # The base implementation of httpcore.AsyncHTTPTransport just
+            # calls into `.aclose`, so simple transport cases can just override
+            # this method for any cleanup, where more complex cases
+            # might want to additionally override `__aenter__`/`__aexit__`.
+            self.events.append(f"{self.name}.aclose")
+
+        async def __aenter__(self):
+            await super().__aenter__()
+            self.events.append(f"{self.name}.__aenter__")
+
+        async def __aexit__(self, *args):
+            await super().__aexit__(*args)
+            self.events.append(f"{self.name}.__aexit__")
+
+    transport = Transport(name="transport")
+    mounted = Transport(name="mounted")
+    async with httpx.AsyncClient(
+        transport=transport, mounts={"http://www.example.org": mounted}
+    ):
+        pass
+
+    assert transport.events == [
+        "transport.__aenter__",
+        "transport.aclose",
+        "transport.__aexit__",
+    ]
+    assert mounted.events == [
+        "mounted.__aenter__",
+        "mounted.aclose",
+        "mounted.__aexit__",
+    ]
+
+
 def hello_world(request):
     return httpx.Response(200, text="Hello, world!")
 
@@ -242,3 +277,28 @@ async def test_deleting_unclosed_async_client_causes_warning():
     await client.get("http://example.com")
     with pytest.warns(UserWarning):
         del client
+
+
+def unmounted(request: httpx.Request) -> httpx.Response:
+    data = {"app": "unmounted"}
+    return httpx.Response(200, json=data)
+
+
+def mounted(request: httpx.Request) -> httpx.Response:
+    data = {"app": "mounted"}
+    return httpx.Response(200, json=data)
+
+
+@pytest.mark.usefixtures("async_environment")
+async def test_mounted_transport():
+    transport = MockTransport(unmounted)
+    mounts = {"custom://": MockTransport(mounted)}
+
+    async with httpx.AsyncClient(transport=transport, mounts=mounts) as client:
+        response = await client.get("https://www.example.com")
+        assert response.status_code == 200
+        assert response.json() == {"app": "unmounted"}
+
+        response = await client.get("custom://www.example.com")
+        assert response.status_code == 200
+        assert response.json() == {"app": "mounted"}
index a41f4232fbb7203c67f1c39396df65efb4a4d335..3675730b309ddbe87d9521e705d6775d29a1059b 100644 (file)
@@ -1,3 +1,4 @@
+import typing
 from datetime import timedelta
 
 import httpcore
@@ -227,15 +228,8 @@ def test_context_managed_transport():
             super().__exit__(*args)
             self.events.append("transport.__exit__")
 
-    # Note that we're including 'proxies' here to *also* run through the
-    # proxy context management, although we can't easily test that at the
-    # moment, since we can't add proxies as transport instances.
-    #
-    # Once we have a more generalised Mount API we'll be able to remove this
-    # in favour of ensuring all mounts are context managed, which will
-    # also neccessarily include proxies.
     transport = Transport()
-    with httpx.Client(transport=transport, proxies="http://www.example.com"):
+    with httpx.Client(transport=transport):
         pass
 
     assert transport.events == [
@@ -245,6 +239,44 @@ def test_context_managed_transport():
     ]
 
 
+def test_context_managed_transport_and_mount():
+    class Transport(httpcore.SyncHTTPTransport):
+        def __init__(self, name: str):
+            self.name: str = name
+            self.events: typing.List[str] = []
+
+        def close(self):
+            # The base implementation of httpcore.SyncHTTPTransport just
+            # calls into `.close`, so simple transport cases can just override
+            # this method for any cleanup, where more complex cases
+            # might want to additionally override `__enter__`/`__exit__`.
+            self.events.append(f"{self.name}.close")
+
+        def __enter__(self):
+            super().__enter__()
+            self.events.append(f"{self.name}.__enter__")
+
+        def __exit__(self, *args):
+            super().__exit__(*args)
+            self.events.append(f"{self.name}.__exit__")
+
+    transport = Transport(name="transport")
+    mounted = Transport(name="mounted")
+    with httpx.Client(transport=transport, mounts={"http://www.example.org": mounted}):
+        pass
+
+    assert transport.events == [
+        "transport.__enter__",
+        "transport.close",
+        "transport.__exit__",
+    ]
+    assert mounted.events == [
+        "mounted.__enter__",
+        "mounted.close",
+        "mounted.__exit__",
+    ]
+
+
 def hello_world(request):
     return httpx.Response(200, text="Hello, world!")
 
@@ -300,3 +332,38 @@ def test_raw_client_header():
         ["User-Agent", f"python-httpx/{httpx.__version__}"],
         ["Example-Header", "example-value"],
     ]
+
+
+def unmounted(request: httpx.Request) -> httpx.Response:
+    data = {"app": "unmounted"}
+    return httpx.Response(200, json=data)
+
+
+def mounted(request: httpx.Request) -> httpx.Response:
+    data = {"app": "mounted"}
+    return httpx.Response(200, json=data)
+
+
+def test_mounted_transport():
+    transport = MockTransport(unmounted)
+    mounts = {"custom://": MockTransport(mounted)}
+
+    client = httpx.Client(transport=transport, mounts=mounts)
+
+    response = client.get("https://www.example.com")
+    assert response.status_code == 200
+    assert response.json() == {"app": "unmounted"}
+
+    response = client.get("custom://www.example.com")
+    assert response.status_code == 200
+    assert response.json() == {"app": "mounted"}
+
+
+def test_all_mounted_transport():
+    mounts = {"all://": MockTransport(mounted)}
+
+    client = httpx.Client(mounts=mounts)
+
+    response = client.get("https://www.example.com")
+    assert response.status_code == 200
+    assert response.json() == {"app": "mounted"}
index 6d30438362f419c2c546ad727cebc5deacc63169..a2d21e9429e50da9bb79a05d09bd53b7e747d5c2 100644 (file)
@@ -41,12 +41,12 @@ def test_proxies_parameter(proxies, expected_proxies):
 
     for proxy_key, url in expected_proxies:
         pattern = URLPattern(proxy_key)
-        assert pattern in client._proxies
-        proxy = client._proxies[pattern]
+        assert pattern in client._mounts
+        proxy = client._mounts[pattern]
         assert isinstance(proxy, httpcore.SyncHTTPProxy)
         assert proxy.proxy_origin == url_to_origin(url)
 
-    assert len(expected_proxies) == len(client._proxies)
+    assert len(expected_proxies) == len(client._mounts)
 
 
 PROXY_URL = "http://[::1]"