]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Public Auth API (#732)
authorTom Christie <tom@tomchristie.com>
Tue, 7 Jan 2020 13:20:23 +0000 (13:20 +0000)
committerGitHub <noreply@github.com>
Tue, 7 Jan 2020 13:20:23 +0000 (13:20 +0000)
* Public Auth API

* Minor docs tweak

* Request.aread and Request.content

* Support requires_request_body

* Update tests/models/test_requests.py

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
docs/advanced.md
httpx/__init__.py
httpx/auth.py
httpx/client.py
httpx/exceptions.py
httpx/models.py
tests/client/test_auth.py
tests/models/test_requests.py

index c3df1d136d6a0bab989add2c1f19fb78a4315a29..c22ff647a932ebe1f7e60259d82842a6852ce1a3 100644 (file)
@@ -380,6 +380,69 @@ MIME header field.
 }
 ```
 
+## Customizing authentication
+
+When issuing requests or instantiating a client, the `auth` argument can be used to pass an authentication scheme to use. The `auth` argument may be one of the following...
+
+* A two-tuple of `username`/`password`, to be used with basic authentication.
+* An instance of `httpx.BasicAuth()` or `httpx.DigestAuth()`.
+* A callable, accepting a request and returning an authenticated request instance.
+* A subclass of `httpx.Auth`.
+
+The most involved of these is the last, which allows you to create authentication flows involving one or more requests. A subclass of `httpx.Auth` should implement `def auth_flow(request)`, and yield any requests that need to be made...
+
+```python
+class MyCustomAuth(httpx.Auth):
+    def __init__(self, token):
+        self.token = token
+
+    def auth_flow(self, request):
+        # Send the request, with a custom `X-Authentication` header.
+        request.headers['X-Authentication'] = self.token
+        yield request
+```
+
+If the auth flow requires more that one request, you can issue multiple yields, and obtain the response in each case...
+
+```python
+class MyCustomAuth(httpx.Auth):
+    def __init__(self, token):
+        self.token = token
+
+    def auth_flow(self, request):
+      response = yield request
+      if response.status_code == 401:
+          # If the server issues a 401 response then resend the request,
+          # with a custom `X-Authentication` header.
+          request.headers['X-Authentication'] = self.token
+          yield request
+```
+
+Custom authentication classes are designed to not perform any I/O, so that they may be used with both sync and async client instances. If you are implementing an authentication scheme that requires the request body, then you need to indicate this on the class using a `requires_request_body` property.
+
+You will then be able to access `request.content` inside the `.auth_flow()` method.
+
+```python
+class MyCustomAuth(httpx.Auth):
+    requires_request_body = True
+
+    def __init__(self, token):
+        self.token = token
+
+    def auth_flow(self, request):
+      response = yield request
+      if response.status_code == 401:
+          # If the server issues a 401 response then resend the request,
+          # with a custom `X-Authentication` header.
+          request.headers['X-Authentication'] = self.sign_request(...)
+          yield request
+
+    def sign_request(self, request):
+        # Create a request signature, based on `request.method`, `request.url`,
+        # `request.headers`, and `request.content`.
+        ...
+```
+
 ## SSL certificates
 
 When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).
index 9c4f31b320e259a073063f94417047407cd8ae30..80c29da7133afeacf6c818f6b21e7820733781a7 100644 (file)
@@ -1,6 +1,6 @@
 from .__version__ import __description__, __title__, __version__
 from .api import delete, get, head, options, patch, post, put, request, stream
-from .auth import BasicAuth, DigestAuth
+from .auth import Auth, BasicAuth, DigestAuth
 from .client import AsyncClient, Client
 from .config import TimeoutConfig  # For 0.8 backwards compat.
 from .config import PoolLimits, Proxy, Timeout
@@ -19,6 +19,7 @@ from .exceptions import (
     ReadTimeout,
     RedirectLoop,
     RequestBodyUnavailable,
+    RequestNotRead,
     ResponseClosed,
     ResponseNotRead,
     StreamConsumed,
@@ -45,6 +46,7 @@ __all__ = [
     "stream",
     "codes",
     "AsyncClient",
+    "Auth",
     "BasicAuth",
     "Client",
     "DigestAuth",
@@ -68,6 +70,7 @@ __all__ = [
     "RequestBodyUnavailable",
     "ResponseClosed",
     "ResponseNotRead",
+    "RequestNotRead",
     "StreamConsumed",
     "ProxyError",
     "TooManyRedirects",
index e412c5707fb0cb75bbbe32006b451b2daf3387e2..d38322f7d73a467a7589b2d8f3e73b7c8a243117 100644 (file)
@@ -10,8 +10,6 @@ from .exceptions import ProtocolError, RequestBodyUnavailable
 from .models import Request, Response
 from .utils import to_bytes, to_str, unquote
 
-AuthFlow = typing.Generator[Request, Response, None]
-
 AuthTypes = typing.Union[
     typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
     typing.Callable[["Request"], "Request"],
@@ -24,7 +22,9 @@ class Auth:
     Base class for all authentication schemes.
     """
 
-    def __call__(self, request: Request) -> AuthFlow:
+    requires_request_body = False
+
+    def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
         """
         Execute the authentication flow.
 
@@ -58,7 +58,7 @@ class FunctionAuth(Auth):
     def __init__(self, func: typing.Callable[[Request], Request]) -> None:
         self.func = func
 
-    def __call__(self, request: Request) -> AuthFlow:
+    def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
         yield self.func(request)
 
 
@@ -73,7 +73,7 @@ class BasicAuth(Auth):
     ):
         self.auth_header = self.build_auth_header(username, password)
 
-    def __call__(self, request: Request) -> AuthFlow:
+    def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
         request.headers["Authorization"] = self.auth_header
         yield request
 
@@ -103,7 +103,7 @@ class DigestAuth(Auth):
         self.username = to_bytes(username)
         self.password = to_bytes(password)
 
-    def __call__(self, request: Request) -> AuthFlow:
+    def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
         if not request.stream.can_replay():
             raise RequestBodyUnavailable("Request body is no longer available.")
         response = yield request
index 45b6ec70411b625869d062ff343398a79c74aedf..a38659092477b21a8a9c61b2904fe414caafab1e 100644 (file)
@@ -676,7 +676,10 @@ class AsyncClient:
         auth: Auth,
         timeout: Timeout,
     ) -> Response:
-        auth_flow = auth(request)
+        if auth.requires_request_body:
+            await request.aread()
+
+        auth_flow = auth.auth_flow(request)
         request = next(auth_flow)
         while True:
             response = await self.send_single_request(request, timeout)
index 9f2119852c8ac34885a11ad53ea43545773f7bf0..7efe6fb3c92329cc649f465af1a45c08c4a778e6 100644 (file)
@@ -146,6 +146,12 @@ class ResponseNotRead(StreamError):
     """
 
 
+class RequestNotRead(StreamError):
+    """
+    Attempted to access request content, without having called `read()`.
+    """
+
+
 class ResponseClosed(StreamError):
     """
     Attempted to read or stream response content, but the request has been
index d0a438feafde71bd2ed9e2b67a54a790a2cbb86e..0cc13ed1ad0ef79903063d26c7a4ded885060b97 100644 (file)
@@ -33,6 +33,7 @@ from .exceptions import (
     HTTPError,
     InvalidURL,
     NotRedirectResponse,
+    RequestNotRead,
     ResponseClosed,
     ResponseNotRead,
     StreamConsumed,
@@ -641,6 +642,24 @@ class Request:
         for item in reversed(auto_headers):
             self.headers.raw.insert(0, item)
 
+    @property
+    def content(self) -> bytes:
+        if not hasattr(self, "_content"):
+            raise RequestNotRead()
+        return self._content
+
+    async def aread(self) -> bytes:
+        """
+        Read and return the request content.
+        """
+        if not hasattr(self, "_content"):
+            self._content = b"".join([part async for part in self.stream])
+            # If a streaming request has been read entirely into memory, then
+            # we can replace the stream with a raw bytes implementation,
+            # to ensure that any non-replayable streams can still be used.
+            self.stream = ByteStream(self._content)
+        return self._content
+
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
         url = str(self.url)
index af9a17be4e106c79a6c0bdbf7e8d936f35e8f0a8..8449b941e0adc02e4e9d5f721f98aabe4a325439 100644 (file)
@@ -8,13 +8,13 @@ import pytest
 from httpx import (
     URL,
     AsyncClient,
+    Auth,
     DigestAuth,
     ProtocolError,
     Request,
     RequestBodyUnavailable,
     Response,
 )
-from httpx.auth import Auth, AuthFlow
 from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
 from httpx.dispatch.base import AsyncDispatcher
 
@@ -418,10 +418,14 @@ async def test_auth_history() -> None:
         of intermediate responses.
         """
 
+        requires_request_body = True
+
         def __init__(self, repeat: int):
             self.repeat = repeat
 
-        def __call__(self, request: Request) -> AuthFlow:
+        def auth_flow(
+            self, request: Request
+        ) -> typing.Generator[Request, Response, None]:
             nonces = []
 
             for index in range(self.repeat):
index 43afe0436089964a38bf38e21243ad4a180d2ce6..0c7269e21957bb293eae90edfeb3d81ca8a601c0 100644 (file)
@@ -21,19 +21,38 @@ def test_content_length_header():
 @pytest.mark.asyncio
 async def test_url_encoded_data():
     request = httpx.Request("POST", "http://example.org", data={"test": "123"})
-    content = b"".join([part async for part in request.stream])
+    await request.aread()
 
     assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
-    assert content == b"test=123"
+    assert request.content == b"test=123"
 
 
 @pytest.mark.asyncio
 async def test_json_encoded_data():
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
-    content = b"".join([part async for part in request.stream])
+    await request.aread()
 
     assert request.headers["Content-Type"] == "application/json"
-    assert content == b'{"test": 123}'
+    assert request.content == b'{"test": 123}'
+
+
+@pytest.mark.asyncio
+async def test_read_and_stream_data():
+    # Ensure a request may still be streamed if it has been read.
+    # Needed for cases such as authentication classes that read the request body.
+    request = httpx.Request("POST", "http://example.org", json={"test": 123})
+    await request.aread()
+    content = b"".join([part async for part in request.stream])
+    assert content == request.content
+
+
+@pytest.mark.asyncio
+async def test_cannot_access_content_without_read():
+    # Ensure a request may still be streamed if it has been read.
+    #  Needed for cases such as authentication classes that read the request body.
+    request = httpx.Request("POST", "http://example.org", json={"test": 123})
+    with pytest.raises(httpx.RequestNotRead):
+        request.content
 
 
 def test_transfer_encoding_header():