]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
No I/O auth (#644)
authorTom Christie <tom@tomchristie.com>
Wed, 18 Dec 2019 15:25:31 +0000 (15:25 +0000)
committerGitHub <noreply@github.com>
Wed, 18 Dec 2019 15:25:31 +0000 (15:25 +0000)
* No I/O on Auth

httpx/api.py
httpx/auth.py
httpx/client.py
httpx/middleware.py [deleted file]
httpx/models.py
tests/client/test_auth.py

index d2f1d6dd78a295c455272e30f4c3f56c21ee86a2..ccf7a346a95de0e16608ac53a31c4916f15e6563 100644 (file)
@@ -1,9 +1,9 @@
 import typing
 
+from .auth import AuthTypes
 from .client import Client, StreamContextManager
 from .config import DEFAULT_TIMEOUT_CONFIG, CertTypes, TimeoutTypes, VerifyTypes
 from .models import (
-    AuthTypes,
     CookieTypes,
     HeaderTypes,
     QueryParamTypes,
index eb93ff35928a6adb4b823f0f2ad47f963a3fefce..e0ef50c3a0b8dcadad8fae943ea7a3447fdc44b7 100644 (file)
@@ -7,20 +7,75 @@ from base64 import b64encode
 from urllib.request import parse_http_list
 
 from .exceptions import ProtocolError
-from .middleware import Middleware
 from .models import Request, Response
 from .utils import to_bytes, to_str, unquote
 
+AuthFlow = typing.Generator[Request, Response, None]
+
+AuthTypes = typing.Union[
+    typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
+    typing.Callable[["Request"], "Request"],
+    "Auth",
+]
+
+
+class Auth:
+    """
+    Base class for all authentication schemes.
+    """
+
+    def __call__(self, request: Request) -> AuthFlow:
+        """
+        Execute the authentication flow.
+
+        To dispatch a request, `yield` it:
+
+        ```
+        yield request
+        ```
+
+        The client will `.send()` the response back into the flow generator. You can
+        access it like so:
+
+        ```
+        response = yield request
+        ```
+
+        A `return` (or reaching the end of the generator) will result in the
+        client returning the last response obtained from the server.
+
+        You can dispatch as many requests as is necessary.
+        """
+        yield request
+
+
+class FunctionAuth(Auth):
+    """
+    Allows the 'auth' argument to be passed as a simple callable function,
+    that takes the request, and returns a new, modified request.
+    """
+
+    def __init__(self, func: typing.Callable[[Request], Request]) -> None:
+        self.func = func
+
+    def __call__(self, request: Request) -> AuthFlow:
+        yield self.func(request)
+
+
+class BasicAuth(Auth):
+    """
+    Allows the 'auth' argument to be passed as a (username, password) pair,
+    and uses HTTP Basic authentication.
+    """
 
-class BasicAuth:
     def __init__(
         self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
     ):
         self.auth_header = self.build_auth_header(username, password)
 
-    def __call__(self, request: Request) -> Request:
+    def __call__(self, request: Request) -> AuthFlow:
         request.headers["Authorization"] = self.auth_header
-        return request
+        yield request
 
     def build_auth_header(
         self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
@@ -30,7 +85,7 @@ class BasicAuth:
         return f"Basic {token}"
 
 
-class DigestAuth(Middleware):
+class DigestAuth(Auth):
     ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
         "MD5": hashlib.md5,
         "MD5-SESS": hashlib.md5,
@@ -48,14 +103,14 @@ class DigestAuth(Middleware):
         self.username = to_bytes(username)
         self.password = to_bytes(password)
 
-    async def __call__(
-        self, request: Request, get_response: typing.Callable
-    ) -> Response:
-        response = await get_response(request)
+    def __call__(self, request: Request) -> AuthFlow:
+        response = yield request
+
         if response.status_code != 401 or "www-authenticate" not in response.headers:
-            return response
+            # If the response is not a 401 WWW-Authenticate, then we don't
+            # need to build an authenticated request.
+            return
 
-        await response.close()
         header = response.headers["www-authenticate"]
         try:
             challenge = DigestAuthChallenge.from_header(header)
@@ -63,7 +118,7 @@ class DigestAuth(Middleware):
             raise ProtocolError("Malformed Digest authentication header")
 
         request.headers["Authorization"] = self._build_auth_header(request, challenge)
-        return await get_response(request)
+        yield request
 
     def _build_auth_header(
         self, request: Request, challenge: "DigestAuthChallenge"
index ab86be15c254b055d0e7e23213e460141364b137..a6c22d8f49a800409981d3d843602f3778912621 100644 (file)
@@ -5,7 +5,7 @@ from types import TracebackType
 
 import hstspreload
 
-from .auth import BasicAuth
+from .auth import Auth, AuthTypes, BasicAuth, FunctionAuth
 from .concurrency.base import ConcurrencyBackend
 from .config import (
     DEFAULT_MAX_REDIRECTS,
@@ -31,10 +31,8 @@ from .exceptions import (
     RedirectLoop,
     TooManyRedirects,
 )
-from .middleware import Middleware
 from .models import (
     URL,
-    AuthTypes,
     Cookies,
     CookieTypes,
     Headers,
@@ -397,28 +395,18 @@ class Client:
         if request.url.scheme not in ("http", "https"):
             raise InvalidURL('URL scheme must be "http" or "https".')
 
-        auth = self.auth if auth is None else auth
-        trust_env = self.trust_env if trust_env is None else trust_env
         timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
 
-        if not isinstance(auth, Middleware):
-            request = self.authenticate(request, trust_env, auth)
-            response = await self.send_handling_redirects(
-                request,
-                verify=verify,
-                cert=cert,
-                timeout=timeout,
-                allow_redirects=allow_redirects,
-            )
-        else:
-            get_response = functools.partial(
-                self.send_handling_redirects,
-                verify=verify,
-                cert=cert,
-                timeout=timeout,
-                allow_redirects=allow_redirects,
-            )
-            response = await auth(request, get_response)
+        auth = self.setup_auth(request, trust_env, auth)
+
+        response = await self.send_handling_redirects(
+            request,
+            auth=auth,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            allow_redirects=allow_redirects,
+        )
 
         if not stream:
             try:
@@ -428,30 +416,36 @@ class Client:
 
         return response
 
-    def authenticate(
-        self, request: Request, trust_env: bool, auth: AuthTypes = None
-    ) -> "Request":
+    def setup_auth(
+        self, request: Request, trust_env: bool = None, auth: AuthTypes = None
+    ) -> Auth:
+        auth = self.auth if auth is None else auth
+        trust_env = self.trust_env if trust_env is None else trust_env
+
         if auth is not None:
             if isinstance(auth, tuple):
-                auth = BasicAuth(username=auth[0], password=auth[1])
-            return auth(request)
+                return BasicAuth(username=auth[0], password=auth[1])
+            elif isinstance(auth, Auth):
+                return auth
+            elif callable(auth):
+                return FunctionAuth(func=auth)
+            raise TypeError('Invalid "auth" argument.')
 
         username, password = request.url.username, request.url.password
         if username or password:
-            auth = BasicAuth(username=username, password=password)
-            return auth(request)
+            return BasicAuth(username=username, password=password)
 
         if trust_env and "Authorization" not in request.headers:
             credentials = self.netrc.get_credentials(request.url.authority)
             if credentials is not None:
-                auth = BasicAuth(username=credentials[0], password=credentials[1])
-                return auth(request)
+                return BasicAuth(username=credentials[0], password=credentials[1])
 
-        return request
+        return Auth()
 
     async def send_handling_redirects(
         self,
         request: Request,
+        auth: Auth,
         timeout: Timeout,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
@@ -467,8 +461,8 @@ class Client:
             if request.url in (response.url for response in history):
                 raise RedirectLoop()
 
-            response = await self.send_single_request(
-                request, verify=verify, cert=cert, timeout=timeout
+            response = await self.send_handling_auth(
+                request, auth=auth, timeout=timeout, verify=verify, cert=cert
             )
             response.history = list(history)
 
@@ -483,6 +477,7 @@ class Client:
                 response.call_next = functools.partial(
                     self.send_handling_redirects,
                     request=request,
+                    auth=auth,
                     verify=verify,
                     cert=cert,
                     timeout=timeout,
@@ -581,6 +576,29 @@ class Client:
             raise RedirectBodyUnavailable()
         return request.content
 
+    async def send_handling_auth(
+        self,
+        request: Request,
+        auth: Auth,
+        timeout: Timeout,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+    ) -> Response:
+        auth_flow = auth(request)
+        request = next(auth_flow)
+        while True:
+            response = await self.send_single_request(request, timeout, verify, cert)
+            try:
+                next_request = auth_flow.send(response)
+            except StopIteration:
+                return response
+            except BaseException as exc:
+                await response.close()
+                raise exc from None
+            else:
+                request = next_request
+                await response.close()
+
     async def send_single_request(
         self,
         request: Request,
diff --git a/httpx/middleware.py b/httpx/middleware.py
deleted file mode 100644 (file)
index 7ef8707..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-import typing
-
-from .models import Request, Response
-
-
-class Middleware:
-    async def __call__(
-        self, request: Request, get_response: typing.Callable
-    ) -> Response:
-        raise NotImplementedError  # pragma: no cover
index f936da5eaca3c109304877287d6dfddf9605fa8c..ba19f43fe2455b68a15e34e5db81d5af7561b9be 100644 (file)
@@ -67,12 +67,6 @@ HeaderTypes = typing.Union[
 
 CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
 
-AuthTypes = typing.Union[
-    typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
-    typing.Callable[["Request"], "Request"],
-    "BaseMiddleware",
-]
-
 ProxiesTypes = typing.Union[
     URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
 ]
index 16cf86d27024e7714afe9b0cad2ba5d6794f6de3..75c293061bae756656ca7244845893776ca18cf7 100644 (file)
@@ -164,14 +164,14 @@ async def test_trust_env_auth():
     os.environ["NETRC"] = "tests/.netrc"
     url = "http://netrcexample.org"
 
-    client = Client(dispatch=MockDispatch())
-    response = await client.get(url, trust_env=False)
+    client = Client(dispatch=MockDispatch(), trust_env=False)
+    response = await client.get(url)
 
     assert response.status_code == 200
     assert response.json() == {"auth": None}
 
-    client = Client(dispatch=MockDispatch(), trust_env=False)
-    response = await client.get(url, trust_env=True)
+    client = Client(dispatch=MockDispatch(), trust_env=True)
+    response = await client.get(url)
 
     assert response.status_code == 200
     assert response.json() == {