...
```
+If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`.
+
+```python
+import asyncio
+import threading
+import httpx
+
+
+class MyCustomAuth(httpx.Auth):
+ def __init__(self):
+ self._sync_lock = threading.RLock()
+ self._async_lock = asyncio.Lock()
+
+ def sync_get_token(self):
+ with self._sync_lock:
+ ...
+
+ def sync_auth_flow(self, request):
+ token = self.sync_get_token()
+ request.headers["Authorization"] = f"Token {token}"
+ yield request
+
+ async def async_get_token(self):
+ async with self._async_lock:
+ ...
+
+ async def async_auth_flow(self, request):
+ token = await self.async_get_token()
+ request.headers["Authorization"] = f"Token {token}"
+ yield request
+```
+
+If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`.
+
+```python
+import httpx
+import sync_only_library
+
+
+class MyCustomAuth(httpx.Auth):
+ def sync_auth_flow(self, request):
+ token = sync_only_library.get_token(...)
+ request.headers["Authorization"] = f"Token {token}"
+ yield request
+
+ async def async_auth_flow(self, request):
+ raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient")
+```
+
## SSL certificates
When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).
To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.
+
+ If the authentication scheme does I/O such as disk access or network calls, or uses
+ synchronization primitives such as locks, you should override `.sync_auth_flow()`
+ and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
+ implementations that will be used by `Client` and `AsyncClient` respectively.
"""
requires_request_body = False
"""
yield request
+ def sync_auth_flow(
+ self, request: Request
+ ) -> typing.Generator[Request, Response, None]:
+ """
+ Execute the authentication flow synchronously.
+
+ By default, this defers to `.auth_flow()`. You should override this method
+ when the authentication scheme does I/O and/or uses concurrency primitives.
+ """
+ if self.requires_request_body:
+ request.read()
+
+ flow = self.auth_flow(request)
+ request = next(flow)
+
+ while True:
+ response = yield request
+ if self.requires_response_body:
+ response.read()
+
+ try:
+ request = flow.send(response)
+ except StopIteration:
+ break
+
+ async def async_auth_flow(
+ self, request: Request
+ ) -> typing.AsyncGenerator[Request, Response]:
+ """
+ Execute the authentication flow asynchronously.
+
+ By default, this defers to `.auth_flow()`. You should override this method
+ when the authentication scheme does I/O and/or uses concurrency primitives.
+ """
+ if self.requires_request_body:
+ await request.aread()
+
+ flow = self.auth_flow(request)
+ request = next(flow)
+
+ while True:
+ response = yield request
+ if self.requires_response_body:
+ await response.aread()
+
+ try:
+ request = flow.send(response)
+ except StopIteration:
+ break
+
class FunctionAuth(Auth):
"""
auth: Auth,
timeout: Timeout,
) -> Response:
- if auth.requires_request_body:
- request.read()
-
- auth_flow = auth.auth_flow(request)
+ auth_flow = auth.sync_auth_flow(request)
request = next(auth_flow)
+
while True:
response = self._send_single_request(request, timeout)
- if auth.requires_response_body:
- response.read()
+
try:
next_request = auth_flow.send(response)
except StopIteration:
auth: Auth,
timeout: Timeout,
) -> Response:
- if auth.requires_request_body:
- await request.aread()
+ auth_flow = auth.async_auth_flow(request)
+ request = await auth_flow.__anext__()
- auth_flow = auth.auth_flow(request)
- request = next(auth_flow)
while True:
response = await self._send_single_request(request, timeout)
- if auth.requires_response_body:
- await response.aread()
+
try:
- next_request = auth_flow.send(response)
- except StopIteration:
+ next_request = await auth_flow.asend(response)
+ except StopAsyncIteration:
return response
except BaseException as exc:
await response.aclose()
+"""
+Integration tests for authentication.
+
+Unit tests for auth classes also exist in tests/test_auth.py
+"""
+import asyncio
import hashlib
import os
+import threading
import typing
import httpcore
yield request
+class SyncOrAsyncAuth(Auth):
+ """
+ A mock authentication scheme that uses a different implementation for the
+ sync and async cases.
+ """
+
+ def __init__(self) -> None:
+ self._lock = threading.Lock()
+ self._async_lock = asyncio.Lock()
+
+ def sync_auth_flow(
+ self, request: Request
+ ) -> typing.Generator[Request, Response, None]:
+ with self._lock:
+ request.headers["Authorization"] = "sync-auth"
+ yield request
+
+ async def async_auth_flow(
+ self, request: Request
+ ) -> typing.AsyncGenerator[Request, Response]:
+ async with self._async_lock:
+ request.headers["Authorization"] = "async-auth"
+ yield request
+
+
@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}
+
+
+@pytest.mark.asyncio
+async def test_async_auth() -> None:
+ """
+ Test that we can use an auth implementation specific to the async case, to
+ support cases that require performing I/O or using concurrency primitives (such
+ as checking a disk-based cache or fetching a token from a remote auth server).
+ """
+ url = "https://example.org/"
+ auth = SyncOrAsyncAuth()
+
+ async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+ response = await client.get(url, auth=auth)
+
+ assert response.status_code == 200
+ assert response.json() == {"auth": "async-auth"}
+
+
+def test_sync_auth() -> None:
+ """
+ Test that we can use an auth implementation specific to the sync case.
+ """
+ url = "https://example.org/"
+ auth = SyncOrAsyncAuth()
+
+ with httpx.Client(transport=SyncMockTransport()) as client:
+ response = client.get(url, auth=auth)
+
+ assert response.status_code == 200
+ assert response.json() == {"auth": "sync-auth"}
--- /dev/null
+"""
+Unit tests for auth classes.
+
+Integration tests also exist in tests/client/test_auth.py
+"""
+import pytest
+
+import httpx
+
+
+def test_basic_auth():
+ auth = httpx.BasicAuth(username="user", password="pass")
+ request = httpx.Request("GET", "https://www.example.com")
+
+ # The initial request should include a basic auth header.
+ flow = auth.sync_auth_flow(request)
+ request = next(flow)
+ assert request.headers["Authorization"].startswith("Basic")
+
+ # No other requests are made.
+ response = httpx.Response(content=b"Hello, world!", status_code=200)
+ with pytest.raises(StopIteration):
+ flow.send(response)
+
+
+def test_digest_auth_with_200():
+ auth = httpx.DigestAuth(username="user", password="pass")
+ request = httpx.Request("GET", "https://www.example.com")
+
+ # The initial request should not include an auth header.
+ flow = auth.sync_auth_flow(request)
+ request = next(flow)
+ assert "Authorization" not in request.headers
+
+ # If a 200 response is returned, then no other requests are made.
+ response = httpx.Response(content=b"Hello, world!", status_code=200)
+ with pytest.raises(StopIteration):
+ flow.send(response)
+
+
+def test_digest_auth_with_401():
+ auth = httpx.DigestAuth(username="user", password="pass")
+ request = httpx.Request("GET", "https://www.example.com")
+
+ # The initial request should not include an auth header.
+ flow = auth.sync_auth_flow(request)
+ request = next(flow)
+ assert "Authorization" not in request.headers
+
+ # If a 401 response is returned, then a digest auth request is made.
+ headers = {
+ "WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
+ }
+ response = httpx.Response(
+ content=b"Auth required", status_code=401, headers=headers
+ )
+ request = flow.send(response)
+ assert request.headers["Authorization"].startswith("Digest")
+
+ # No other requests are made.
+ response = httpx.Response(content=b"Hello, world!", status_code=200)
+ with pytest.raises(StopIteration):
+ flow.send(response)