]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add support for unix domain sockets (#511)
authorJonas Lundberg <jonas@5monkeys.se>
Tue, 19 Nov 2019 22:02:08 +0000 (23:02 +0100)
committerFlorimond Manca <florimond.manca@gmail.com>
Tue, 19 Nov 2019 22:02:08 +0000 (23:02 +0100)
* Add and implement open_uds_stream in concurrency backends

* Add uds arg to BaseClient and select tcp or uds in HttpConnection

* Make open stream methods in backends more explicit

* Close sentence

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>
* Refactor uds concurrency test

* Remove redundant uds test assertions

httpx/client.py
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/conftest.py
tests/test_concurrency.py

index 45589b9e23bf96f2b54f681f3cf4e48a6a05b589..c327c2db003f1d66b904d535d578bc055be1695d 100644 (file)
@@ -74,6 +74,7 @@ class BaseClient:
         app: typing.Callable = None,
         backend: ConcurrencyBackend = None,
         trust_env: bool = True,
+        uds: str = None,
     ):
         if backend is None:
             backend = AsyncioBackend()
@@ -99,6 +100,7 @@ class BaseClient:
                 pool_limits=pool_limits,
                 backend=backend,
                 trust_env=self.trust_env,
+                uds=uds,
             )
         elif isinstance(dispatch, Dispatcher):
             async_dispatch = ThreadedDispatcher(dispatch, backend)
@@ -721,6 +723,7 @@ class Client(BaseClient):
     async requests.
     * **trust_env** - *(optional)* Enables or disables usage of environment
     variables for configuration.
