]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Record history of requests made during authentication (#718)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 3 Jan 2020 21:59:16 +0000 (22:59 +0100)
committerGitHub <noreply@github.com>
Fri, 3 Jan 2020 21:59:16 +0000 (22:59 +0100)
* Record history of requests made during authentication

* Add asserts on digest auth history

Co-Authored-By: Gaurav Dhameeja <gdhameeja@gmail.com>
httpx/client.py
tests/client/test_auth.py

index 27a146ae75fc9f9a82ff4e8f494d4b283a3ef948..75f29ac3b7aac6e4cff0ee2c3060fc417f345105 100644 (file)
@@ -451,7 +451,7 @@ class AsyncClient:
                 raise RedirectLoop()
 
             response = await self.send_handling_auth(
-                request, auth=auth, timeout=timeout,
+                request, history, auth=auth, timeout=timeout,
             )
             response.history = list(history)
 
@@ -571,7 +571,11 @@ class AsyncClient:
         return request.stream
 
     async def send_handling_auth(
-        self, request: Request, auth: Auth, timeout: Timeout,
+        self,
+        request: Request,
+        history: typing.List[Response],
+        auth: Auth,
+        timeout: Timeout,
     ) -> Response:
         auth_flow = auth(request)
         request = next(auth_flow)
@@ -585,8 +589,10 @@ class AsyncClient:
                 await response.aclose()
                 raise exc from None
             else:
+                response.history = list(history)
+                await response.aread()
                 request = next_request
-                await response.aclose()
+                history.append(response)
 
     async def send_single_request(
         self, request: Request, timeout: Timeout,
index d4dd76a905c2e4f12c2c0bb4761d23b18d048f4e..ea6ff8acac860741e49f260c4fe40f368cefa647 100644 (file)
@@ -6,6 +6,7 @@ import typing
 import pytest
 
 from httpx import URL, AsyncClient, DigestAuth, ProtocolError, Request, Response
+from httpx.auth import Auth, AuthFlow
 from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
 from httpx.dispatch.base import Dispatcher
 
@@ -218,6 +219,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
 
     assert response.status_code == 200
     assert response.json() == {"auth": None}
+    assert len(response.history) == 0
 
 
 @pytest.mark.asyncio
@@ -233,6 +235,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
 
     assert response.status_code == 200
     assert response.json() == {"auth": None}
+    assert len(response.history) == 0
 
 
 @pytest.mark.asyncio
@@ -245,6 +248,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
 
     assert response.status_code == 401
     assert response.json() == {"auth": None}
+    assert len(response.history) == 0
 
 
 @pytest.mark.parametrize(
@@ -271,6 +275,8 @@ async def test_digest_auth(
     response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
+    assert len(response.history) == 1
+
     authorization = typing.cast(dict, response.json())["auth"]
     scheme, _, fields = authorization.partition(" ")
     assert scheme == "Digest"
@@ -299,6 +305,8 @@ async def test_digest_auth_no_specified_qop() -> None:
     response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
+    assert len(response.history) == 1
+
     authorization = typing.cast(dict, response.json())["auth"]
     scheme, _, fields = authorization.partition(" ")
     assert scheme == "Digest"
@@ -325,7 +333,10 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
     auth = DigestAuth(username="tomchristie", password="password123")
 
     client = AsyncClient(dispatch=MockDigestAuthDispatch(qop=qop))
-    await client.get(url, auth=auth)
+    response = await client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -357,6 +368,7 @@ async def test_digest_auth_incorrect_credentials() -> None:
     response = await client.get(url, auth=auth)
 
     assert response.status_code == 401
+    assert len(response.history) == 1
 
 
 @pytest.mark.parametrize(
@@ -381,3 +393,52 @@ async def test_digest_auth_raises_protocol_error_on_malformed_header(
 
     with pytest.raises(ProtocolError):
         await client.get(url, auth=auth)
+
+
+@pytest.mark.asyncio
+async def test_auth_history() -> None:
+    """
+    Test that intermediate requests sent as part of an authentication flow
+    are recorded in the response history.
+    """
+
+    class RepeatAuth(Auth):
+        """
+        A mock authentication scheme that requires clients to send
+        the request a fixed number of times, and then send a last request containing
+        an aggregation of nonces that the server sent in 'WWW-Authenticate' headers
+        of intermediate responses.
+        """
+
+        def __init__(self, repeat: int):
+            self.repeat = repeat
+
+        def __call__(self, request: Request) -> AuthFlow:
+            nonces = []
+
+            for index in range(self.repeat):
+                request.headers["Authorization"] = f"Repeat {index}"
+                response = yield request
+                nonces.append(response.headers["www-authenticate"])
+
+            key = ".".join(nonces)
+            request.headers["Authorization"] = f"Repeat {key}"
+            yield request
+
+    url = "https://example.org/"
+    auth = RepeatAuth(repeat=2)
+    client = AsyncClient(dispatch=MockDispatch(auth_header="abc"))
+
+    response = await client.get(url, auth=auth)
+    assert response.status_code == 200
+    assert response.json() == {"auth": "Repeat abc.abc"}
+
+    assert len(response.history) == 2
+    resp1, resp2 = response.history
+    assert resp1.json() == {"auth": "Repeat 0"}
+    assert resp2.json() == {"auth": "Repeat 1"}
+
+    assert len(resp2.history) == 1
+    assert resp2.history == [resp1]
+
+    assert len(resp1.history) == 0