]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Tighten public API on auth classes (#1084)
authorTom Christie <tom@tomchristie.com>
Sun, 26 Jul 2020 18:05:33 +0000 (19:05 +0100)
committerGitHub <noreply@github.com>
Sun, 26 Jul 2020 18:05:33 +0000 (19:05 +0100)
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_auth.py

index d60dbdc028643a53306043be36083f8bb8ac4c76..6940019e7e328d7303aee380b00752c88ff0da37 100644 (file)
@@ -14,6 +14,9 @@ from ._utils import to_bytes, to_str, unquote
 class Auth:
     """
     Base class for all authentication schemes.
+
+    To implement a custom authentication scheme, subclass `Auth` and override
+    the `.auth_flow()` method.
     """
 
     requires_request_body = False
@@ -51,10 +54,10 @@ class FunctionAuth(Auth):
     """
 
     def __init__(self, func: typing.Callable[[Request], Request]) -> None:
-        self.func = func
+        self._func = func
 
     def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
-        yield self.func(request)
+        yield self._func(request)
 
 
 class BasicAuth(Auth):
@@ -66,13 +69,13 @@ class BasicAuth(Auth):
     def __init__(
         self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
     ):
-        self.auth_header = self.build_auth_header(username, password)
+        self._auth_header = self._build_auth_header(username, password)
 
     def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
-        request.headers["Authorization"] = self.auth_header
+        request.headers["Authorization"] = self._auth_header
         yield request
 
-    def build_auth_header(
+    def _build_auth_header(
         self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
     ) -> str:
         userpass = b":".join((to_bytes(username), to_bytes(password)))
@@ -81,7 +84,7 @@ class BasicAuth(Auth):
 
 
 class DigestAuth(Auth):
-    ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
+    _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
         "MD5": hashlib.md5,
         "MD5-SESS": hashlib.md5,
         "SHA": hashlib.sha1,
@@ -95,8 +98,8 @@ class DigestAuth(Auth):
     def __init__(
         self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
     ) -> None:
-        self.username = to_bytes(username)
-        self.password = to_bytes(password)
+        self._username = to_bytes(username)
+        self._password = to_bytes(password)
 
     def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
         if not request.stream.can_replay():
@@ -114,23 +117,46 @@ class DigestAuth(Auth):
             return
 
         header = response.headers["www-authenticate"]
-        try:
-            challenge = DigestAuthChallenge.from_header(header)
-        except ValueError:
-            raise ProtocolError("Malformed Digest authentication header")
-
+        challenge = self._parse_challenge(header)
         request.headers["Authorization"] = self._build_auth_header(request, challenge)
         yield request
 
+    def _parse_challenge(self, header: str) -> "_DigestAuthChallenge":
+        """
+        Returns a challenge from a Digest WWW-Authenticate header.
+        These take the form of:
+        `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
+        """
+        scheme, _, fields = header.partition(" ")
+        if scheme.lower() != "digest":
+            raise ProtocolError("Header does not start with 'Digest'")
+
+        header_dict: typing.Dict[str, str] = {}
+        for field in parse_http_list(fields):
+            key, value = field.strip().split("=", 1)
+            header_dict[key] = unquote(value)
+
+        try:
+            realm = header_dict["realm"].encode()
+            nonce = header_dict["nonce"].encode()
+            qop = header_dict["qop"].encode() if "qop" in header_dict else None
+            opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
+            algorithm = header_dict.get("algorithm", "MD5")
+            return _DigestAuthChallenge(
+                realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
+            )
+        except KeyError as exc:
+            raise ProtocolError("Malformed Digest WWW-Authenticate header") from exc
+
     def _build_auth_header(
-        self, request: Request, challenge: "DigestAuthChallenge"
+        self, request: Request, challenge: "_DigestAuthChallenge"
     ) -> str:
-        hash_func = self.ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]
+        hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]
 
         def digest(data: bytes) -> bytes:
             return hash_func(data).hexdigest().encode()
 
-        A1 = b":".join((self.username, challenge.realm, self.password))
+        A1 = b":".join((self._username, challenge.realm, self._password))
 
         path = request.url.full_path.encode("utf-8")
         A2 = b":".join((request.method.encode(), path))
@@ -153,7 +179,7 @@ class DigestAuth(Auth):
         key_digest = b":".join(digest_data)
 
         format_args = {
-            "username": self.username,
+            "username": self._username,
             "realm": challenge.realm,
             "nonce": challenge.nonce,
             "uri": path,
@@ -208,48 +234,17 @@ class DigestAuth(Auth):
         raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth')
 
 
-class DigestAuthChallenge:
+class _DigestAuthChallenge:
     def __init__(
         self,
         realm: bytes,
         nonce: bytes,
-        algorithm: str = None,
+        algorithm: str,
         opaque: typing.Optional[bytes] = None,
         qop: typing.Optional[bytes] = None,
     ) -> None:
         self.realm = realm
         self.nonce = nonce
-        self.algorithm = algorithm or "MD5"
+        self.algorithm = algorithm
         self.opaque = opaque
         self.qop = qop
-
-    @classmethod
-    def from_header(cls, header: str) -> "DigestAuthChallenge":
-        """Returns a challenge from a Digest WWW-Authenticate header.
-        These take the form of:
-        `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
-        """
-        scheme, _, fields = header.partition(" ")
-        if scheme.lower() != "digest":
-            raise ValueError("Header does not start with 'Digest'")
-
-        header_dict: typing.Dict[str, str] = {}
-        for field in parse_http_list(fields):
-            key, value = field.strip().split("=", 1)
-            header_dict[key] = unquote(value)
-
-        try:
-            return cls.from_header_dict(header_dict)
-        except KeyError as exc:
-            raise ValueError("Malformed Digest WWW-Authenticate header") from exc
-
-    @classmethod
-    def from_header_dict(cls, header_dict: dict) -> "DigestAuthChallenge":
-        realm = header_dict["realm"].encode()
-        nonce = header_dict["nonce"].encode()
-        qop = header_dict["qop"].encode() if "qop" in header_dict else None
-        opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
-        algorithm = header_dict.get("algorithm")
-        return cls(
-            realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
-        )