]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
reuse the digest auth state to avoid unnecessary requests (#2463)
authorrettier <rettier@users.noreply.github.com>
Tue, 29 Nov 2022 17:05:37 +0000 (18:05 +0100)
committerGitHub <noreply@github.com>
Tue, 29 Nov 2022 17:05:37 +0000 (17:05 +0000)
* reuse the digest auth challenge to avoid sending twice as many requests

* fix for digest testcase

* ran testing/linting scripts

* codereview changes, removed tomchristie username from all authentication tests

Co-authored-by: Philipp Reitter <p.reitter@accessio.at>
Co-authored-by: Tom Christie <tom@tomchristie.com>
httpx/_auth.py
tests/client/test_auth.py
tests/test_auth.py

index 0f54be9b407a5e847916f4c8dfe0bb809bc102be..b3b7a19851a23e76dafad27e8a8d1c9c91aad7b1 100644 (file)
@@ -158,8 +158,15 @@ class DigestAuth(Auth):
     ) -> None:
         self._username = to_bytes(username)
         self._password = to_bytes(password)
+        self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
+        self._nonce_count = 1
 
     def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
+        if self._last_challenge:
+            request.headers["Authorization"] = self._build_auth_header(
+                request, self._last_challenge
+            )
+
         response = yield request
 
         if response.status_code != 401 or "www-authenticate" not in response.headers:
@@ -175,8 +182,12 @@ class DigestAuth(Auth):
             # header, then we don't need to build an authenticated request.
             return
 
