]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor tests in the light of backend auto-detection (#615)
authorFlorimond Manca <florimond.manca@gmail.com>
Sat, 7 Dec 2019 14:17:35 +0000 (15:17 +0100)
committerGitHub <noreply@github.com>
Sat, 7 Dec 2019 14:17:35 +0000 (15:17 +0100)
* Refactor tests in the light of backend auto-detection

* Test passing explicit backend separately

* Drop 'backend=backend'

* Fix usage of asyncio.run() on 3.6

tests/client/test_async_client.py
tests/client/test_redirects.py
tests/concurrency.py
tests/conftest.py
tests/dispatch/test_connection_pools.py
tests/dispatch/test_connections.py
tests/dispatch/test_http2.py
tests/dispatch/test_proxy_http.py
tests/dispatch/utils.py
tests/test_concurrency.py
tests/test_timeouts.py

index cc433e574dbb5adfe769075dd59c702785a57585..eb3fa83b1096123992dcaa4d91d1bf2868fea7a3 100644 (file)
@@ -7,7 +7,7 @@ import httpx
 
 async def test_get(server, backend):
     url = server.url
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.get(url)
     assert response.status_code == 200
     assert response.text == "Hello, world!"
@@ -20,7 +20,7 @@ async def test_get(server, backend):
 async def test_build_request(server, backend):
     url = server.url.copy_with(path="/echo_headers")
     headers = {"Custom-header": "value"}
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         request = client.build_request("GET", url)
         request.headers.update(headers)
         response = await client.send(request)
@@ -48,20 +48,20 @@ async def test_get_no_backend(server):
 
 async def test_post(server, backend):
     url = server.url
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.post(url, data=b"Hello, world!")
     assert response.status_code == 200
 
 
 async def test_post_json(server, backend):
     url = server.url
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.post(url, json={"text": "Hello, world!"})
     assert response.status_code == 200
 
 
 async def test_stream_response(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         async with client.stream("GET", server.url) as response:
             body = await response.read()
 
@@ -71,7 +71,7 @@ async def test_stream_response(server, backend):
 
 
 async def test_access_content_stream_response(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         async with client.stream("GET", server.url) as response:
             pass
 
@@ -85,13 +85,13 @@ async def test_stream_request(server, backend):
         yield b"Hello, "
         yield b"world!"
 
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.request("POST", server.url, data=hello_world())
     assert response.status_code == 200
 
 
 async def test_raise_for_status(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         for status_code in (200, 400, 404, 500, 505):
             response = await client.request(
                 "GET", server.url.copy_with(path=f"/status/{status_code}")
@@ -106,33 +106,33 @@ async def test_raise_for_status(server, backend):
 
 
 async def test_options(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.options(server.url)
     assert response.status_code == 200
     assert response.text == "Hello, world!"
 
 
 async def test_head(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.head(server.url)
     assert response.status_code == 200
     assert response.text == ""
 
 
 async def test_put(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.put(server.url, data=b"Hello, world!")
     assert response.status_code == 200
 
 
 async def test_patch(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.patch(server.url, data=b"Hello, world!")
     assert response.status_code == 200
 
 
 async def test_delete(server, backend):
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.delete(server.url)
     assert response.status_code == 200
     assert response.text == "Hello, world!"
@@ -142,7 +142,7 @@ async def test_100_continue(server, backend):
     headers = {"Expect": "100-continue"}
     data = b"Echo request body"
 
-    async with httpx.Client(backend=backend) as client:
+    async with httpx.Client() as client:
         response = await client.post(
             server.url.copy_with(path="/echo_body"), headers=headers, data=data
         )
@@ -155,8 +155,22 @@ async def test_uds(uds_server, backend):
     url = uds_server.url
     uds = uds_server.config.uds
     assert uds is not None
-    async with httpx.Client(backend=backend, uds=uds) as client:
+    async with httpx.Client(uds=uds) as client:
         response = await client.get(url)
     assert response.status_code == 200
     assert response.text == "Hello, world!"
     assert response.encoding == "iso-8859-1"
+
+
+@pytest.mark.parametrize(
+    "backend",
+    [
+        pytest.param("asyncio", marks=pytest.mark.asyncio),
+        pytest.param("trio", marks=pytest.mark.trio),
+    ],
+)
+async def test_explicit_backend(server, backend):
+    async with httpx.Client(backend=backend) as client:
+        response = await client.get(server.url)
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
index 5e763b002d53e3db5d7db1062deb0dcb16f9e2c6..ec8514195a281e29d613fb29762344b75ea409de 100644 (file)
@@ -105,7 +105,7 @@ class MockDispatch(Dispatcher):
 
 
 async def test_no_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.com/no_redirect"
     response = await client.get(url)
     assert response.status_code == 200
@@ -114,7 +114,7 @@ async def test_no_redirect(backend):
 
 
 async def test_redirect_301(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.post("https://example.org/redirect_301")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/")
@@ -122,7 +122,7 @@ async def test_redirect_301(backend):
 
 
 async def test_redirect_302(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.post("https://example.org/redirect_302")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/")
@@ -130,7 +130,7 @@ async def test_redirect_302(backend):
 
 
 async def test_redirect_303(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.get("https://example.org/redirect_303")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/")
@@ -138,7 +138,7 @@ async def test_redirect_303(backend):
 
 
 async def test_disallow_redirects(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.post(
         "https://example.org/redirect_303", allow_redirects=False
     )
@@ -155,7 +155,7 @@ async def test_disallow_redirects(backend):
 
 
 async def test_relative_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.get("https://example.org/relative_redirect")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/")
@@ -163,7 +163,7 @@ async def test_relative_redirect(backend):
 
 
 async def test_no_scheme_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.get("https://example.org/no_scheme_redirect")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/")
@@ -171,7 +171,7 @@ async def test_no_scheme_redirect(backend):
 
 
 async def test_fragment_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.get("https://example.org/relative_redirect#fragment")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/#fragment")
@@ -179,7 +179,7 @@ async def test_fragment_redirect(backend):
 
 
 async def test_multiple_redirects(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     response = await client.get("https://example.org/multiple_redirects?count=20")
     assert response.status_code == codes.OK
     assert response.url == URL("https://example.org/multiple_redirects")
@@ -195,13 +195,13 @@ async def test_multiple_redirects(backend):
 
 
 async def test_too_many_redirects(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     with pytest.raises(TooManyRedirects):
         await client.get("https://example.org/multiple_redirects?count=21")
 
 
 async def test_too_many_redirects_calling_next(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/multiple_redirects?count=21"
     response = await client.get(url, allow_redirects=False)
     with pytest.raises(TooManyRedirects):
@@ -210,13 +210,13 @@ async def test_too_many_redirects_calling_next(backend):
 
 
 async def test_redirect_loop(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     with pytest.raises(RedirectLoop):
         await client.get("https://example.org/redirect_loop")
 
 
 async def test_redirect_loop_calling_next(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/redirect_loop"
     response = await client.get(url, allow_redirects=False)
     with pytest.raises(RedirectLoop):
@@ -225,7 +225,7 @@ async def test_redirect_loop_calling_next(backend):
 
 
 async def test_cross_domain_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.com/cross_domain"
     headers = {"Authorization": "abc"}
     response = await client.get(url, headers=headers)
@@ -234,7 +234,7 @@ async def test_cross_domain_redirect(backend):
 
 
 async def test_same_domain_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/cross_domain"
     headers = {"Authorization": "abc"}
     response = await client.get(url, headers=headers)
@@ -246,7 +246,7 @@ async def test_body_redirect(backend):
     """
     A 308 redirect should preserve the request body.
     """
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/redirect_body"
     data = b"Example request body"
     response = await client.post(url, data=data)
@@ -259,7 +259,7 @@ async def test_no_body_redirect(backend):
     """
     A 303 redirect should remove the request body.
     """
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/redirect_no_body"
     data = b"Example request body"
     response = await client.post(url, data=data)
@@ -269,7 +269,7 @@ async def test_no_body_redirect(backend):
 
 
 async def test_cannot_redirect_streaming_body(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.org/redirect_body"
 
     async def streaming_body():
@@ -280,7 +280,7 @@ async def test_cannot_redirect_streaming_body(backend):
 
 
 async def test_cross_subdomain_redirect(backend):
-    client = Client(dispatch=MockDispatch(), backend=backend)
+    client = Client(dispatch=MockDispatch())
     url = "https://example.com/cross_subdomain"
     response = await client.get(url)
     assert response.url == URL("https://www.example.org/cross_subdomain")
@@ -326,7 +326,7 @@ class MockCookieDispatch(Dispatcher):
 
 
 async def test_redirect_cookie_behavior(backend):
-    client = Client(dispatch=MockCookieDispatch(), backend=backend)
+    client = Client(dispatch=MockCookieDispatch())
 
     # The client is not logged in.
     response = await client.get("https://example.com/")
index 74c3ab05bfddaf3ac5688e10db426c4c885bc4f3..777ed5be120e825ec561dd86aac1ad93e662e95f 100644 (file)
@@ -10,6 +10,7 @@ import typing
 import trio
 
 from httpx.concurrency.asyncio import AsyncioBackend
+from httpx.concurrency.auto import AutoBackend
 from httpx.concurrency.trio import TrioBackend
 
 
@@ -18,6 +19,11 @@ async def sleep(backend, seconds: int):
     raise NotImplementedError  # pragma: no cover
 
 
+@sleep.register(AutoBackend)
+async def _sleep_auto(backend, seconds: int):
+    return await sleep(backend.backend, seconds=seconds)
+
+
 @sleep.register(AsyncioBackend)
 async def _sleep_asyncio(backend, seconds: int):
     await asyncio.sleep(seconds)
@@ -33,6 +39,11 @@ async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awai
     raise NotImplementedError  # pragma: no cover
 
 
+@run_concurrently.register(AutoBackend)
+async def _run_concurrently_auto(backend, *coroutines):
+    await run_concurrently(backend.backend, *coroutines)
+
+
 @run_concurrently.register(AsyncioBackend)
 async def _run_concurrently_asyncio(backend, *coroutines):
     coros = (coroutine() for coroutine in coroutines)
@@ -44,3 +55,23 @@ async def _run_concurrently_trio(backend, *coroutines):
     async with trio.open_nursery() as nursery:
         for coroutine in coroutines:
             nursery.start_soon(coroutine)
+
+
+@functools.singledispatch
+def get_cipher(backend, stream):
+    raise NotImplementedError  # pragma: no cover
+
+
+@get_cipher.register(AutoBackend)
+def _get_cipher_auto(backend, stream):
+    return get_cipher(backend.backend, stream)
+
+
+@get_cipher.register(AsyncioBackend)
+def _get_cipher_asyncio(backend, stream):
+    return stream.stream_writer.get_extra_info("cipher", default=None)
+
+
+@get_cipher.register(TrioBackend)
+def get_trio_cipher(backend, stream):
+    return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
index a3037cb2b0614716999cb275d4b5a258e2a8932f..b27913eaf635c16196d9d6134afd650ce1dc8b19 100644 (file)
@@ -19,7 +19,7 @@ from uvicorn.main import Server
 
 from httpx import URL
 from httpx.concurrency.asyncio import AsyncioBackend
-from httpx.concurrency.trio import TrioBackend
+from httpx.concurrency.base import lookup_backend
 
 ENVIRONMENT_VARIABLES = {
     "SSL_CERT_FILE",
@@ -51,13 +51,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
 
 @pytest.fixture(
     params=[
-        pytest.param(AsyncioBackend, marks=pytest.mark.asyncio),
-        pytest.param(TrioBackend, marks=pytest.mark.trio),
+        # pytest uses the marks to set up the specified async environment and run
+        # 'async def' test functions. The "auto" backend should then auto-detect
+        # the environment it's running in.
+        # Passing the backend explicitly, e.g. `backend="asyncio"`,
+        # is tested separately.
+        pytest.param("auto", marks=pytest.mark.asyncio),
+        pytest.param("auto", marks=pytest.mark.trio),
     ]
 )
 def backend(request):
-    backend_cls = request.param
-    return backend_cls()
+    return request.param
 
 
 async def app(scope, receive, send):
@@ -271,9 +275,13 @@ def restart(backend):
     This fixture deals with possible differences between the environment of the
     test function and that of the server.
     """
+    asyncio_backend = AsyncioBackend()
+    backend_implementation = lookup_backend(backend)
 
     async def restart(server):
-        await backend.run_in_threadpool(AsyncioBackend().run, server.restart)
+        await backend_implementation.run_in_threadpool(
+            asyncio_backend.run, server.restart
+        )
 
     return restart
 
index 4311655c59e126372a1ba56833540dba60d2e983..f10b16ab1451dcc85d21bcf924c97abe2c8ea270 100644 (file)
@@ -6,7 +6,7 @@ async def test_keepalive_connections(server, backend):
     """
     Connections should default to staying in a keep-alive state.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.read()
         assert len(http.active_connections) == 0
@@ -22,7 +22,7 @@ async def test_differing_connection_keys(server, backend):
     """
     Connections to differing connection keys should result in multiple connections.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.read()
         assert len(http.active_connections) == 0
@@ -40,7 +40,7 @@ async def test_soft_limit(server, backend):
     """
     pool_limits = httpx.PoolLimits(soft_limit=1)
 
-    async with ConnectionPool(pool_limits=pool_limits, backend=backend) as http:
+    async with ConnectionPool(pool_limits=pool_limits) as http:
         response = await http.request("GET", server.url)
         await response.read()
         assert len(http.active_connections) == 0
@@ -56,7 +56,7 @@ async def test_streaming_response_holds_connection(server, backend):
     """
     A streaming request should hold the connection open until the response is read.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         assert len(http.active_connections) == 1
         assert len(http.keepalive_connections) == 0
@@ -71,7 +71,7 @@ async def test_multiple_concurrent_connections(server, backend):
     """
     Multiple conncurrent requests should open multiple conncurrent connections.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response_a = await http.request("GET", server.url)
         assert len(http.active_connections) == 1
         assert len(http.keepalive_connections) == 0
@@ -94,7 +94,7 @@ async def test_close_connections(server, backend):
     Using a `Connection: close` header should close the connection.
     """
     headers = [(b"connection", b"close")]
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url, headers=headers)
         await response.read()
         assert len(http.active_connections) == 0
@@ -105,7 +105,7 @@ async def test_standard_response_close(server, backend):
     """
     A standard close should keep the connection open.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.read()
         await response.close()
@@ -117,7 +117,7 @@ async def test_premature_response_close(server, backend):
     """
     A premature close should close the connection.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.close()
         assert len(http.active_connections) == 0
@@ -131,7 +131,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(
     Upon keep-alive connection closed by remote a new connection
     should be reestablished.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.read()
 
@@ -151,7 +151,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
     Upon keep-alive connection closed by remote a new connection
     should be reestablished.
     """
-    async with ConnectionPool(backend=backend) as http:
+    async with ConnectionPool() as http:
         response = await http.request("GET", server.url)
         await response.read()
 
@@ -169,9 +169,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back
     Verify that max_connections semaphore is released
     properly on a disconnected connection.
     """
-    async with ConnectionPool(
-        pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend
-    ) as http:
+    async with ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
         response = await http.request("GET", server.url)
         await response.read()
 
index 54fb2ba7addd222efcfaf4026c76ab24f22e0db2..6fdb24ea6e7405b51eaa48f1fc2adc7020cbf935 100644 (file)
@@ -5,7 +5,7 @@ from httpx.dispatch.connection import HTTPConnection
 
 
 async def test_get(server, backend):
-    async with HTTPConnection(origin=server.url, backend=backend) as conn:
+    async with HTTPConnection(origin=server.url) as conn:
         response = await conn.request("GET", server.url)
         await response.read()
         assert response.status_code == 200
@@ -13,14 +13,14 @@ async def test_get(server, backend):
 
 
 async def test_post(server, backend):
-    async with HTTPConnection(origin=server.url, backend=backend) as conn:
+    async with HTTPConnection(origin=server.url) as conn:
         response = await conn.request("GET", server.url, data=b"Hello, world!")
         assert response.status_code == 200
 
 
 async def test_premature_close(server, backend):
     with pytest.raises(httpx.ConnectionClosed):
-        async with HTTPConnection(origin=server.url, backend=backend) as conn:
+        async with HTTPConnection(origin=server.url) as conn:
             response = await conn.request(
                 "GET", server.url.copy_with(path="/premature_close")
             )
@@ -31,9 +31,7 @@ async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file, backe
     """
     An HTTPS request, with default SSL configuration set on the client.
     """
-    async with HTTPConnection(
-        origin=https_server.url, verify=ca_cert_pem_file, backend=backend
-    ) as conn:
+    async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
         response = await conn.request("GET", https_server.url)
         await response.read()
         assert response.status_code == 200
@@ -44,7 +42,7 @@ async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file, back
     """
     An HTTPS request, with SSL configuration set on the request.
     """
-    async with HTTPConnection(origin=https_server.url, backend=backend) as conn:
+    async with HTTPConnection(origin=https_server.url) as conn:
         response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file)
         await response.read()
         assert response.status_code == 200
index 375b805394cf5cac487a273c88a95434b9b904eb..a9c450975613e00c396e8393b7e780129b2a69f2 100644 (file)
@@ -104,7 +104,7 @@ async def test_http2_reconnect():
 
 
 async def test_http2_settings_in_handshake(backend):
-    backend = MockHTTP2Backend(app=app, backend=backend)
+    backend = MockHTTP2Backend(app=app)
 
     async with Client(backend=backend, http2=True) as client:
         await client.get("http://example.org")
@@ -139,7 +139,7 @@ async def test_http2_settings_in_handshake(backend):
 
 
 async def test_http2_live_request(backend):
-    async with Client(backend=backend, http2=True) as client:
+    async with Client(http2=True) as client:
         try:
             resp = await client.get("https://nghttp2.org/httpbin/anything")
         except TimeoutException:
index 4234b9847f7374e0dfda1e7a25e0be3cb859ccae..eab7738d70503b9cb4ca074f3312b30b8c006241 100644 (file)
@@ -19,7 +19,6 @@ async def test_proxy_tunnel_success(backend):
                 b"\r\n",
             ]
         ),
-        backend=backend,
     )
     async with httpx.HTTPProxy(
         proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
@@ -53,7 +52,6 @@ async def test_proxy_tunnel_non_2xx_response(backend, status_code):
                 b"\r\n",
             ]
         ),
-        backend=backend,
     )
 
     with pytest.raises(httpx.ProxyError) as e:
@@ -105,7 +103,6 @@ async def test_proxy_tunnel_start_tls(backend):
                 b"\r\n",
             ]
         ),
-        backend=backend,
     )
     async with httpx.HTTPProxy(
         proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
@@ -155,7 +152,6 @@ async def test_proxy_forwarding(backend, proxy_mode):
                 b"\r\n"
             ]
         ),
-        backend=backend,
     )
     async with httpx.HTTPProxy(
         proxy_url="http://127.0.0.1:8000",
index 2f26fda3b3fd47ca2d97c20e798d701e8e7df1ce..00565289b69f2393737f4755b0e56234416647cf 100644 (file)
@@ -6,15 +6,14 @@ import h2.connection
 import h2.events
 
 from httpx import Request, Timeout
-from httpx.concurrency.asyncio import AsyncioBackend
-from httpx.concurrency.base import BaseSocketStream
+from httpx.concurrency.base import BaseSocketStream, lookup_backend
 from tests.concurrency import sleep
 
 
 class MockHTTP2Backend:
-    def __init__(self, app, backend=None):
+    def __init__(self, app, backend="auto"):
         self.app = app
-        self.backend = AsyncioBackend() if backend is None else backend
+        self.backend = lookup_backend(backend)
         self.server = None
 
     async def open_tcp_stream(
@@ -168,8 +167,8 @@ class MockHTTP2Server(BaseSocketStream):
 
 
 class MockRawSocketBackend:
-    def __init__(self, data_to_send=b"", backend=None):
-        self.backend = AsyncioBackend() if backend is None else backend
+    def __init__(self, data_to_send=b"", backend="auto"):
+        self.backend = lookup_backend(backend)
         self.data_to_send = data_to_send
         self.received_data = []
         self.stream = MockRawSocketStream(self)
index a3286c5569551d34537b38d6d5ec54c21624a55a..3e3d64030c2aa6c9783bb8092dece452d956202a 100644 (file)
@@ -1,19 +1,14 @@
+import asyncio
+
 import pytest
 import trio
 
 from httpx import Timeout
 from httpx.concurrency.asyncio import AsyncioBackend
+from httpx.concurrency.base import lookup_backend
 from httpx.concurrency.trio import TrioBackend
 from httpx.config import SSLConfig
-from tests.concurrency import run_concurrently, sleep
-
-
-def get_asyncio_cipher(stream):
-    return stream.stream_writer.get_extra_info("cipher", default=None)
-
-
-def get_trio_cipher(stream):
-    return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
+from tests.concurrency import get_cipher, run_concurrently, sleep
 
 
 async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes:
@@ -34,14 +29,8 @@ async def read_response(stream, timeout: Timeout, should_contain: bytes) -> byte
     return response
 
 
-@pytest.mark.parametrize(
-    "backend, get_cipher",
-    [
-        pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
-        pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
-    ],
-)
-async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher):
+async def test_start_tls_on_tcp_socket_stream(https_server, backend):
+    backend = lookup_backend(backend)
     ctx = SSLConfig().load_ssl_context_no_verify()
     timeout = Timeout(5)
 
@@ -51,11 +40,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher)
 
     try:
         assert stream.is_connection_dropped() is False
-        assert get_cipher(stream) is None
+        assert get_cipher(backend, stream) is None
 
         stream = await stream.start_tls(https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert get_cipher(stream) is not None
+        assert get_cipher(backend, stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
@@ -66,14 +55,8 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher)
         await stream.close()
 
 
-@pytest.mark.parametrize(
-    "backend, get_cipher",
-    [
-        pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
-        pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
-    ],
-)
-async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cipher):
+async def test_start_tls_on_uds_socket_stream(https_uds_server, backend):
+    backend = lookup_backend(backend)
     ctx = SSLConfig().load_ssl_context_no_verify()
     timeout = Timeout(5)
 
@@ -83,11 +66,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cip
 
     try:
         assert stream.is_connection_dropped() is False
-        assert get_cipher(stream) is None
+        assert get_cipher(backend, stream) is None
 
         stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert get_cipher(stream) is not None
+        assert get_cipher(backend, stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
@@ -102,6 +85,7 @@ async def test_concurrent_read(server, backend):
     """
     Regression test for: https://github.com/encode/httpx/issues/527
     """
+    backend = lookup_backend(backend)
     stream = await backend.open_tcp_stream(
         server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5)
     )
@@ -116,6 +100,7 @@ async def test_concurrent_read(server, backend):
 
 
 async def test_fork(backend):
+    backend = lookup_backend(backend)
     ok_counter = 0
 
     async def ok(delay: int) -> None:
@@ -159,3 +144,23 @@ async def test_fork(backend):
     # No 'match', since we can't know which will be raised first.
     with pytest.raises(RuntimeError):
         await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0])
+
+
+def test_lookup_backend():
+    assert isinstance(lookup_backend("asyncio"), AsyncioBackend)
+    assert isinstance(lookup_backend("trio"), TrioBackend)
+    assert isinstance(lookup_backend(AsyncioBackend()), AsyncioBackend)
+
+    async def get_backend_from_auto():
+        auto_backend = lookup_backend("auto")
+        return auto_backend.backend
+
+    loop = asyncio.get_event_loop()
+    backend = loop.run_until_complete(get_backend_from_auto())
+    assert isinstance(backend, AsyncioBackend)
+
+    backend = trio.run(get_backend_from_auto)
+    assert isinstance(backend, TrioBackend)
+
+    with pytest.raises(Exception, match="unknownio"):
+        lookup_backend("unknownio")
index 80fefd33341363d80e671b6b19b936fcbfda8254..8619dc90b24afa3be006d3218d560571f5d9f055 100644 (file)
@@ -6,7 +6,7 @@ import httpx
 async def test_read_timeout(server, backend):
     timeout = httpx.Timeout(read_timeout=1e-6)
 
-    async with httpx.Client(timeout=timeout, backend=backend) as client:
+    async with httpx.Client(timeout=timeout) as client:
         with pytest.raises(httpx.ReadTimeout):
             await client.get(server.url.copy_with(path="/slow_response"))
 
@@ -14,7 +14,7 @@ async def test_read_timeout(server, backend):
 async def test_write_timeout(server, backend):
     timeout = httpx.Timeout(write_timeout=1e-6)
 
-    async with httpx.Client(timeout=timeout, backend=backend) as client:
+    async with httpx.Client(timeout=timeout) as client:
         with pytest.raises(httpx.WriteTimeout):
             data = b"*" * 1024 * 1024 * 100
             await client.put(server.url.copy_with(path="/slow_response"), data=data)
@@ -23,7 +23,7 @@ async def test_write_timeout(server, backend):
 async def test_connect_timeout(server, backend):
     timeout = httpx.Timeout(connect_timeout=1e-6)
 
-    async with httpx.Client(timeout=timeout, backend=backend) as client:
+    async with httpx.Client(timeout=timeout) as client:
         with pytest.raises(httpx.ConnectTimeout):
             # See https://stackoverflow.com/questions/100841/
             await client.get("http://10.255.255.1/")
@@ -33,9 +33,7 @@ async def test_pool_timeout(server, backend):
     pool_limits = httpx.PoolLimits(hard_limit=1)
     timeout = httpx.Timeout(pool_timeout=1e-4)
 
-    async with httpx.Client(
-        pool_limits=pool_limits, timeout=timeout, backend=backend
-    ) as client:
+    async with httpx.Client(pool_limits=pool_limits, timeout=timeout) as client:
         async with client.stream("GET", server.url):
             with pytest.raises(httpx.PoolTimeout):
                 await client.get("http://localhost:8000/")