+    * **uds** - *(optional)* A path to a Unix domain socket to connect through.
     """
 
     def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
index 019876e43e883cdf659fbdf5350458d8c0710bd2..a0163ed02168ad05df947880cba9b30046ed5a11 100644 (file)
@@ -275,6 +275,29 @@ class AsyncioBackend(ConcurrencyBackend):
             stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
         )
 
+    async def open_uds_stream(
+        self,
+        path: str,
+        hostname: typing.Optional[str],
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> SocketStream:
+        server_hostname = hostname if ssl_context else None
+
+        try:
+            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_unix_connection(
+                    path, ssl=ssl_context, server_hostname=server_hostname
+                ),
+                timeout.connect_timeout,
+            )
+        except asyncio.TimeoutError:
+            raise ConnectTimeout()
+
+        return SocketStream(
+            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+        )
+
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index 9d5bffde3eef585a0b1b875ad3c13a062c001e88..2109c2121aa1bedfe2e37dd7c5b0e5082e81f452 100644 (file)
@@ -124,6 +124,15 @@ class ConcurrencyBackend:
     ) -> BaseSocketStream:
         raise NotImplementedError()  # pragma: no cover
 
+    async def open_uds_stream(
+        self,
+        path: str,
+        hostname: typing.Optional[str],
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> BaseSocketStream:
+        raise NotImplementedError()  # pragma: no cover
+
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
index 5d3b50dfbba420e6f3cfecdd62547ae92be52666..c84b1e4b8b2089c8c23dacaea614565d5bc9b193 100644 (file)
@@ -191,6 +191,26 @@ class TrioBackend(ConcurrencyBackend):
 
         return SocketStream(stream=stream, timeout=timeout)
 
+    async def open_uds_stream(
+        self,
+        path: str,
+        hostname: typing.Optional[str],
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> SocketStream:
+        connect_timeout = _or_inf(timeout.connect_timeout)
+
+        with trio.move_on_after(connect_timeout) as cancel_scope:
+            stream: trio.SocketStream = await trio.open_unix_socket(path)
+            if ssl_context is not None:
+                stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
+                await stream.do_handshake()
+
+        if cancel_scope.cancelled_caught:
+            raise ConnectTimeout()
+
+        return SocketStream(stream=stream, timeout=timeout)
+
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index 91feb97676d4c9a74f40e4cf17f06393274cf7c4..0612bccb86828f77aee783ad2be47deb73ed3c28 100644 (file)
@@ -38,6 +38,7 @@ class HTTPConnection(AsyncDispatcher):
         http_versions: HTTPVersionTypes = None,
         backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
+        uds: typing.Optional[str] = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
@@ -45,6 +46,7 @@ class HTTPConnection(AsyncDispatcher):
         self.http_versions = HTTPVersionConfig(http_versions)
         self.backend = AsyncioBackend() if backend is None else backend
         self.release_func = release_func
+        self.uds = uds
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
         self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
 
@@ -84,8 +86,21 @@ class HTTPConnection(AsyncDispatcher):
         else:
             on_release = functools.partial(self.release_func, self)
 
-        logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
-        stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
+        if self.uds is None:
+            logger.trace(
+                f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
+            )
+            stream = await self.backend.open_tcp_stream(
+                host, port, ssl_context, timeout
+            )
+        else:
+            logger.trace(
+                f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
+            )
+            stream = await self.backend.open_uds_stream(
+                self.uds, host, ssl_context, timeout
+            )
+
         http_version = stream.get_http_version()
         logger.trace(f"connected http_version={http_version!r}")
 
index 189fcff699973b82c92d4f48c18f848fda29d30a..5d7d886dec4ceab518c006a40c1d4a2d07ba7539 100644 (file)
@@ -89,6 +89,7 @@ class ConnectionPool(AsyncDispatcher):
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         http_versions: HTTPVersionTypes = None,
         backend: ConcurrencyBackend = None,
+        uds: typing.Optional[str] = None,
     ):
         self.verify = verify
         self.cert = cert
@@ -97,6 +98,7 @@ class ConnectionPool(AsyncDispatcher):
         self.http_versions = http_versions
         self.is_closed = False
         self.trust_env = trust_env
+        self.uds = uds
 
         self.keepalive_connections = ConnectionStore()
         self.active_connections = ConnectionStore()
@@ -142,6 +144,7 @@ class ConnectionPool(AsyncDispatcher):
                 backend=self.backend,
                 release_func=self.release_connection,
                 trust_env=self.trust_env,
+                uds=self.uds,
             )
             logger.trace(f"new_connection connection={connection!r}")
         else:
index eaac5ef76a00d25b2f0bc55e5f436598949d69b6..42202aa180d4738c6f91b054f9e5aee00fa0620f 100644 (file)
@@ -146,3 +146,14 @@ async def test_100_continue(server, backend):
 
     assert response.status_code == 200
     assert response.content == data
+
+
+async def test_uds(uds_server, backend):
+    url = uds_server.url
+    uds = uds_server.config.uds
+    assert uds is not None
+    async with httpx.AsyncClient(backend=backend, uds=uds) as client:
+        response = await client.get(url)
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+    assert response.encoding == "iso-8859-1"
index f7be6070e916132c294bae8a01b82f2a377c1fca..5dc196d933cc67f45005cb6d33f31122a0b652e1 100644 (file)
@@ -138,6 +138,17 @@ def test_base_url(server):
     assert response.url == base_url
 
 
+def test_uds(uds_server):
+    url = uds_server.url
+    uds = uds_server.config.uds
+    assert uds is not None
+    with httpx.Client(uds=uds) as http:
+        response = http.get(url)
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+    assert response.encoding == "iso-8859-1"
+
+
 def test_merge_url():
     client = httpx.Client(base_url="https://www.paypal.com/")
     url = client.merge_url("http://www.paypal.com")
index ef57caef30be072b9673f3718988bea924fc211e..de67ff7f921d95816bf552360591d92d038f2de9 100644 (file)
@@ -288,6 +288,15 @@ def server():
     yield from serve_in_thread(server)
 
 
+@pytest.fixture(scope=SERVER_SCOPE)
+def uds_server():
+    uds = "test_server.sock"
+    config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
+    server = TestServer(config=config)
+    yield from serve_in_thread(server)
+    os.remove(uds)
+
+
 @pytest.fixture(scope=SERVER_SCOPE)
 def https_server(cert_pem_file, cert_private_key_file):
     config = Config(
@@ -301,3 +310,19 @@ def https_server(cert_pem_file, cert_private_key_file):
     )
     server = TestServer(config=config)
     yield from serve_in_thread(server)
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
+def https_uds_server(cert_pem_file, cert_private_key_file):
+    uds = "https_test_server.sock"
+    config = Config(
+        app=app,
+        lifespan="off",
+        ssl_certfile=cert_pem_file,
+        ssl_keyfile=cert_private_key_file,
+        uds=uds,
+        loop="asyncio",
+    )
+    server = TestServer(config=config)
+    yield from serve_in_thread(server)
+    os.remove(uds)
index 7477ea3b96ba8c208576c3f0c5ca5f281a4f6850..cf3844cffedf33b7e69dd8f88e8b8f5755fb5894 100644 (file)
@@ -24,7 +24,10 @@ from httpx.concurrency.trio import TrioBackend
         ),
     ],
 )
-async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
+@pytest.mark.parametrize("use_uds", (False, True))
+async def test_start_tls_on_socket_stream(
+    https_server, https_uds_server, backend, get_cipher, use_uds
+):
     """
     See that the concurrency backend can make a connection without TLS then
     start TLS on an existing connection.
@@ -32,9 +35,15 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
     ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
     timeout = TimeoutConfig(5)
 
-    stream = await backend.open_tcp_stream(
-        https_server.url.host, https_server.url.port, None, timeout
-    )
+    if use_uds:
+        assert https_uds_server.config.uds is not None
+        stream = await backend.open_uds_stream(
+            https_uds_server.config.uds, https_uds_server.url.host, None, timeout
+        )
+    else:
+        stream = await backend.open_tcp_stream(
+            https_server.url.host, https_server.url.port, None, timeout
+        )
 
     try:
         assert stream.is_connection_dropped() is False