]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add DigestAuth middleware (#332)
authorYeray Diaz Diaz <yeraydiazdiaz@gmail.com>
Tue, 10 Sep 2019 19:53:39 +0000 (20:53 +0100)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Tue, 10 Sep 2019 19:53:39 +0000 (14:53 -0500)
* Remove global variable, just return response from auth request

* Add extra space to Digest header start assertion

* Prevent unpacking errors limiting the number of splits

httpx/__init__.py
httpx/client.py
httpx/middleware/digest_auth.py [new file with mode: 0644]
httpx/models.py
httpx/utils.py
tests/client/test_auth.py

index 255ee96f953127a9806546dba717cbf46cf3c81b..bd58a005af1596c45d33f191f252e49d5f61765c 100644 (file)
@@ -40,6 +40,7 @@ from .exceptions import (
     TooManyRedirects,
     WriteTimeout,
 )
+from .middleware.digest_auth import DigestAuth
 from .models import (
     URL,
     AsyncRequest,
@@ -133,4 +134,5 @@ __all__ = [
     "Response",
     "ResponseContent",
     "RequestFiles",
+    "DigestAuth",
 ]
index b3eaff883ed92e5909c12f18fb5bf0a54a401ba7..430f072442e6a22518ce5fcb4abaf38fff8d1fff 100644 (file)
@@ -211,8 +211,9 @@ class BaseClient:
     ) -> typing.Optional[BaseMiddleware]:
         if isinstance(auth, tuple):
             return BasicAuthMiddleware(username=auth[0], password=auth[1])
-
-        if callable(auth):
+        elif isinstance(auth, BaseMiddleware):
+            return auth
+        elif callable(auth):
             return CustomAuthMiddleware(auth=auth)
 
         if auth is not None:
diff --git a/httpx/middleware/digest_auth.py b/httpx/middleware/digest_auth.py
new file mode 100644 (file)
index 0000000..fb13973
--- /dev/null
@@ -0,0 +1,181 @@
+import hashlib
+import os
+import re
+import time
+import typing
+from urllib.request import parse_http_list
+
+from ..exceptions import ProtocolError
+from ..models import AsyncRequest, AsyncResponse, StatusCode
+from ..utils import to_bytes, to_str, unquote
+from .base import BaseMiddleware
+
+
+class DigestAuth(BaseMiddleware):
+    ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
+        "MD5": hashlib.md5,
+        "MD5-SESS": hashlib.md5,
+        "SHA": hashlib.sha1,
+        "SHA-SESS": hashlib.sha1,
+        "SHA-256": hashlib.sha256,
+        "SHA-256-SESS": hashlib.sha256,
+        "SHA-512": hashlib.sha512,
+        "SHA-512-SESS": hashlib.sha512,
+    }
+
+    def __init__(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ) -> None:
+        self.username = to_bytes(username)
+        self.password = to_bytes(password)
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        response = await get_response(request)
+        if not (
+            StatusCode.is_client_error(response.status_code)
+            and "www-authenticate" in response.headers
+        ):
+            return response
+
+        header = response.headers["www-authenticate"]
+        try:
+            challenge = DigestAuthChallenge.from_header(header)
+        except ValueError:
+            raise ProtocolError("Malformed Digest authentication header")
+
+        request.headers["Authorization"] = self._build_auth_header(request, challenge)
+        return await get_response(request)
+
+    def _build_auth_header(
+        self, request: AsyncRequest, challenge: "DigestAuthChallenge"
+    ) -> str:
+        hash_func = self.ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]
+
+        def digest(data: bytes) -> bytes:
+            return hash_func(data).hexdigest().encode()
+
+        A1 = b":".join((self.username, challenge.realm, self.password))
+
+        path = request.url.full_path.encode("utf-8")
+        A2 = b":".join((request.method.encode(), path))
+        # TODO: implement auth-int
+        HA2 = digest(A2)
+
+        nonce_count = 1  # TODO: implement nonce counting
+        nc_value = b"%08x" % nonce_count
+        cnonce = self._get_client_nonce(nonce_count, challenge.nonce)
+
+        HA1 = digest(A1)
+        if challenge.algorithm.lower().endswith("-sess"):
+            HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))
+
+        qop = self._resolve_qop(challenge.qop)
+        if qop is None:
+            digest_data = [HA1, challenge.nonce, HA2]
+        else:
+            digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
+        key_digest = b":".join(digest_data)
+
+        format_args = {
+            "username": self.username,
+            "realm": challenge.realm,
+            "nonce": challenge.nonce,
+            "uri": path,
+            "response": digest(b":".join((HA1, key_digest))),
+            "algorithm": challenge.algorithm.encode(),
+        }
+        if challenge.opaque:
+            format_args["opaque"] = challenge.opaque
+        if qop:
+            format_args["qop"] = b"auth"
+            format_args["nc"] = nc_value
+            format_args["cnonce"] = cnonce
+
+        return "Digest " + self._get_header_value(format_args)
+
+    def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
+        s = str(nonce_count).encode()
+        s += nonce
+        s += time.ctime().encode()
+        s += os.urandom(8)
+
+        return hashlib.sha1(s).hexdigest()[:16].encode()
+
+    def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
+        NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
+        QUOTED_TEMPLATE = '{}="{}"'
+        NON_QUOTED_TEMPLATE = "{}={}"
+
+        header_value = ""
+        for i, (field, value) in enumerate(header_fields.items()):
+            if i > 0:
+                header_value += ", "
+            template = (
+                QUOTED_TEMPLATE
+                if field not in NON_QUOTED_FIELDS
+                else NON_QUOTED_TEMPLATE
+            )
+            header_value += template.format(field, to_str(value))
+
+        return header_value
+
+    def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]:
+        if qop is None:
+            return None
+        qops = re.split(b", ?", qop)
+        if b"auth" in qops:
+            return b"auth"
+
+        if qops == [b"auth-int"]:
+            raise NotImplementedError("Digest auth-int support is not yet implemented")
+
+        raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth')
+
+
+class DigestAuthChallenge:
+    def __init__(
+        self,
+        realm: bytes,
+        nonce: bytes,
+        algorithm: str = None,
+        opaque: typing.Optional[bytes] = None,
+        qop: typing.Optional[bytes] = None,
+    ) -> None:
+        self.realm = realm
+        self.nonce = nonce
+        self.algorithm = algorithm or "MD5"
+        self.opaque = opaque
+        self.qop = qop
+
+    @classmethod
+    def from_header(cls, header: str) -> "DigestAuthChallenge":
+        """Returns a challenge from a Digest WWW-Authenticate header.
+        These take the form of:
+        `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
+        """
+        scheme, _, fields = header.partition(" ")
+        if scheme.lower() != "digest":
+            raise ValueError("Header does not start with 'Digest'")
+
+        header_dict: typing.Dict[str, str] = {}
+        for field in parse_http_list(fields):
+            key, value = field.strip().split("=", 1)
+            header_dict[key] = unquote(value)
+
+        try:
+            return cls.from_header_dict(header_dict)
+        except KeyError as exc:
+            raise ValueError("Malformed Digest WWW-Authenticate header") from exc
+
+    @classmethod
+    def from_header_dict(cls, header_dict: dict) -> "DigestAuthChallenge":
+        realm = header_dict["realm"].encode()
+        nonce = header_dict["nonce"].encode()
+        qop = header_dict["qop"].encode() if "qop" in header_dict else None
+        opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
+        algorithm = header_dict.get("algorithm")
+        return cls(
+            realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
+        )
index d0bbc4180cb8127ec5087e808cf49a024f694df5..29fc1b931a07e5636edd2a56df8ef4c34509bcbe 100644 (file)
@@ -39,6 +39,9 @@ from .utils import (
     str_query_param,
 )
 