-        challenge = self._parse_challenge(request, response, auth_header)
-        request.headers["Authorization"] = self._build_auth_header(request, challenge)
+        self._last_challenge = self._parse_challenge(request, response, auth_header)
+        self._nonce_count = 1
+
+        request.headers["Authorization"] = self._build_auth_header(
+            request, self._last_challenge
+        )
         yield request
 
     def _parse_challenge(
@@ -225,9 +236,9 @@ class DigestAuth(Auth):
         # TODO: implement auth-int
         HA2 = digest(A2)
 
-        nonce_count = 1  # TODO: implement nonce counting
-        nc_value = b"%08x" % nonce_count
-        cnonce = self._get_client_nonce(nonce_count, challenge.nonce)
+        nc_value = b"%08x" % self._nonce_count
+        cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
+        self._nonce_count += 1
 
         HA1 = digest(A1)
         if challenge.algorithm.lower().endswith("-sess"):
index 735205c3aa33b4b6b7eb253645cc75d65772a6df..19ffcfcfc7c0833a48e32c6f88c895f2b546df83 100644 (file)
@@ -8,6 +8,7 @@ import hashlib
 import os
 import threading
 import typing
+from urllib.request import parse_keqv_list
 
 import pytest
 
@@ -151,14 +152,14 @@ class SyncOrAsyncAuth(Auth):
 @pytest.mark.asyncio
 async def test_basic_auth() -> None:
     url = "https://example.org/"
-    auth = ("tomchristie", "password123")
+    auth = ("user", "password123")
     app = App()
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
         response = await client.get(url, auth=auth)
 
     assert response.status_code == 200
-    assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+    assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
 
 
 @pytest.mark.asyncio
@@ -167,7 +168,7 @@ async def test_basic_auth_with_stream() -> None:
     See: https://github.com/encode/httpx/pull/1312
     """
     url = "https://example.org/"
-    auth = ("tomchristie", "password123")
+    auth = ("user", "password123")
     app = App()
 
     async with httpx.AsyncClient(
@@ -177,25 +178,25 @@ async def test_basic_auth_with_stream() -> None:
             await response.aread()
 
     assert response.status_code == 200
-    assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+    assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
 
 
 @pytest.mark.asyncio
 async def test_basic_auth_in_url() -> None:
-    url = "https://tomchristie:password123@example.org/"
+    url = "https://user:password123@example.org/"
     app = App()
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
         response = await client.get(url)
 
     assert response.status_code == 200
-    assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+    assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
 
 
 @pytest.mark.asyncio
 async def test_basic_auth_on_session() -> None:
     url = "https://example.org/"
-    auth = ("tomchristie", "password123")
+    auth = ("user", "password123")
     app = App()
 
     async with httpx.AsyncClient(
@@ -204,7 +205,7 @@ async def test_basic_auth_on_session() -> None:
         response = await client.get(url)
 
     assert response.status_code == 200
-    assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+    assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
 
 
 @pytest.mark.asyncio
@@ -279,7 +280,7 @@ async def test_trust_env_auth() -> None:
 @pytest.mark.asyncio
 async def test_auth_disable_per_request() -> None:
     url = "https://example.org/"
-    auth = ("tomchristie", "password123")
+    auth = ("user", "password123")
     app = App()
 
     async with httpx.AsyncClient(
@@ -317,13 +318,13 @@ async def test_auth_property() -> None:
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
         assert client.auth is None
 
-        client.auth = ("tomchristie", "password123")  # type: ignore
+        client.auth = ("user", "password123")  # type: ignore
         assert isinstance(client.auth, BasicAuth)
 
         url = "https://example.org/"
         response = await client.get(url)
         assert response.status_code == 200
-        assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+        assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
 
 
 @pytest.mark.asyncio
@@ -347,7 +348,7 @@ async def test_auth_invalid_type() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = App()
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -360,7 +361,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
 
 def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     auth_header = "Token ..."
     app = App(auth_header=auth_header, status_code=401)
 
@@ -375,7 +376,7 @@ def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_200_response_including_digest_auth_header() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
     app = App(auth_header=auth_header, status_code=200)
 
@@ -390,7 +391,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_401_response_without_digest_auth_header() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = App(auth_header="", status_code=401)
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -419,7 +420,7 @@ async def test_digest_auth(
     algorithm: str, expected_hash_length: int, expected_response_length: int
 ) -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(algorithm=algorithm)
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -435,7 +436,7 @@ async def test_digest_auth(
     response_fields = [field.strip() for field in fields.split(",")]
     digest_data = dict(field.split("=") for field in response_fields)
 
-    assert digest_data["username"] == '"tomchristie"'
+    assert digest_data["username"] == '"user"'
     assert digest_data["realm"] == '"httpx@example.org"'
     assert "nonce" in digest_data
     assert digest_data["uri"] == '"/"'
@@ -450,7 +451,7 @@ async def test_digest_auth(
 @pytest.mark.asyncio
 async def test_digest_auth_no_specified_qop() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(qop="")
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -469,7 +470,7 @@ async def test_digest_auth_no_specified_qop() -> None:
     assert "qop" not in digest_data
     assert "nc" not in digest_data
     assert "cnonce" not in digest_data
-    assert digest_data["username"] == '"tomchristie"'
+    assert digest_data["username"] == '"user"'
     assert digest_data["realm"] == '"httpx@example.org"'
     assert len(digest_data["nonce"]) == 64 + 2  # extra quotes
     assert digest_data["uri"] == '"/"'
@@ -482,7 +483,7 @@ async def test_digest_auth_no_specified_qop() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str) -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(qop=qop)
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -495,7 +496,7 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
 @pytest.mark.asyncio
 async def test_digest_auth_qop_auth_int_not_implemented() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(qop="auth-int")
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -506,7 +507,7 @@ async def test_digest_auth_qop_auth_int_not_implemented() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(qop="not-auth")
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -517,7 +518,7 @@ async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
 @pytest.mark.asyncio
 async def test_digest_auth_incorrect_credentials() -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp(send_response_after_attempt=2)
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -527,6 +528,62 @@ async def test_digest_auth_incorrect_credentials() -> None:
     assert len(response.history) == 1
 
 
+@pytest.mark.asyncio
+async def test_digest_auth_reuses_challenge() -> None:
+    url = "https://example.org/"
+    auth = DigestAuth(username="user", password="password123")
+    app = DigestApp()
+
+    async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
+        response_1 = await client.get(url, auth=auth)
+        response_2 = await client.get(url, auth=auth)
+
+        assert response_1.status_code == 200
+        assert response_2.status_code == 200
+
+        assert len(response_1.history) == 1
+        assert len(response_2.history) == 0
+
+
+@pytest.mark.asyncio
+async def test_digest_auth_resets_nonce_count_after_401() -> None:
+    url = "https://example.org/"
+    auth = DigestAuth(username="user", password="password123")
+    app = DigestApp()
+
+    async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
+        response_1 = await client.get(url, auth=auth)
+        assert response_1.status_code == 200
+        assert len(response_1.history) == 1
+
+        first_nonce = parse_keqv_list(
+            response_1.request.headers["Authorization"].split(", ")
+        )["nonce"]
+        first_nc = parse_keqv_list(
+            response_1.request.headers["Authorization"].split(", ")
+        )["nc"]
+
+        # with this we now force a 401 on a subsequent (but initial) request
+        app.send_response_after_attempt = 2
+
+        # we expect the client again to try to authenticate, i.e. the history length must be 1
+        response_2 = await client.get(url, auth=auth)
+        assert response_2.status_code == 200
+        assert len(response_2.history) == 1
+
+        second_nonce = parse_keqv_list(
+            response_2.request.headers["Authorization"].split(", ")
+        )["nonce"]
+        second_nc = parse_keqv_list(
+            response_2.request.headers["Authorization"].split(", ")
+        )["nc"]
+
+    assert first_nonce != second_nonce  # ensures that the auth challenge was reset
+    assert (
+        first_nc == second_nc
+    )  # ensures the nonce count is reset when the authentication failed
+
+
 @pytest.mark.parametrize(
     "auth_header",
     [
@@ -539,7 +596,7 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
     auth_header: str,
 ) -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = App(auth_header=auth_header, status_code=401)
 
     async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@@ -558,7 +615,7 @@ def test_sync_digest_auth_raises_protocol_error_on_malformed_header(
     auth_header: str,
 ) -> None:
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = App(auth_header=auth_header, status_code=401)
 
     with httpx.Client(transport=httpx.MockTransport(app)) as client:
@@ -629,7 +686,7 @@ class ConsumeBodyTransport(httpx.MockTransport):
 @pytest.mark.asyncio
 async def test_digest_auth_unavailable_streaming_body():
     url = "https://example.org/"
-    auth = DigestAuth(username="tomchristie", password="password123")
+    auth = DigestAuth(username="user", password="password123")
     app = DigestApp()
 
     async def streaming_body():
index 20c666a88cc667e7940b9717617015eb47c47b2e..a1997c2fe25d2a0d2051ec494c2af1465cfa48b6 100644 (file)
@@ -3,6 +3,8 @@ Unit tests for auth classes.
 
 Integration tests also exist in tests/client/test_auth.py
 """
+from urllib.request import parse_keqv_list
+
 import pytest
 
 import httpx
@@ -61,3 +63,41 @@ def test_digest_auth_with_401():
     response = httpx.Response(content=b"Hello, world!", status_code=200)
     with pytest.raises(StopIteration):
         flow.send(response)
+
+
+def test_digest_auth_with_401_nonce_counting():
+    auth = httpx.DigestAuth(username="user", password="pass")
+    request = httpx.Request("GET", "https://www.example.com")
+
+    # The initial request should not include an auth header.
+    flow = auth.sync_auth_flow(request)
+    request = next(flow)
+    assert "Authorization" not in request.headers
+
+    # If a 401 response is returned, then a digest auth request is made.
+    headers = {
+        "WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
+    }
+    response = httpx.Response(
+        content=b"Auth required", status_code=401, headers=headers
+    )
+    first_request = flow.send(response)
+    assert first_request.headers["Authorization"].startswith("Digest")
+
+    # Each subsequent request contains the digest header by default...
+    request = httpx.Request("GET", "https://www.example.com")
+    flow = auth.sync_auth_flow(request)
+    second_request = next(flow)
+    assert second_request.headers["Authorization"].startswith("Digest")
+
+    # ... and the client nonce count (nc) is increased
+    first_nc = parse_keqv_list(first_request.headers["Authorization"].split(", "))["nc"]
+    second_nc = parse_keqv_list(second_request.headers["Authorization"].split(", "))[
+        "nc"
+    ]
+    assert int(first_nc, 16) + 1 == int(second_nc, 16)
+
+    # No other requests are made.
+    response = httpx.Response(content=b"Hello, world!", status_code=200)
+    with pytest.raises(StopIteration):
+        flow.send(response)