From: Florimond Manca Date: Sat, 7 Dec 2019 14:17:35 +0000 (+0100) Subject: Refactor tests in the light of backend auto-detection (#615) X-Git-Tag: 0.9.4~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ab41a5d5c39627dfae8d31584e30183c28525b71;p=thirdparty%2Fhttpx.git Refactor tests in the light of backend auto-detection (#615) * Refactor tests in the light of backend auto-detection * Test passing explicit backend separately * Drop 'backend=backend' * Fix usage of asyncio.run() on 3.6 --- diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index cc433e57..eb3fa83b 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -7,7 +7,7 @@ import httpx async def test_get(server, backend): url = server.url - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.get(url) assert response.status_code == 200 assert response.text == "Hello, world!" @@ -20,7 +20,7 @@ async def test_get(server, backend): async def test_build_request(server, backend): url = server.url.copy_with(path="/echo_headers") headers = {"Custom-header": "value"} - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: request = client.build_request("GET", url) request.headers.update(headers) response = await client.send(request) @@ -48,20 +48,20 @@ async def test_get_no_backend(server): async def test_post(server, backend): url = server.url - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.post(url, data=b"Hello, world!") assert response.status_code == 200 async def test_post_json(server, backend): url = server.url - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.post(url, json={"text": "Hello, world!"}) assert response.status_code == 200 async def test_stream_response(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: async with client.stream("GET", server.url) as response: body = await response.read() @@ -71,7 +71,7 @@ async def test_stream_response(server, backend): async def test_access_content_stream_response(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: async with client.stream("GET", server.url) as response: pass @@ -85,13 +85,13 @@ async def test_stream_request(server, backend): yield b"Hello, " yield b"world!" - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.request("POST", server.url, data=hello_world()) assert response.status_code == 200 async def test_raise_for_status(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: for status_code in (200, 400, 404, 500, 505): response = await client.request( "GET", server.url.copy_with(path=f"/status/{status_code}") @@ -106,33 +106,33 @@ async def test_raise_for_status(server, backend): async def test_options(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.options(server.url) assert response.status_code == 200 assert response.text == "Hello, world!" async def test_head(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.head(server.url) assert response.status_code == 200 assert response.text == "" async def test_put(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.put(server.url, data=b"Hello, world!") assert response.status_code == 200 async def test_patch(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.patch(server.url, data=b"Hello, world!") assert response.status_code == 200 async def test_delete(server, backend): - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.delete(server.url) assert response.status_code == 200 assert response.text == "Hello, world!" @@ -142,7 +142,7 @@ async def test_100_continue(server, backend): headers = {"Expect": "100-continue"} data = b"Echo request body" - async with httpx.Client(backend=backend) as client: + async with httpx.Client() as client: response = await client.post( server.url.copy_with(path="/echo_body"), headers=headers, data=data ) @@ -155,8 +155,22 @@ async def test_uds(uds_server, backend): url = uds_server.url uds = uds_server.config.uds assert uds is not None - async with httpx.Client(backend=backend, uds=uds) as client: + async with httpx.Client(uds=uds) as client: response = await client.get(url) assert response.status_code == 200 assert response.text == "Hello, world!" assert response.encoding == "iso-8859-1" + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("asyncio", marks=pytest.mark.asyncio), + pytest.param("trio", marks=pytest.mark.trio), + ], +) +async def test_explicit_backend(server, backend): + async with httpx.Client(backend=backend) as client: + response = await client.get(server.url) + assert response.status_code == 200 + assert response.text == "Hello, world!" diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 5e763b00..ec851419 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -105,7 +105,7 @@ class MockDispatch(Dispatcher): async def test_no_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.com/no_redirect" response = await client.get(url) assert response.status_code == 200 @@ -114,7 +114,7 @@ async def test_no_redirect(backend): async def test_redirect_301(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.post("https://example.org/redirect_301") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") @@ -122,7 +122,7 @@ async def test_redirect_301(backend): async def test_redirect_302(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.post("https://example.org/redirect_302") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") @@ -130,7 +130,7 @@ async def test_redirect_302(backend): async def test_redirect_303(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.get("https://example.org/redirect_303") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") @@ -138,7 +138,7 @@ async def test_redirect_303(backend): async def test_disallow_redirects(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.post( "https://example.org/redirect_303", allow_redirects=False ) @@ -155,7 +155,7 @@ async def test_disallow_redirects(backend): async def test_relative_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.get("https://example.org/relative_redirect") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") @@ -163,7 +163,7 @@ async def test_relative_redirect(backend): async def test_no_scheme_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.get("https://example.org/no_scheme_redirect") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") @@ -171,7 +171,7 @@ async def test_no_scheme_redirect(backend): async def test_fragment_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.get("https://example.org/relative_redirect#fragment") assert response.status_code == codes.OK assert response.url == URL("https://example.org/#fragment") @@ -179,7 +179,7 @@ async def test_fragment_redirect(backend): async def test_multiple_redirects(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) response = await client.get("https://example.org/multiple_redirects?count=20") assert response.status_code == codes.OK assert response.url == URL("https://example.org/multiple_redirects") @@ -195,13 +195,13 @@ async def test_multiple_redirects(backend): async def test_too_many_redirects(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) with pytest.raises(TooManyRedirects): await client.get("https://example.org/multiple_redirects?count=21") async def test_too_many_redirects_calling_next(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/multiple_redirects?count=21" response = await client.get(url, allow_redirects=False) with pytest.raises(TooManyRedirects): @@ -210,13 +210,13 @@ async def test_too_many_redirects_calling_next(backend): async def test_redirect_loop(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) with pytest.raises(RedirectLoop): await client.get("https://example.org/redirect_loop") async def test_redirect_loop_calling_next(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/redirect_loop" response = await client.get(url, allow_redirects=False) with pytest.raises(RedirectLoop): @@ -225,7 +225,7 @@ async def test_redirect_loop_calling_next(backend): async def test_cross_domain_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.com/cross_domain" headers = {"Authorization": "abc"} response = await client.get(url, headers=headers) @@ -234,7 +234,7 @@ async def test_cross_domain_redirect(backend): async def test_same_domain_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/cross_domain" headers = {"Authorization": "abc"} response = await client.get(url, headers=headers) @@ -246,7 +246,7 @@ async def test_body_redirect(backend): """ A 308 redirect should preserve the request body. """ - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/redirect_body" data = b"Example request body" response = await client.post(url, data=data) @@ -259,7 +259,7 @@ async def test_no_body_redirect(backend): """ A 303 redirect should remove the request body. """ - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/redirect_no_body" data = b"Example request body" response = await client.post(url, data=data) @@ -269,7 +269,7 @@ async def test_no_body_redirect(backend): async def test_cannot_redirect_streaming_body(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.org/redirect_body" async def streaming_body(): @@ -280,7 +280,7 @@ async def test_cannot_redirect_streaming_body(backend): async def test_cross_subdomain_redirect(backend): - client = Client(dispatch=MockDispatch(), backend=backend) + client = Client(dispatch=MockDispatch()) url = "https://example.com/cross_subdomain" response = await client.get(url) assert response.url == URL("https://www.example.org/cross_subdomain") @@ -326,7 +326,7 @@ class MockCookieDispatch(Dispatcher): async def test_redirect_cookie_behavior(backend): - client = Client(dispatch=MockCookieDispatch(), backend=backend) + client = Client(dispatch=MockCookieDispatch()) # The client is not logged in. response = await client.get("https://example.com/") diff --git a/tests/concurrency.py b/tests/concurrency.py index 74c3ab05..777ed5be 100644 --- a/tests/concurrency.py +++ b/tests/concurrency.py @@ -10,6 +10,7 @@ import typing import trio from httpx.concurrency.asyncio import AsyncioBackend +from httpx.concurrency.auto import AutoBackend from httpx.concurrency.trio import TrioBackend @@ -18,6 +19,11 @@ async def sleep(backend, seconds: int): raise NotImplementedError # pragma: no cover +@sleep.register(AutoBackend) +async def _sleep_auto(backend, seconds: int): + return await sleep(backend.backend, seconds=seconds) + + @sleep.register(AsyncioBackend) async def _sleep_asyncio(backend, seconds: int): await asyncio.sleep(seconds) @@ -33,6 +39,11 @@ async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awai raise NotImplementedError # pragma: no cover +@run_concurrently.register(AutoBackend) +async def _run_concurrently_auto(backend, *coroutines): + await run_concurrently(backend.backend, *coroutines) + + @run_concurrently.register(AsyncioBackend) async def _run_concurrently_asyncio(backend, *coroutines): coros = (coroutine() for coroutine in coroutines) @@ -44,3 +55,23 @@ async def _run_concurrently_trio(backend, *coroutines): async with trio.open_nursery() as nursery: for coroutine in coroutines: nursery.start_soon(coroutine) + + +@functools.singledispatch +def get_cipher(backend, stream): + raise NotImplementedError # pragma: no cover + + +@get_cipher.register(AutoBackend) +def _get_cipher_auto(backend, stream): + return get_cipher(backend.backend, stream) + + +@get_cipher.register(AsyncioBackend) +def _get_cipher_asyncio(backend, stream): + return stream.stream_writer.get_extra_info("cipher", default=None) + + +@get_cipher.register(TrioBackend) +def get_trio_cipher(backend, stream): + return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None diff --git a/tests/conftest.py b/tests/conftest.py index a3037cb2..b27913ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from uvicorn.main import Server from httpx import URL from httpx.concurrency.asyncio import AsyncioBackend -from httpx.concurrency.trio import TrioBackend +from httpx.concurrency.base import lookup_backend ENVIRONMENT_VARIABLES = { "SSL_CERT_FILE", @@ -51,13 +51,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]: @pytest.fixture( params=[ - pytest.param(AsyncioBackend, marks=pytest.mark.asyncio), - pytest.param(TrioBackend, marks=pytest.mark.trio), + # pytest uses the marks to set up the specified async environment and run + # 'async def' test functions. The "auto" backend should then auto-detect + # the environment it's running in. + # Passing the backend explicitly, e.g. `backend="asyncio"`, + # is tested separately. + pytest.param("auto", marks=pytest.mark.asyncio), + pytest.param("auto", marks=pytest.mark.trio), ] ) def backend(request): - backend_cls = request.param - return backend_cls() + return request.param async def app(scope, receive, send): @@ -271,9 +275,13 @@ def restart(backend): This fixture deals with possible differences between the environment of the test function and that of the server. """ + asyncio_backend = AsyncioBackend() + backend_implementation = lookup_backend(backend) async def restart(server): - await backend.run_in_threadpool(AsyncioBackend().run, server.restart) + await backend_implementation.run_in_threadpool( + asyncio_backend.run, server.restart + ) return restart diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index 4311655c..f10b16ab 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -6,7 +6,7 @@ async def test_keepalive_connections(server, backend): """ Connections should default to staying in a keep-alive state. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.read() assert len(http.active_connections) == 0 @@ -22,7 +22,7 @@ async def test_differing_connection_keys(server, backend): """ Connections to differing connection keys should result in multiple connections. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.read() assert len(http.active_connections) == 0 @@ -40,7 +40,7 @@ async def test_soft_limit(server, backend): """ pool_limits = httpx.PoolLimits(soft_limit=1) - async with ConnectionPool(pool_limits=pool_limits, backend=backend) as http: + async with ConnectionPool(pool_limits=pool_limits) as http: response = await http.request("GET", server.url) await response.read() assert len(http.active_connections) == 0 @@ -56,7 +56,7 @@ async def test_streaming_response_holds_connection(server, backend): """ A streaming request should hold the connection open until the response is read. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -71,7 +71,7 @@ async def test_multiple_concurrent_connections(server, backend): """ Multiple conncurrent requests should open multiple conncurrent connections. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response_a = await http.request("GET", server.url) assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -94,7 +94,7 @@ async def test_close_connections(server, backend): Using a `Connection: close` header should close the connection. """ headers = [(b"connection", b"close")] - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url, headers=headers) await response.read() assert len(http.active_connections) == 0 @@ -105,7 +105,7 @@ async def test_standard_response_close(server, backend): """ A standard close should keep the connection open. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.read() await response.close() @@ -117,7 +117,7 @@ async def test_premature_response_close(server, backend): """ A premature close should close the connection. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.close() assert len(http.active_connections) == 0 @@ -131,7 +131,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished( Upon keep-alive connection closed by remote a new connection should be reestablished. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.read() @@ -151,7 +151,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished( Upon keep-alive connection closed by remote a new connection should be reestablished. """ - async with ConnectionPool(backend=backend) as http: + async with ConnectionPool() as http: response = await http.request("GET", server.url) await response.read() @@ -169,9 +169,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back Verify that max_connections semaphore is released properly on a disconnected connection. """ - async with ConnectionPool( - pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend - ) as http: + async with ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http: response = await http.request("GET", server.url) await response.read() diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index 54fb2ba7..6fdb24ea 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -5,7 +5,7 @@ from httpx.dispatch.connection import HTTPConnection async def test_get(server, backend): - async with HTTPConnection(origin=server.url, backend=backend) as conn: + async with HTTPConnection(origin=server.url) as conn: response = await conn.request("GET", server.url) await response.read() assert response.status_code == 200 @@ -13,14 +13,14 @@ async def test_get(server, backend): async def test_post(server, backend): - async with HTTPConnection(origin=server.url, backend=backend) as conn: + async with HTTPConnection(origin=server.url) as conn: response = await conn.request("GET", server.url, data=b"Hello, world!") assert response.status_code == 200 async def test_premature_close(server, backend): with pytest.raises(httpx.ConnectionClosed): - async with HTTPConnection(origin=server.url, backend=backend) as conn: + async with HTTPConnection(origin=server.url) as conn: response = await conn.request( "GET", server.url.copy_with(path="/premature_close") ) @@ -31,9 +31,7 @@ async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file, backe """ An HTTPS request, with default SSL configuration set on the client. """ - async with HTTPConnection( - origin=https_server.url, verify=ca_cert_pem_file, backend=backend - ) as conn: + async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn: response = await conn.request("GET", https_server.url) await response.read() assert response.status_code == 200 @@ -44,7 +42,7 @@ async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file, back """ An HTTPS request, with SSL configuration set on the request. """ - async with HTTPConnection(origin=https_server.url, backend=backend) as conn: + async with HTTPConnection(origin=https_server.url) as conn: response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file) await response.read() assert response.status_code == 200 diff --git a/tests/dispatch/test_http2.py b/tests/dispatch/test_http2.py index 375b8053..a9c45097 100644 --- a/tests/dispatch/test_http2.py +++ b/tests/dispatch/test_http2.py @@ -104,7 +104,7 @@ async def test_http2_reconnect(): async def test_http2_settings_in_handshake(backend): - backend = MockHTTP2Backend(app=app, backend=backend) + backend = MockHTTP2Backend(app=app) async with Client(backend=backend, http2=True) as client: await client.get("http://example.org") @@ -139,7 +139,7 @@ async def test_http2_settings_in_handshake(backend): async def test_http2_live_request(backend): - async with Client(backend=backend, http2=True) as client: + async with Client(http2=True) as client: try: resp = await client.get("https://nghttp2.org/httpbin/anything") except TimeoutException: diff --git a/tests/dispatch/test_proxy_http.py b/tests/dispatch/test_proxy_http.py index 4234b984..eab7738d 100644 --- a/tests/dispatch/test_proxy_http.py +++ b/tests/dispatch/test_proxy_http.py @@ -19,7 +19,6 @@ async def test_proxy_tunnel_success(backend): b"\r\n", ] ), - backend=backend, ) async with httpx.HTTPProxy( proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY", @@ -53,7 +52,6 @@ async def test_proxy_tunnel_non_2xx_response(backend, status_code): b"\r\n", ] ), - backend=backend, ) with pytest.raises(httpx.ProxyError) as e: @@ -105,7 +103,6 @@ async def test_proxy_tunnel_start_tls(backend): b"\r\n", ] ), - backend=backend, ) async with httpx.HTTPProxy( proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY", @@ -155,7 +152,6 @@ async def test_proxy_forwarding(backend, proxy_mode): b"\r\n" ] ), - backend=backend, ) async with httpx.HTTPProxy( proxy_url="http://127.0.0.1:8000", diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 2f26fda3..00565289 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -6,15 +6,14 @@ import h2.connection import h2.events from httpx import Request, Timeout -from httpx.concurrency.asyncio import AsyncioBackend -from httpx.concurrency.base import BaseSocketStream +from httpx.concurrency.base import BaseSocketStream, lookup_backend from tests.concurrency import sleep class MockHTTP2Backend: - def __init__(self, app, backend=None): + def __init__(self, app, backend="auto"): self.app = app - self.backend = AsyncioBackend() if backend is None else backend + self.backend = lookup_backend(backend) self.server = None async def open_tcp_stream( @@ -168,8 +167,8 @@ class MockHTTP2Server(BaseSocketStream): class MockRawSocketBackend: - def __init__(self, data_to_send=b"", backend=None): - self.backend = AsyncioBackend() if backend is None else backend + def __init__(self, data_to_send=b"", backend="auto"): + self.backend = lookup_backend(backend) self.data_to_send = data_to_send self.received_data = [] self.stream = MockRawSocketStream(self) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index a3286c55..3e3d6403 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,19 +1,14 @@ +import asyncio + import pytest import trio from httpx import Timeout from httpx.concurrency.asyncio import AsyncioBackend +from httpx.concurrency.base import lookup_backend from httpx.concurrency.trio import TrioBackend from httpx.config import SSLConfig -from tests.concurrency import run_concurrently, sleep - - -def get_asyncio_cipher(stream): - return stream.stream_writer.get_extra_info("cipher", default=None) - - -def get_trio_cipher(stream): - return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None +from tests.concurrency import get_cipher, run_concurrently, sleep async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes: @@ -34,14 +29,8 @@ async def read_response(stream, timeout: Timeout, should_contain: bytes) -> byte return response -@pytest.mark.parametrize( - "backend, get_cipher", - [ - pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio), - pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio), - ], -) -async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher): +async def test_start_tls_on_tcp_socket_stream(https_server, backend): + backend = lookup_backend(backend) ctx = SSLConfig().load_ssl_context_no_verify() timeout = Timeout(5) @@ -51,11 +40,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher) try: assert stream.is_connection_dropped() is False - assert get_cipher(stream) is None + assert get_cipher(backend, stream) is None stream = await stream.start_tls(https_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False - assert get_cipher(stream) is not None + assert get_cipher(backend, stream) is not None await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout) @@ -66,14 +55,8 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher) await stream.close() -@pytest.mark.parametrize( - "backend, get_cipher", - [ - pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio), - pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio), - ], -) -async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cipher): +async def test_start_tls_on_uds_socket_stream(https_uds_server, backend): + backend = lookup_backend(backend) ctx = SSLConfig().load_ssl_context_no_verify() timeout = Timeout(5) @@ -83,11 +66,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cip try: assert stream.is_connection_dropped() is False - assert get_cipher(stream) is None + assert get_cipher(backend, stream) is None stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False - assert get_cipher(stream) is not None + assert get_cipher(backend, stream) is not None await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout) @@ -102,6 +85,7 @@ async def test_concurrent_read(server, backend): """ Regression test for: https://github.com/encode/httpx/issues/527 """ + backend = lookup_backend(backend) stream = await backend.open_tcp_stream( server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5) ) @@ -116,6 +100,7 @@ async def test_concurrent_read(server, backend): async def test_fork(backend): + backend = lookup_backend(backend) ok_counter = 0 async def ok(delay: int) -> None: @@ -159,3 +144,23 @@ async def test_fork(backend): # No 'match', since we can't know which will be raised first. with pytest.raises(RuntimeError): await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0]) + + +def test_lookup_backend(): + assert isinstance(lookup_backend("asyncio"), AsyncioBackend) + assert isinstance(lookup_backend("trio"), TrioBackend) + assert isinstance(lookup_backend(AsyncioBackend()), AsyncioBackend) + + async def get_backend_from_auto(): + auto_backend = lookup_backend("auto") + return auto_backend.backend + + loop = asyncio.get_event_loop() + backend = loop.run_until_complete(get_backend_from_auto()) + assert isinstance(backend, AsyncioBackend) + + backend = trio.run(get_backend_from_auto) + assert isinstance(backend, TrioBackend) + + with pytest.raises(Exception, match="unknownio"): + lookup_backend("unknownio") diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index 80fefd33..8619dc90 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -6,7 +6,7 @@ import httpx async def test_read_timeout(server, backend): timeout = httpx.Timeout(read_timeout=1e-6) - async with httpx.Client(timeout=timeout, backend=backend) as client: + async with httpx.Client(timeout=timeout) as client: with pytest.raises(httpx.ReadTimeout): await client.get(server.url.copy_with(path="/slow_response")) @@ -14,7 +14,7 @@ async def test_read_timeout(server, backend): async def test_write_timeout(server, backend): timeout = httpx.Timeout(write_timeout=1e-6) - async with httpx.Client(timeout=timeout, backend=backend) as client: + async with httpx.Client(timeout=timeout) as client: with pytest.raises(httpx.WriteTimeout): data = b"*" * 1024 * 1024 * 100 await client.put(server.url.copy_with(path="/slow_response"), data=data) @@ -23,7 +23,7 @@ async def test_write_timeout(server, backend): async def test_connect_timeout(server, backend): timeout = httpx.Timeout(connect_timeout=1e-6) - async with httpx.Client(timeout=timeout, backend=backend) as client: + async with httpx.Client(timeout=timeout) as client: with pytest.raises(httpx.ConnectTimeout): # See https://stackoverflow.com/questions/100841/ await client.get("http://10.255.255.1/") @@ -33,9 +33,7 @@ async def test_pool_timeout(server, backend): pool_limits = httpx.PoolLimits(hard_limit=1) timeout = httpx.Timeout(pool_timeout=1e-4) - async with httpx.Client( - pool_limits=pool_limits, timeout=timeout, backend=backend - ) as client: + async with httpx.Client(pool_limits=pool_limits, timeout=timeout) as client: async with client.stream("GET", server.url): with pytest.raises(httpx.PoolTimeout): await client.get("http://localhost:8000/")