]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor test_auth.py to use `MockTransport` class. (#1288)
authorTom Christie <tom@tomchristie.com>
Mon, 14 Sep 2020 16:44:05 +0000 (17:44 +0100)
committerGitHub <noreply@github.com>
Mon, 14 Sep 2020 16:44:05 +0000 (17:44 +0100)
* Use tests.utils.MockTransport

* Use tests.utils.MockTransport

tests/client/test_auth.py

index c6c6d979accc785bb2f753ff49c9e35548d14015..cc6fd69c000ab5deabcd4928cc05e0033f8084be 100644 (file)
@@ -5,11 +5,11 @@ Unit tests for auth classes also exist in tests/test_auth.py
 """
 import asyncio
 import hashlib
+import json
 import os
 import threading
 import typing
 
-import httpcore
 import pytest
 
 import httpx
@@ -23,61 +23,24 @@ from httpx import (
     RequestBodyUnavailable,
     Response,
 )
-from httpx._content_streams import ContentStream, JSONStream
+from tests.utils import AsyncMockTransport, MockTransport
 
 from ..common import FIXTURES_DIR
 
 
-def get_header_value(headers, key, default=None):
-    lookup = key.encode("ascii").lower()
-    for header_key, header_value in headers:
-        if header_key.lower() == lookup:
-            return header_value.decode("ascii")
-    return default
-
-
-class MockTransport:
-    def __init__(self, auth_header: bytes = b"", status_code: int = 200) -> None:
+class App:
+    def __init__(self, auth_header: str = "", status_code: int = 200) -> None:
         self.auth_header = auth_header
         self.status_code = status_code
 
-    def _request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        authorization = get_header_value(headers, "Authorization")
-        response_headers = (
-            [(b"www-authenticate", self.auth_header)] if self.auth_header else []
-        )
-        response_stream = JSONStream({"auth": authorization})
-        return b"HTTP/1.1", self.status_code, b"", response_headers, response_stream
-
+    def __call__(self, request: httpx.Request) -> httpx.Response:
+        headers = {"www-authenticate": self.auth_header} if self.auth_header else {}
+        data = {"auth": request.headers.get("Authorization")}
+        content = json.dumps(data).encode("utf-8")
+        return httpx.Response(self.status_code, headers=headers, content=content)
 
-class AsyncMockTransport(MockTransport, httpcore.AsyncHTTPTransport):
-    async def request(
-        self, *args, **kwargs
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        return self._request(*args, **kwargs)
 
-
-class SyncMockTransport(MockTransport, httpcore.SyncHTTPTransport):
-    def request(
-        self, *args, **kwargs
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        return self._request(*args, **kwargs)
-
-
-class MockDigestAuthTransport(httpcore.AsyncHTTPTransport):
+class DigestApp:
     def __init__(
         self,
         algorithm: str = "SHA-256",
@@ -91,29 +54,15 @@ class MockDigestAuthTransport(httpcore.AsyncHTTPTransport):
         self._regenerate_nonce = regenerate_nonce
         self._response_count = 0
 
-    async def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.AsyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
+    def __call__(self, request: httpx.Request) -> httpx.Response:
         if self._response_count < self.send_response_after_attempt:
-            assert headers is not None
-            return self.challenge_send(method, headers)
-
-        authorization = get_header_value(headers, "Authorization")
-        body = JSONStream({"auth": authorization})
-        return b"HTTP/1.1", 200, b"", [], body
-
-    def challenge_send(
-        self, method: bytes, headers: typing.List[typing.Tuple[bytes, bytes]]
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
+            return self.challenge_send(request)
+
+        data = {"auth": request.headers.get("Authorization")}
+        content = json.dumps(data).encode("utf-8")
+        return httpx.Response(200, content=content)
+
+    def challenge_send(self, request: httpx.Request) -> httpx.Response:
         self._response_count += 1
         nonce = (
             hashlib.sha256(os.urandom(8)).hexdigest()
@@ -135,13 +84,10 @@ class MockDigestAuthTransport(httpcore.AsyncHTTPTransport):
             if value
         )
 
-        headers = [
-            (
-                b"www-authenticate",
-                b'Digest realm="httpx@example.org", ' + challenge_str.encode("ascii"),
-            )
-        ]
-        return b"HTTP/1.1", 401, b"", headers, ContentStream()
+        headers = {
+            "www-authenticate": f'Digest realm="httpx@example.org", {challenge_str}',
+        }
+        return Response(401, headers=headers)
 
 
 class RepeatAuth(Auth):
@@ -219,8 +165,9 @@ class SyncOrAsyncAuth(Auth):
 async def test_basic_auth() -> None:
     url = "https://example.org/"
     auth = ("tomchristie", "password123")
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -230,8 +177,9 @@ async def test_basic_auth() -> None:
 @pytest.mark.asyncio
 async def test_basic_auth_in_url() -> None:
     url = "https://tomchristie:password123@example.org/"
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url)
 
     assert response.status_code == 200
@@ -242,8 +190,11 @@ async def test_basic_auth_in_url() -> None:
 async def test_basic_auth_on_session() -> None:
     url = "https://example.org/"
     auth = ("tomchristie", "password123")
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport(), auth=auth) as client:
+    async with httpx.AsyncClient(
+        transport=AsyncMockTransport(app), auth=auth
+    ) as client:
         response = await client.get(url)
 
     assert response.status_code == 200
@@ -253,12 +204,13 @@ async def test_basic_auth_on_session() -> None:
 @pytest.mark.asyncio
 async def test_custom_auth() -> None:
     url = "https://example.org/"
+    app = App()
 
     def auth(request: Request) -> Request:
         request.headers["Authorization"] = "Token 123"
         return request
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -269,8 +221,9 @@ async def test_custom_auth() -> None:
 async def test_netrc_auth() -> None:
     os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
     url = "http://netrcexample.org"
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url)
 
     assert response.status_code == 200
@@ -283,8 +236,9 @@ async def test_netrc_auth() -> None:
 async def test_auth_header_has_priority_over_netrc() -> None:
     os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
     url = "http://netrcexample.org"
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, headers={"Authorization": "Override"})
 
     assert response.status_code == 200
@@ -295,9 +249,10 @@ async def test_auth_header_has_priority_over_netrc() -> None:
 async def test_trust_env_auth() -> None:
     os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
     url = "http://netrcexample.org"
+    app = App()
 
     async with httpx.AsyncClient(
-        transport=AsyncMockTransport(), trust_env=False
+        transport=AsyncMockTransport(app), trust_env=False
     ) as client:
         response = await client.get(url)
 
@@ -305,7 +260,7 @@ async def test_trust_env_auth() -> None:
     assert response.json() == {"auth": None}
 
     async with httpx.AsyncClient(
-        transport=AsyncMockTransport(), trust_env=True
+        transport=AsyncMockTransport(app), trust_env=True
     ) as client:
         response = await client.get(url)
 
@@ -319,8 +274,11 @@ async def test_trust_env_auth() -> None:
 async def test_auth_disable_per_request() -> None:
     url = "https://example.org/"
     auth = ("tomchristie", "password123")
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport(), auth=auth) as client:
+    async with httpx.AsyncClient(
+        transport=AsyncMockTransport(app), auth=auth
+    ) as client:
         response = await client.get(url, auth=None)
 
     assert response.status_code == 200
@@ -338,8 +296,9 @@ def test_auth_hidden_url() -> None:
 async def test_auth_hidden_header() -> None:
     url = "https://example.org/"
     auth = ("example-username", "example-password")
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert "'authorization': '[secure]'" in str(response.request.headers)
@@ -347,7 +306,9 @@ async def test_auth_hidden_header() -> None:
 
 @pytest.mark.asyncio
 async def test_auth_property() -> None:
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    app = App()
+
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         assert client.auth is None
 
         client.auth = ("tomchristie", "password123")  # type: ignore
@@ -361,13 +322,15 @@ async def test_auth_property() -> None:
 
 @pytest.mark.asyncio
 async def test_auth_invalid_type() -> None:
+    app = App()
+
     with pytest.raises(TypeError):
         client = httpx.AsyncClient(
-            transport=AsyncMockTransport(),
+            transport=AsyncMockTransport(app),
             auth="not a tuple, not a callable",  # type: ignore
         )
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         with pytest.raises(TypeError):
             await client.get(auth="not a tuple, not a callable")  # type: ignore
 
@@ -379,8 +342,9 @@ async def test_auth_invalid_type() -> None:
 async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -391,11 +355,10 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
 def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
-    auth_header = b"Token ..."
+    auth_header = "Token ..."
+    app = App(auth_header=auth_header, status_code=401)
 
-    client = httpx.Client(
-        transport=SyncMockTransport(auth_header=auth_header, status_code=401)
-    )
+    client = httpx.Client(transport=MockTransport(app))
     response = client.get(url, auth=auth)
 
     assert response.status_code == 401
@@ -407,11 +370,10 @@ def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
 async def test_digest_auth_200_response_including_digest_auth_header() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
-    auth_header = b'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
+    auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
+    app = App(auth_header=auth_header, status_code=200)
 
-    async with httpx.AsyncClient(
-        transport=AsyncMockTransport(auth_header=auth_header, status_code=200)
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -423,10 +385,9 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
 async def test_digest_auth_401_response_without_digest_auth_header() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = App(auth_header="", status_code=401)
 
-    async with httpx.AsyncClient(
-        transport=AsyncMockTransport(auth_header=b"", status_code=401)
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 401
@@ -453,10 +414,9 @@ async def test_digest_auth(
 ) -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(algorithm=algorithm)
 
-    async with httpx.AsyncClient(
-        transport=MockDigestAuthTransport(algorithm=algorithm)
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -485,8 +445,9 @@ async def test_digest_auth(
 async def test_digest_auth_no_specified_qop() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(qop="")
 
-    async with httpx.AsyncClient(transport=MockDigestAuthTransport(qop="")) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -516,8 +477,9 @@ async def test_digest_auth_no_specified_qop() -> None:
 async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str) -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(qop=qop)
 
-    async with httpx.AsyncClient(transport=MockDigestAuthTransport(qop=qop)) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -528,10 +490,9 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
 async def test_digest_auth_qop_auth_int_not_implemented() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(qop="auth-int")
 
-    async with httpx.AsyncClient(
-        transport=MockDigestAuthTransport(qop="auth-int")
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         with pytest.raises(NotImplementedError):
             await client.get(url, auth=auth)
 
@@ -540,10 +501,9 @@ async def test_digest_auth_qop_auth_int_not_implemented() -> None:
 async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(qop="not-auth")
 
-    async with httpx.AsyncClient(
-        transport=MockDigestAuthTransport(qop="not-auth")
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         with pytest.raises(ProtocolError):
             await client.get(url, auth=auth)
 
@@ -552,10 +512,9 @@ async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
 async def test_digest_auth_incorrect_credentials() -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = DigestApp(send_response_after_attempt=2)
 
-    async with httpx.AsyncClient(
-        transport=MockDigestAuthTransport(send_response_after_attempt=2)
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 401
@@ -565,19 +524,19 @@ async def test_digest_auth_incorrect_credentials() -> None:
 @pytest.mark.parametrize(
     "auth_header",
     [
-        b'Digest realm="httpx@example.org", qop="auth"',  # missing fields
-        b'Digest realm="httpx@example.org", qop="auth,au',  # malformed fields list
+        'Digest realm="httpx@example.org", qop="auth"',  # missing fields
+        'Digest realm="httpx@example.org", qop="auth,au',  # malformed fields list
     ],
 )
 @pytest.mark.asyncio
 async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
-    auth_header: bytes,
+    auth_header: str,
 ) -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
-    async with httpx.AsyncClient(
-        transport=AsyncMockTransport(auth_header=auth_header, status_code=401)
-    ) as client:
+    app = App(auth_header=auth_header, status_code=401)
+
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         with pytest.raises(ProtocolError):
             await client.get(url, auth=auth)
 
@@ -585,19 +544,18 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
 @pytest.mark.parametrize(
     "auth_header",
     [
-        b'Digest realm="httpx@example.org", qop="auth"',  # missing fields
-        b'Digest realm="httpx@example.org", qop="auth,au',  # malformed fields list
+        'Digest realm="httpx@example.org", qop="auth"',  # missing fields
+        'Digest realm="httpx@example.org", qop="auth,au',  # malformed fields list
     ],
 )
 def test_sync_digest_auth_raises_protocol_error_on_malformed_header(
-    auth_header: bytes,
+    auth_header: str,
 ) -> None:
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = App(auth_header=auth_header, status_code=401)
 
-    with httpx.Client(
-        transport=SyncMockTransport(auth_header=auth_header, status_code=401)
-    ) as client:
+    with httpx.Client(transport=MockTransport(app)) as client:
         with pytest.raises(ProtocolError):
             client.get(url, auth=auth)
 
@@ -610,10 +568,9 @@ async def test_async_auth_history() -> None:
     """
     url = "https://example.org/"
     auth = RepeatAuth(repeat=2)
+    app = App(auth_header="abc")
 
-    async with httpx.AsyncClient(
-        transport=AsyncMockTransport(auth_header=b"abc")
-    ) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -637,8 +594,9 @@ def test_sync_auth_history() -> None:
     """
     url = "https://example.org/"
     auth = RepeatAuth(repeat=2)
+    app = App(auth_header="abc")
 
-    with httpx.Client(transport=SyncMockTransport(auth_header=b"abc")) as client:
+    with httpx.Client(transport=MockTransport(app)) as client:
         response = client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -659,11 +617,12 @@ def test_sync_auth_history() -> None:
 async def test_digest_auth_unavailable_streaming_body():
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
+    app = App()
 
     async def streaming_body():
         yield b"Example request body"  # pragma: nocover
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         with pytest.raises(RequestBodyUnavailable):
             await client.post(url, data=streaming_body(), auth=auth)
 
@@ -676,7 +635,9 @@ async def test_async_auth_reads_response_body() -> None:
     """
     url = "https://example.org/"
     auth = ResponseBodyAuth("xyz")
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    app = App()
+
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -690,8 +651,9 @@ def test_sync_auth_reads_response_body() -> None:
     """
     url = "https://example.org/"
     auth = ResponseBodyAuth("xyz")
+    app = App()
 
-    with httpx.Client(transport=SyncMockTransport()) as client:
+    with httpx.Client(transport=MockTransport(app)) as client:
         response = client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -707,8 +669,9 @@ async def test_async_auth() -> None:
     """
     url = "https://example.org/"
     auth = SyncOrAsyncAuth()
+    app = App()
 
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
@@ -721,8 +684,9 @@ def test_sync_auth() -> None:
     """
     url = "https://example.org/"
     auth = SyncOrAsyncAuth()
+    app = App()
 
-    with httpx.Client(transport=SyncMockTransport()) as client:
+    with httpx.Client(transport=MockTransport(app)) as client:
         response = client.get(url, auth=auth)
 
     assert response.status_code == 200