]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add support for sync-specific or async-specific auth flows (#1217)
authorFlorimond Manca <florimond.manca@gmail.com>
Wed, 9 Sep 2020 13:37:20 +0000 (15:37 +0200)
committerGitHub <noreply@github.com>
Wed, 9 Sep 2020 13:37:20 +0000 (14:37 +0100)
* Add support for async auth flows

* Move body logic to Auth, add sync_auth_flow, add NoAuth

* Update tests

* Stick to next() / __anext__()

* Fix undefined name errors

* Add docs

* Add unit tests for auth classes

Co-authored-by: Tom Christie <tom@tomchristie.com>
docs/advanced.md
httpx/_auth.py
httpx/_client.py
tests/client/test_auth.py
tests/test_auth.py [new file with mode: 0644]

index b2a07df371478a18cfdaa4316acab50c8bed6bea..0f0b2ddf72a531c9d3ad974c88eea0ac3249b677 100644 (file)
@@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth):
         ...
 ```
 
+If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`.
+
+```python
+import asyncio
+import threading
+import httpx
+
+
+class MyCustomAuth(httpx.Auth):
+    def __init__(self):
+        self._sync_lock = threading.RLock()
+        self._async_lock = asyncio.Lock()
+
+    def sync_get_token(self):
+        with self._sync_lock:
+            ...
+
+    def sync_auth_flow(self, request):
+        token = self.sync_get_token()
+        request.headers["Authorization"] = f"Token {token}"
+        yield request
+
+    async def async_get_token(self):
+        async with self._async_lock:
+            ...
+
+    async def async_auth_flow(self, request):
+        token = await self.async_get_token()
+        request.headers["Authorization"] = f"Token {token}"
+        yield request
+```
+
+If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`.
+
+```python
+import httpx
+import sync_only_library
+
+
+class MyCustomAuth(httpx.Auth):
+    def sync_auth_flow(self, request):
+        token = sync_only_library.get_token(...)
+        request.headers["Authorization"] = f"Token {token}"
+        yield request
+
+    async def async_auth_flow(self, request):
+        raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient")
+```
+
 ## SSL certificates
 
 When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).
index eb110dea3ae02e2affd7e348c571156a7a89bc34..439f337fbfe5eb5141e2d59c3a91a34e58110a5a 100644 (file)
@@ -17,6 +17,11 @@ class Auth:
 
     To implement a custom authentication scheme, subclass `Auth` and override
     the `.auth_flow()` method.
+
+    If the authentication scheme does I/O such as disk access or network calls, or uses
+    synchronization primitives such as locks, you should override `.sync_auth_flow()`
+    and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
+    implementations that will be used by `Client` and `AsyncClient` respectively.
     """
 
     requires_request_body = False
@@ -46,6 +51,56 @@ class Auth:
         """
         yield request
 
+    def sync_auth_flow(
+        self, request: Request
+    ) -> typing.Generator[Request, Response, None]:
+        """
+        Execute the authentication flow synchronously.
+
+        By default, this defers to `.auth_flow()`. You should override this method
+        when the authentication scheme does I/O and/or uses concurrency primitives.
+        """
+        if self.requires_request_body:
+            request.read()
+
+        flow = self.auth_flow(request)
+        request = next(flow)
+
+        while True:
+            response = yield request
+            if self.requires_response_body:
+                response.read()
+
+            try:
+                request = flow.send(response)
+            except StopIteration:
+                break
+
+    async def async_auth_flow(
+        self, request: Request
+    ) -> typing.AsyncGenerator[Request, Response]:
+        """
+        Execute the authentication flow asynchronously.
+
+        By default, this defers to `.auth_flow()`. You should override this method
+        when the authentication scheme does I/O and/or uses concurrency primitives.
+        """
+        if self.requires_request_body:
+            await request.aread()
+
+        flow = self.auth_flow(request)
+        request = next(flow)
+
+        while True:
+            response = yield request
+            if self.requires_response_body:
+                await response.aread()
+
+            try:
+                request = flow.send(response)
+            except StopIteration:
+                break
+
 
 class FunctionAuth(Auth):
     """
index 0b67a78dddf3f0847abdad05d724757efdcfacf6..61a862bde367649eb568c0d1e5d9555acc7307ab 100644 (file)
@@ -785,15 +785,12 @@ class Client(BaseClient):
         auth: Auth,
         timeout: Timeout,
     ) -> Response:
-        if auth.requires_request_body:
-            request.read()
-
-        auth_flow = auth.auth_flow(request)
+        auth_flow = auth.sync_auth_flow(request)
         request = next(auth_flow)
+
         while True:
             response = self._send_single_request(request, timeout)
-            if auth.requires_response_body:
-                response.read()
+
             try:
                 next_request = auth_flow.send(response)
             except StopIteration:
@@ -1409,18 +1406,15 @@ class AsyncClient(BaseClient):
         auth: Auth,
         timeout: Timeout,
     ) -> Response:
-        if auth.requires_request_body:
-            await request.aread()
+        auth_flow = auth.async_auth_flow(request)
+        request = await auth_flow.__anext__()
 
-        auth_flow = auth.auth_flow(request)
-        request = next(auth_flow)
         while True:
             response = await self._send_single_request(request, timeout)
-            if auth.requires_response_body:
-                await response.aread()
+
             try:
-                next_request = auth_flow.send(response)
-            except StopIteration:
+                next_request = await auth_flow.asend(response)
+            except StopAsyncIteration:
                 return response
             except BaseException as exc:
                 await response.aclose()
index a08c3292fdb60feeeab1a63aa974abc605dbf749..c6c6d979accc785bb2f753ff49c9e35548d14015 100644 (file)
@@ -1,5 +1,12 @@
+"""
+Integration tests for authentication.
+
+Unit tests for auth classes also exist in tests/test_auth.py
+"""
+import asyncio
 import hashlib
 import os
+import threading
 import typing
 
 import httpcore
@@ -183,6 +190,31 @@ class ResponseBodyAuth(Auth):
         yield request
 
 
+class SyncOrAsyncAuth(Auth):
+    """
+    A mock authentication scheme that uses a different implementation for the
+    sync and async cases.
+    """
+
+    def __init__(self) -> None:
+        self._lock = threading.Lock()
+        self._async_lock = asyncio.Lock()
+
+    def sync_auth_flow(
+        self, request: Request
+    ) -> typing.Generator[Request, Response, None]:
+        with self._lock:
+            request.headers["Authorization"] = "sync-auth"
+        yield request
+
+    async def async_auth_flow(
+        self, request: Request
+    ) -> typing.AsyncGenerator[Request, Response]:
+        async with self._async_lock:
+            request.headers["Authorization"] = "async-auth"
+        yield request
+
+
 @pytest.mark.asyncio
 async def test_basic_auth() -> None:
     url = "https://example.org/"
@@ -664,3 +696,34 @@ def test_sync_auth_reads_response_body() -> None:
 
     assert response.status_code == 200
     assert response.json() == {"auth": '{"auth": "xyz"}'}
+
+
+@pytest.mark.asyncio
+async def test_async_auth() -> None:
+    """
+    Test that we can use an auth implementation specific to the async case, to
+    support cases that require performing I/O or using concurrency primitives (such
+    as checking a disk-based cache or fetching a token from a remote auth server).
+    """
+    url = "https://example.org/"
+    auth = SyncOrAsyncAuth()
+
+    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+        response = await client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert response.json() == {"auth": "async-auth"}
+
+
+def test_sync_auth() -> None:
+    """
+    Test that we can use an auth implementation specific to the sync case.
+    """
+    url = "https://example.org/"
+    auth = SyncOrAsyncAuth()
+
+    with httpx.Client(transport=SyncMockTransport()) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert response.json() == {"auth": "sync-auth"}
diff --git a/tests/test_auth.py b/tests/test_auth.py
new file mode 100644 (file)
index 0000000..20c666a
--- /dev/null
@@ -0,0 +1,63 @@
+"""
+Unit tests for auth classes.
+
+Integration tests also exist in tests/client/test_auth.py
+"""
+import pytest
+
+import httpx
+
+
+def test_basic_auth():
+    auth = httpx.BasicAuth(username="user", password="pass")
+    request = httpx.Request("GET", "https://www.example.com")
+
+    # The initial request should include a basic auth header.
+    flow = auth.sync_auth_flow(request)
+    request = next(flow)
+    assert request.headers["Authorization"].startswith("Basic")
+
+    # No other requests are made.
+    response = httpx.Response(content=b"Hello, world!", status_code=200)
+    with pytest.raises(StopIteration):
+        flow.send(response)
+
+
+def test_digest_auth_with_200():
+    auth = httpx.DigestAuth(username="user", password="pass")
+    request = httpx.Request("GET", "https://www.example.com")
+
+    # The initial request should not include an auth header.
+    flow = auth.sync_auth_flow(request)
+    request = next(flow)
+    assert "Authorization" not in request.headers
+
+    # If a 200 response is returned, then no other requests are made.
+    response = httpx.Response(content=b"Hello, world!", status_code=200)
+    with pytest.raises(StopIteration):
+        flow.send(response)
+
+
+def test_digest_auth_with_401():
+    auth = httpx.DigestAuth(username="user", password="pass")
+    request = httpx.Request("GET", "https://www.example.com")
+
+    # The initial request should not include an auth header.
+    flow = auth.sync_auth_flow(request)
+    request = next(flow)
+    assert "Authorization" not in request.headers
+
+    # If a 401 response is returned, then a digest auth request is made.
+    headers = {
+        "WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
+    }
+    response = httpx.Response(
+        content=b"Auth required", status_code=401, headers=headers
+    )
+    request = flow.send(response)
+    assert request.headers["Authorization"].startswith("Digest")
+
+    # No other requests are made.
+    response = httpx.Response(content=b"Hello, world!", status_code=200)
+    with pytest.raises(StopIteration):
+        flow.send(response)