app: typing.Callable = None,
backend: ConcurrencyBackend = None,
trust_env: bool = True,
+ uds: str = None,
):
if backend is None:
backend = AsyncioBackend()
pool_limits=pool_limits,
backend=backend,
trust_env=self.trust_env,
+ uds=uds,
)
elif isinstance(dispatch, Dispatcher):
async_dispatch = ThreadedDispatcher(dispatch, backend)
async requests.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
+ * **uds** - *(optional)* A path to a Unix domain socket to connect through.
"""
def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
+ async def open_uds_stream(
+ self,
+ path: str,
+ hostname: typing.Optional[str],
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> SocketStream:
+ server_hostname = hostname if ssl_context else None
+
+ try:
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_unix_connection(
+ path, ssl=ssl_context, server_hostname=server_hostname
+ ),
+ timeout.connect_timeout,
+ )
+ except asyncio.TimeoutError:
+ raise ConnectTimeout()
+
+ return SocketStream(
+ stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+ )
+
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover
+ async def open_uds_stream(
+ self,
+ path: str,
+ hostname: typing.Optional[str],
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> BaseSocketStream:
+ raise NotImplementedError() # pragma: no cover
+
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
return SocketStream(stream=stream, timeout=timeout)
+ async def open_uds_stream(
+ self,
+ path: str,
+ hostname: typing.Optional[str],
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> SocketStream:
+ connect_timeout = _or_inf(timeout.connect_timeout)
+
+ with trio.move_on_after(connect_timeout) as cancel_scope:
+ stream: trio.SocketStream = await trio.open_unix_socket(path)
+ if ssl_context is not None:
+ stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
+ await stream.do_handshake()
+
+ if cancel_scope.cancelled_caught:
+ raise ConnectTimeout()
+
+ return SocketStream(stream=stream, timeout=timeout)
+
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
+ uds: typing.Optional[str] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
self.http_versions = HTTPVersionConfig(http_versions)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
+ self.uds = uds
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
else:
on_release = functools.partial(self.release_func, self)
- logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
- stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
+ if self.uds is None:
+ logger.trace(
+ f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
+ )
+ stream = await self.backend.open_tcp_stream(
+ host, port, ssl_context, timeout
+ )
+ else:
+ logger.trace(
+ f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
+ )
+ stream = await self.backend.open_uds_stream(
+ self.uds, host, ssl_context, timeout
+ )
+
http_version = stream.get_http_version()
logger.trace(f"connected http_version={http_version!r}")
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_versions: HTTPVersionTypes = None,
backend: ConcurrencyBackend = None,
+ uds: typing.Optional[str] = None,
):
self.verify = verify
self.cert = cert
self.http_versions = http_versions
self.is_closed = False
self.trust_env = trust_env
+ self.uds = uds
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
backend=self.backend,
release_func=self.release_connection,
trust_env=self.trust_env,
+ uds=self.uds,
)
logger.trace(f"new_connection connection={connection!r}")
else:
assert response.status_code == 200
assert response.content == data
+
+
+async def test_uds(uds_server, backend):
+ url = uds_server.url
+ uds = uds_server.config.uds
+ assert uds is not None
+ async with httpx.AsyncClient(backend=backend, uds=uds) as client:
+ response = await client.get(url)
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+ assert response.encoding == "iso-8859-1"
assert response.url == base_url
+def test_uds(uds_server):
+ url = uds_server.url
+ uds = uds_server.config.uds
+ assert uds is not None
+ with httpx.Client(uds=uds) as http:
+ response = http.get(url)
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+ assert response.encoding == "iso-8859-1"
+
+
def test_merge_url():
client = httpx.Client(base_url="https://www.paypal.com/")
url = client.merge_url("http://www.paypal.com")
yield from serve_in_thread(server)
+@pytest.fixture(scope=SERVER_SCOPE)
+def uds_server():
+ uds = "test_server.sock"
+ config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
+ server = TestServer(config=config)
+ yield from serve_in_thread(server)
+ os.remove(uds)
+
+
@pytest.fixture(scope=SERVER_SCOPE)
def https_server(cert_pem_file, cert_private_key_file):
config = Config(
)
server = TestServer(config=config)
yield from serve_in_thread(server)
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
+def https_uds_server(cert_pem_file, cert_private_key_file):
+ uds = "https_test_server.sock"
+ config = Config(
+ app=app,
+ lifespan="off",
+ ssl_certfile=cert_pem_file,
+ ssl_keyfile=cert_private_key_file,
+ uds=uds,
+ loop="asyncio",
+ )
+ server = TestServer(config=config)
+ yield from serve_in_thread(server)
+ os.remove(uds)
),
],
)
-async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
+@pytest.mark.parametrize("use_uds", (False, True))
+async def test_start_tls_on_socket_stream(
+ https_server, https_uds_server, backend, get_cipher, use_uds
+):
"""
See that the concurrency backend can make a connection without TLS then
start TLS on an existing connection.
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)
- stream = await backend.open_tcp_stream(
- https_server.url.host, https_server.url.port, None, timeout
- )
+ if use_uds:
+ assert https_uds_server.config.uds is not None
+ stream = await backend.open_uds_stream(
+ https_uds_server.config.uds, https_uds_server.url.host, None, timeout
+ )
+ else:
+ stream = await backend.open_tcp_stream(
+ https_server.url.host, https_server.url.port, None, timeout
+ )
try:
assert stream.is_connection_dropped() is False