--- /dev/null
+import hashlib
+import os
+import re
+import time
+import typing
+from urllib.request import parse_http_list
+
+from ..exceptions import ProtocolError
+from ..models import AsyncRequest, AsyncResponse, StatusCode
+from ..utils import to_bytes, to_str, unquote
+from .base import BaseMiddleware
+
+
+class DigestAuth(BaseMiddleware):
+ ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
+ "MD5": hashlib.md5,
+ "MD5-SESS": hashlib.md5,
+ "SHA": hashlib.sha1,
+ "SHA-SESS": hashlib.sha1,
+ "SHA-256": hashlib.sha256,
+ "SHA-256-SESS": hashlib.sha256,
+ "SHA-512": hashlib.sha512,
+ "SHA-512-SESS": hashlib.sha512,
+ }
+
+ def __init__(
+ self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+ ) -> None:
+ self.username = to_bytes(username)
+ self.password = to_bytes(password)
+
+ async def __call__(
+ self, request: AsyncRequest, get_response: typing.Callable
+ ) -> AsyncResponse:
+ response = await get_response(request)
+ if not (
+ StatusCode.is_client_error(response.status_code)
+ and "www-authenticate" in response.headers
+ ):
+ return response
+
+ header = response.headers["www-authenticate"]
+ try:
+ challenge = DigestAuthChallenge.from_header(header)
+ except ValueError:
+ raise ProtocolError("Malformed Digest authentication header")
+
+ request.headers["Authorization"] = self._build_auth_header(request, challenge)
+ return await get_response(request)
+
+ def _build_auth_header(
+ self, request: AsyncRequest, challenge: "DigestAuthChallenge"
+ ) -> str:
+ hash_func = self.ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]
+
+ def digest(data: bytes) -> bytes:
+ return hash_func(data).hexdigest().encode()
+
+ A1 = b":".join((self.username, challenge.realm, self.password))
+
+ path = request.url.full_path.encode("utf-8")
+ A2 = b":".join((request.method.encode(), path))
+ # TODO: implement auth-int
+ HA2 = digest(A2)
+
+ nonce_count = 1 # TODO: implement nonce counting
+ nc_value = b"%08x" % nonce_count
+ cnonce = self._get_client_nonce(nonce_count, challenge.nonce)
+
+ HA1 = digest(A1)
+ if challenge.algorithm.lower().endswith("-sess"):
+ HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))
+
+ qop = self._resolve_qop(challenge.qop)
+ if qop is None:
+ digest_data = [HA1, challenge.nonce, HA2]
+ else:
+ digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
+ key_digest = b":".join(digest_data)
+
+ format_args = {
+ "username": self.username,
+ "realm": challenge.realm,
+ "nonce": challenge.nonce,
+ "uri": path,
+ "response": digest(b":".join((HA1, key_digest))),
+ "algorithm": challenge.algorithm.encode(),
+ }
+ if challenge.opaque:
+ format_args["opaque"] = challenge.opaque
+ if qop:
+ format_args["qop"] = b"auth"
+ format_args["nc"] = nc_value
+ format_args["cnonce"] = cnonce
+
+ return "Digest " + self._get_header_value(format_args)
+
+ def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
+ s = str(nonce_count).encode()
+ s += nonce
+ s += time.ctime().encode()
+ s += os.urandom(8)
+
+ return hashlib.sha1(s).hexdigest()[:16].encode()
+
+ def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
+ NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
+ QUOTED_TEMPLATE = '{}="{}"'
+ NON_QUOTED_TEMPLATE = "{}={}"
+
+ header_value = ""
+ for i, (field, value) in enumerate(header_fields.items()):
+ if i > 0:
+ header_value += ", "
+ template = (
+ QUOTED_TEMPLATE
+ if field not in NON_QUOTED_FIELDS
+ else NON_QUOTED_TEMPLATE
+ )
+ header_value += template.format(field, to_str(value))
+
+ return header_value
+
+ def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]:
+ if qop is None:
+ return None
+ qops = re.split(b", ?", qop)
+ if b"auth" in qops:
+ return b"auth"
+
+ if qops == [b"auth-int"]:
+ raise NotImplementedError("Digest auth-int support is not yet implemented")
+
+ raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth')
+
+
+class DigestAuthChallenge:
+ def __init__(
+ self,
+ realm: bytes,
+ nonce: bytes,
+ algorithm: str = None,
+ opaque: typing.Optional[bytes] = None,
+ qop: typing.Optional[bytes] = None,
+ ) -> None:
+ self.realm = realm
+ self.nonce = nonce
+ self.algorithm = algorithm or "MD5"
+ self.opaque = opaque
+ self.qop = qop
+
+ @classmethod
+ def from_header(cls, header: str) -> "DigestAuthChallenge":
+ """Returns a challenge from a Digest WWW-Authenticate header.
+ These take the form of:
+ `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
+ """
+ scheme, _, fields = header.partition(" ")
+ if scheme.lower() != "digest":
+ raise ValueError("Header does not start with 'Digest'")
+
+ header_dict: typing.Dict[str, str] = {}
+ for field in parse_http_list(fields):
+ key, value = field.strip().split("=", 1)
+ header_dict[key] = unquote(value)
+
+ try:
+ return cls.from_header_dict(header_dict)
+ except KeyError as exc:
+ raise ValueError("Malformed Digest WWW-Authenticate header") from exc
+
+ @classmethod
+ def from_header_dict(cls, header_dict: dict) -> "DigestAuthChallenge":
+ realm = header_dict["realm"].encode()
+ nonce = header_dict["nonce"].encode()
+ qop = header_dict["qop"].encode() if "qop" in header_dict else None
+ opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
+ algorithm = header_dict.get("algorithm")
+ return cls(
+ realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
+ )
+import hashlib
import json
import os
AsyncResponse,
CertTypes,
Client,
+ DigestAuth,
+ ProtocolError,
TimeoutTypes,
VerifyTypes,
)
return AsyncResponse(200, content=body, request=request)
+class MockDigestAuthDispatch(AsyncDispatcher):
+ def __init__(
+ self,
+ algorithm: str = "SHA-256",
+ send_response_after_attempt: int = 1,
+ qop: str = "auth",
+ regenerate_nonce: bool = True,
+ ) -> None:
+ self.algorithm = algorithm
+ self.send_response_after_attempt = send_response_after_attempt
+ self.qop = qop
+ self._regenerate_nonce = regenerate_nonce
+ self._response_count = 0
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+ if self._response_count < self.send_response_after_attempt:
+ return self.challenge_send(request)
+
+ body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
+ return AsyncResponse(200, content=body, request=request)
+
+ def challenge_send(self, request: AsyncRequest) -> AsyncResponse:
+ self._response_count += 1
+ nonce = (
+ hashlib.sha256(os.urandom(8)).hexdigest()
+ if self._regenerate_nonce
+ else "ee96edced2a0b43e4869e96ebe27563f369c1205a049d06419bb51d8aeddf3d3"
+ )
+ challenge_data = {
+ "nonce": nonce,
+ "qop": self.qop,
+ "opaque": (
+ "ee6378f3ee14ebfd2fff54b70a91a7c9390518047f242ab2271380db0e14bda1"
+ ),
+ "algorithm": self.algorithm,
+ "stale": "FALSE",
+ }
+ challenge_str = ", ".join(
+ '{}="{}"'.format(key, value)
+ for key, value in challenge_data.items()
+ if value
+ )
+
+ headers = [
+ ("www-authenticate", 'Digest realm="httpx@example.org", ' + challenge_str)
+ ]
+ return AsyncResponse(401, headers=headers, content=b"", request=request)
+
+ def reset(self) -> None:
+ self._response_count = 0
+
+
+class MockAuthHeaderDispatch(AsyncDispatcher):
+ def __init__(self, auth_header: str) -> None:
+ self.auth_header = auth_header
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+ headers = [("www-authenticate", self.auth_header)]
+ return AsyncResponse(401, headers=headers, content=b"", request=request)
+
+
def test_basic_auth():
url = "https://example.org/"
auth = ("tomchristie", "password123")
with Client(dispatch=MockDispatch(), auth="not a tuple, not a callable") as client:
with pytest.raises(TypeError):
client.get(url)
+
+
+def test_digest_auth_returns_no_auth_if_no_digest_header_in_response():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.get(url, auth=auth)
+
+ assert response.status_code == 200
+ assert response.json() == {"auth": None}
+
+
+@pytest.mark.parametrize(
+ "algorithm,expected_hash_length,expected_response_length",
+ [
+ ("MD5", 64, 32),
+ ("MD5-SESS", 64, 32),
+ ("SHA", 64, 40),
+ ("SHA-SESS", 64, 40),
+ ("SHA-256", 64, 64),
+ ("SHA-256-SESS", 64, 64),
+ ("SHA-512", 64, 128),
+ ("SHA-512-SESS", 64, 128),
+ ],
+)
+def test_digest_auth(algorithm, expected_hash_length, expected_response_length):
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with Client(dispatch=MockDigestAuthDispatch(algorithm=algorithm)) as client:
+ response = client.get(url, auth=auth)
+
+ assert response.status_code == 200
+ auth = response.json()["auth"]
+ assert auth.startswith("Digest ")
+
+ response_fields = [field.strip() for field in auth[auth.find(" ") :].split(",")]
+ digest_data = dict(field.split("=") for field in response_fields)
+
+ assert digest_data["username"] == '"tomchristie"'
+ assert digest_data["realm"] == '"httpx@example.org"'
+ assert "nonce" in digest_data
+ assert digest_data["uri"] == '"/"'
+ assert len(digest_data["response"]) == expected_response_length + 2 # extra quotes
+ assert len(digest_data["opaque"]) == expected_hash_length + 2
+ assert digest_data["algorithm"] == algorithm
+ assert digest_data["qop"] == "auth"
+ assert digest_data["nc"] == "00000001"
+ assert len(digest_data["cnonce"]) == 16 + 2
+
+
+def test_digest_auth_no_specified_qop():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with Client(dispatch=MockDigestAuthDispatch(qop=None)) as client:
+ response = client.get(url, auth=auth)
+
+ assert response.status_code == 200
+ auth = response.json()["auth"]
+ assert auth.startswith("Digest ")
+
+ response_fields = [field.strip() for field in auth[auth.find(" ") :].split(",")]
+ digest_data = dict(field.split("=") for field in response_fields)
+
+ assert "qop" not in digest_data
+ assert "nc" not in digest_data
+ assert "cnonce" not in digest_data
+ assert digest_data["username"] == '"tomchristie"'
+ assert digest_data["realm"] == '"httpx@example.org"'
+ assert len(digest_data["nonce"]) == 64 + 2 # extra quotes
+ assert digest_data["uri"] == '"/"'
+ assert len(digest_data["response"]) == 64 + 2
+ assert len(digest_data["opaque"]) == 64 + 2
+ assert digest_data["algorithm"] == "SHA-256"
+
+
+def test_digest_auth_qop_including_spaces_and_auth_returns_auth():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with Client(dispatch=MockDigestAuthDispatch(qop="auth, auth-int")) as client:
+ client.get(url, auth=auth)
+
+
+def test_digest_auth_qop_auth_int_not_implemented():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with pytest.raises(NotImplementedError):
+ with Client(dispatch=MockDigestAuthDispatch(qop="auth-int")) as client:
+ client.get(url, auth=auth)
+
+
+def test_digest_auth_qop_must_be_auth_or_auth_int():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with pytest.raises(ProtocolError):
+ with Client(dispatch=MockDigestAuthDispatch(qop="not-auth")) as client:
+ client.get(url, auth=auth)
+
+
+def test_digest_auth_incorrect_credentials():
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with Client(
+ dispatch=MockDigestAuthDispatch(send_response_after_attempt=2)
+ ) as client:
+ response = client.get(url, auth=auth)
+
+ assert response.status_code == 401
+
+
+@pytest.mark.parametrize(
+ "auth_header",
+ [
+ 'Digest realm="httpx@example.org", qop="auth"', # missing fields
+ 'realm="httpx@example.org", qop="auth"', # not starting with Digest
+ 'DigestZ realm="httpx@example.org", qop="auth"'
+ 'qop="auth,auth-int",nonce="abc",opaque="xyz"',
+ 'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
+ ],
+)
+def test_digest_auth_raises_protocol_error_on_malformed_header(auth_header: str):
+ url = "https://example.org/"
+ auth = DigestAuth(username="tomchristie", password="password123")
+
+ with pytest.raises(ProtocolError):
+ with Client(dispatch=MockAuthHeaderDispatch(auth_header=auth_header)) as client:
+ client.get(url, auth=auth)