+if typing.TYPE_CHECKING:
+    from .middleware.base import BaseMiddleware  # noqa: F401
+
 PrimitiveData = typing.Optional[typing.Union[str, int, float, bool]]
 
 URLTypes = typing.Union["URL", str]
@@ -61,6 +64,7 @@ CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
 AuthTypes = typing.Union[
     typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
     typing.Callable[["AsyncRequest"], "AsyncRequest"],
+    "BaseMiddleware",
 ]
 
 AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
index b80003ecc4cff2c86f788ef2c36caeadc755fe62..f6b83c89dbd2aaad71244ad2a2dd69a7b7ee67ec 100644 (file)
@@ -173,3 +173,13 @@ def get_logger(name: str) -> logging.Logger:
 
 def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
     return value.encode(encoding) if isinstance(value, str) else value
+
+
+def to_str(str_or_bytes: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
+    return (
+        str_or_bytes if isinstance(str_or_bytes, str) else str_or_bytes.decode(encoding)
+    )
+
+
+def unquote(value: str) -> str:
+    return value[1:-1] if value[0] == value[-1] == '"' else value
index fc3b192f496d17ff5bd2d28a0c62903b5be36450..6478ea03a68210ce12ff3fe4677226b2ad86a70b 100644 (file)
@@ -1,3 +1,4 @@
+import hashlib
 import json
 import os
 
@@ -10,6 +11,8 @@ from httpx import (
     AsyncResponse,
     CertTypes,
     Client,
+    DigestAuth,
+    ProtocolError,
     TimeoutTypes,
     VerifyTypes,
 )
@@ -27,6 +30,79 @@ class MockDispatch(AsyncDispatcher):
         return AsyncResponse(200, content=body, request=request)
 
 
+class MockDigestAuthDispatch(AsyncDispatcher):
+    def __init__(
+        self,
+        algorithm: str = "SHA-256",
+        send_response_after_attempt: int = 1,
+        qop: str = "auth",
+        regenerate_nonce: bool = True,
+    ) -> None:
+        self.algorithm = algorithm
+        self.send_response_after_attempt = send_response_after_attempt
+        self.qop = qop
+        self._regenerate_nonce = regenerate_nonce
+        self._response_count = 0
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        if self._response_count < self.send_response_after_attempt:
+            return self.challenge_send(request)
+
+        body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
+        return AsyncResponse(200, content=body, request=request)
+
+    def challenge_send(self, request: AsyncRequest) -> AsyncResponse:
+        self._response_count += 1
+        nonce = (
+            hashlib.sha256(os.urandom(8)).hexdigest()
+            if self._regenerate_nonce
+            else "ee96edced2a0b43e4869e96ebe27563f369c1205a049d06419bb51d8aeddf3d3"
+        )
+        challenge_data = {
+            "nonce": nonce,
+            "qop": self.qop,
+            "opaque": (
+                "ee6378f3ee14ebfd2fff54b70a91a7c9390518047f242ab2271380db0e14bda1"
+            ),
+            "algorithm": self.algorithm,
+            "stale": "FALSE",
+        }
+        challenge_str = ", ".join(
+            '{}="{}"'.format(key, value)
+            for key, value in challenge_data.items()
+            if value
+        )
+
+        headers = [
+            ("www-authenticate", 'Digest realm="httpx@example.org", ' + challenge_str)
+        ]
+        return AsyncResponse(401, headers=headers, content=b"", request=request)
+
+    def reset(self) -> None:
+        self._response_count = 0
+
+
+class MockAuthHeaderDispatch(AsyncDispatcher):
+    def __init__(self, auth_header: str) -> None:
+        self.auth_header = auth_header
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        headers = [("www-authenticate", self.auth_header)]
+        return AsyncResponse(401, headers=headers, content=b"", request=request)
+
+
 def test_basic_auth():
     url = "https://example.org/"
     auth = ("tomchristie", "password123")
@@ -127,3 +203,136 @@ def test_auth_invalid_type():
     with Client(dispatch=MockDispatch(), auth="not a tuple, not a callable") as client:
         with pytest.raises(TypeError):
             client.get(url)
+
+
+def test_digest_auth_returns_no_auth_if_no_digest_header_in_response():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert response.json() == {"auth": None}
+
+
+@pytest.mark.parametrize(
+    "algorithm,expected_hash_length,expected_response_length",
+    [
+        ("MD5", 64, 32),
+        ("MD5-SESS", 64, 32),
+        ("SHA", 64, 40),
+        ("SHA-SESS", 64, 40),
+        ("SHA-256", 64, 64),
+        ("SHA-256-SESS", 64, 64),
+        ("SHA-512", 64, 128),
+        ("SHA-512-SESS", 64, 128),
+    ],
+)
+def test_digest_auth(algorithm, expected_hash_length, expected_response_length):
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(dispatch=MockDigestAuthDispatch(algorithm=algorithm)) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    auth = response.json()["auth"]
+    assert auth.startswith("Digest ")
+
+    response_fields = [field.strip() for field in auth[auth.find(" ") :].split(",")]
+    digest_data = dict(field.split("=") for field in response_fields)
+
+    assert digest_data["username"] == '"tomchristie"'
+    assert digest_data["realm"] == '"httpx@example.org"'
+    assert "nonce" in digest_data
+    assert digest_data["uri"] == '"/"'
+    assert len(digest_data["response"]) == expected_response_length + 2  # extra quotes
+    assert len(digest_data["opaque"]) == expected_hash_length + 2
+    assert digest_data["algorithm"] == algorithm
+    assert digest_data["qop"] == "auth"
+    assert digest_data["nc"] == "00000001"
+    assert len(digest_data["cnonce"]) == 16 + 2
+
+
+def test_digest_auth_no_specified_qop():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(dispatch=MockDigestAuthDispatch(qop=None)) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    auth = response.json()["auth"]
+    assert auth.startswith("Digest ")
+
+    response_fields = [field.strip() for field in auth[auth.find(" ") :].split(",")]
+    digest_data = dict(field.split("=") for field in response_fields)
+
+    assert "qop" not in digest_data
+    assert "nc" not in digest_data
+    assert "cnonce" not in digest_data
+    assert digest_data["username"] == '"tomchristie"'
+    assert digest_data["realm"] == '"httpx@example.org"'
+    assert len(digest_data["nonce"]) == 64 + 2  # extra quotes
+    assert digest_data["uri"] == '"/"'
+    assert len(digest_data["response"]) == 64 + 2
+    assert len(digest_data["opaque"]) == 64 + 2
+    assert digest_data["algorithm"] == "SHA-256"
+
+
+def test_digest_auth_qop_including_spaces_and_auth_returns_auth():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(dispatch=MockDigestAuthDispatch(qop="auth, auth-int")) as client:
+        client.get(url, auth=auth)
+
+
+def test_digest_auth_qop_auth_int_not_implemented():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with pytest.raises(NotImplementedError):
+        with Client(dispatch=MockDigestAuthDispatch(qop="auth-int")) as client:
+            client.get(url, auth=auth)
+
+
+def test_digest_auth_qop_must_be_auth_or_auth_int():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with pytest.raises(ProtocolError):
+        with Client(dispatch=MockDigestAuthDispatch(qop="not-auth")) as client:
+            client.get(url, auth=auth)
+
+
+def test_digest_auth_incorrect_credentials():
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with Client(
+        dispatch=MockDigestAuthDispatch(send_response_after_attempt=2)
+    ) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 401
+
+
+@pytest.mark.parametrize(
+    "auth_header",
+    [
+        'Digest realm="httpx@example.org", qop="auth"',  # missing fields
+        'realm="httpx@example.org", qop="auth"',  # not starting with Digest
+        'DigestZ realm="httpx@example.org", qop="auth"'
+        'qop="auth,auth-int",nonce="abc",opaque="xyz"',
+        'Digest realm="httpx@example.org", qop="auth,au',  # malformed fields list
+    ],
+)
+def test_digest_auth_raises_protocol_error_on_malformed_header(auth_header: str):
+    url = "https://example.org/"
+    auth = DigestAuth(username="tomchristie", password="password123")
+
+    with pytest.raises(ProtocolError):
+        with Client(dispatch=MockAuthHeaderDispatch(auth_header=auth_header)) as client:
+            client.get(url, auth=auth)