]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Additional DigestAuth tests (#334)
authorYeray Diaz Diaz <yeraydiazdiaz@gmail.com>
Wed, 11 Sep 2019 12:25:11 +0000 (13:25 +0100)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Wed, 11 Sep 2019 12:25:11 +0000 (07:25 -0500)
tests/client/test_auth.py

index 6478ea03a68210ce12ff3fe4677226b2ad86a70b..fa57ef528ec2cf1178236233773efd8fc5bff0db 100644 (file)
@@ -19,6 +19,10 @@ from httpx import (
 
 
 class MockDispatch(AsyncDispatcher):
+    def __init__(self, auth_header: str = "", status_code: int = 200) -> None:
+        self.auth_header = auth_header
+        self.status_code = status_code
+
     async def send(
         self,
         request: AsyncRequest,
@@ -26,8 +30,11 @@ class MockDispatch(AsyncDispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> AsyncResponse:
+        headers = [("www-authenticate", self.auth_header)] if self.auth_header else []
         body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
-        return AsyncResponse(200, content=body, request=request)
+        return AsyncResponse(
+            self.status_code, headers=headers, content=body, request=request
+        )
 
 
 class MockDigestAuthDispatch(AsyncDispatcher):
@@ -88,21 +95,6 @@ class MockDigestAuthDispatch(AsyncDispatcher):
         self._response_count = 0
 
 
-class MockAuthHeaderDispatch(AsyncDispatcher):
-    def __init__(self, auth_header: str) -> None:
-        self.auth_header = auth_header
-
-    async def send(
-        self,
-        request: AsyncRequest,
-        verify: VerifyTypes = None,
-        cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
-    ) -> AsyncResponse:
-        headers = [("www-authenticate", self.auth_header)]
-        return AsyncResponse(401, headers=headers, content=b"", request=request)
-
-
 def test_basic_auth():
     url = "https://example.org/"
     auth = ("tomchristie", "password123")
@@ -216,6 +208,31 @@ def test_digest_auth_returns_no_auth_if_no_digest_header_in_response():
     assert response.json() == {"auth": None}
 
 
+def test_digest_auth_200_response_including_digest_auth_header_is_returned_as_is():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+    auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
+
+    with Client(
+        dispatch=MockDispatch(auth_header=auth_header, status_code=200)
+    ) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert response.json() == {"auth": None}
+
+
+def test_digest_auth_401_response_without_digest_auth_header_is_returned_as_is():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(dispatch=MockDispatch(auth_header="", status_code=401)) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 401
+    assert response.json() == {"auth": None}
+
+
 @pytest.mark.parametrize(
     "algorithm,expected_hash_length,expected_response_length",
     [
@@ -281,11 +298,12 @@ def test_digest_auth_no_specified_qop():
     assert digest_data["algorithm"] == "SHA-256"
 
 
-def test_digest_auth_qop_including_spaces_and_auth_returns_auth():
+@pytest.mark.parametrize("qop", ("auth, auth-int", "auth,auth-int", "unknown,auth"))
+def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str):
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
 
-    with Client(dispatch=MockDigestAuthDispatch(qop="auth, auth-int")) as client:
+    with Client(dispatch=MockDigestAuthDispatch(qop=qop)) as client:
         client.get(url, auth=auth)
 
 
@@ -334,5 +352,7 @@ def test_digest_auth_raises_protocol_error_on_malformed_header(auth_header: str)
     auth = DigestAuth(username="tomchristie", password="password123")
 
     with pytest.raises(ProtocolError):
-        with Client(dispatch=MockAuthHeaderDispatch(auth_header=auth_header)) as client:
+        with Client(
+            dispatch=MockDispatch(auth_header=auth_header, status_code=401)
+        ) as client:
             client.get(url, auth=auth)