]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor start_tls tests (#567)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 29 Nov 2019 21:16:32 +0000 (22:16 +0100)
committerGitHub <noreply@github.com>
Fri, 29 Nov 2019 21:16:32 +0000 (22:16 +0100)
* Refactor start_tls tests

* Clean up read_response()

tests/test_concurrency.py

index 878d6a8933b51e7919a7a154d0aeb87e5f02d16a..da43987f0cb461ccac506168238be310578fa1c9 100644 (file)
@@ -6,45 +6,46 @@ from httpx.concurrency.trio import TrioBackend
 from tests.concurrency import run_concurrently
 
 
+def get_asyncio_cipher(stream):
+    return stream.stream_writer.get_extra_info("cipher", default=None)
+
+
+def get_trio_cipher(stream):
+    return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
+
+
+async def read_response(stream, timeout: float, should_contain: bytes) -> bytes:
+    # stream.read() only gives us *up to* as much data as we ask for. In order to
+    # cleanly close the stream, we must read until the end of the HTTP response.
+    response = b""
+    ended = False
+
+    for _ in range(5):  # Try read some (not too large) number of times...
+        response += await stream.read(8192, timeout)
+        # We know we're at the end of the response when we've received the body plus
+        # the terminating CRLFs.
+        if should_contain in response and response.endswith(b"\r\n\r\n"):
+            ended = True
+            break
+
+    assert ended
+    return response
+
+
 @pytest.mark.parametrize(
     "backend, get_cipher",
     [
-        pytest.param(
-            AsyncioBackend(),
-            lambda stream: stream.stream_writer.get_extra_info("cipher", default=None),
-            marks=pytest.mark.asyncio,
-        ),
-        pytest.param(
-            TrioBackend(),
-            lambda stream: (
-                stream.stream.cipher()
-                if isinstance(stream.stream, trio.SSLStream)
-                else None
-            ),
-            marks=pytest.mark.trio,
-        ),
+        pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
+        pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
     ],
 )
-@pytest.mark.parametrize("use_uds", (False, True))
-async def test_start_tls_on_socket_stream(
-    https_server, https_uds_server, backend, get_cipher, use_uds
-):
-    """
-    See that the concurrency backend can make a connection without TLS then
-    start TLS on an existing connection.
-    """
+async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher):
     ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
     timeout = TimeoutConfig(5)
 
-    if use_uds:
-        assert https_uds_server.config.uds is not None
-        stream = await backend.open_uds_stream(
-            https_uds_server.config.uds, https_uds_server.url.host, None, timeout
-        )
-    else:
-        stream = await backend.open_tcp_stream(
-            https_server.url.host, https_server.url.port, None, timeout
-        )
+    stream = await backend.open_tcp_stream(
+        https_server.url.host, https_server.url.port, None, timeout
+    )
 
     try:
         assert stream.is_connection_dropped() is False
@@ -56,20 +57,40 @@ async def test_start_tls_on_socket_stream(
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n")
 
-        # stream.read() only gives us *up to* as much data as we ask for. In order to
-        # cleanly close the stream, we must read until the end of the HTTP response.
-        read = b""
-        ended = False
-        for _ in range(5):  # Try read some (not too large) number of times...
-            read += await stream.read(8192, timeout)
-            # We know we're at the end of the response when we've received the body plus
-            # the terminating CRLFs.
-            if b"Hello, world!" in read and read.endswith(b"\r\n\r\n"):
-                ended = True
-                break
-
-        assert ended
-        assert read.startswith(b"HTTP/1.1 200 OK\r\n")
+        response = await read_response(stream, timeout, should_contain=b"Hello, world")
+        assert response.startswith(b"HTTP/1.1 200 OK\r\n")
+
+    finally:
+        await stream.close()
+
+
+@pytest.mark.parametrize(
+    "backend, get_cipher",
+    [
+        pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
+        pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
+    ],
+)
+async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cipher):
+    ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
+    timeout = TimeoutConfig(5)
+
+    stream = await backend.open_uds_stream(
+        https_uds_server.config.uds, https_uds_server.url.host, None, timeout
+    )
+
+    try:
+        assert stream.is_connection_dropped() is False
+        assert get_cipher(stream) is None
+
+        stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
+        assert stream.is_connection_dropped() is False
+        assert get_cipher(stream) is not None
+
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+
+        response = await read_response(stream, timeout, should_contain=b"Hello, world")
+        assert response.startswith(b"HTTP/1.1 200 OK\r\n")
 
     finally:
         await stream.close()