import typing
+from .auth import AuthTypes
from .client import Client, StreamContextManager
from .config import DEFAULT_TIMEOUT_CONFIG, CertTypes, TimeoutTypes, VerifyTypes
from .models import (
- AuthTypes,
CookieTypes,
HeaderTypes,
QueryParamTypes,
from urllib.request import parse_http_list
from .exceptions import ProtocolError
-from .middleware import Middleware
from .models import Request, Response
from .utils import to_bytes, to_str, unquote
+AuthFlow = typing.Generator[Request, Response, None]
+
+AuthTypes = typing.Union[
+ typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
+ typing.Callable[["Request"], "Request"],
+ "Auth",
+]
+
+
+class Auth:
+ """
+ Base class for all authentication schemes.
+ """
+
+ def __call__(self, request: Request) -> AuthFlow:
+ """
+ Execute the authentication flow.
+
+ To dispatch a request, `yield` it:
+
+ ```
+ yield request
+ ```
+
+ The client will `.send()` the response back into the flow generator. You can
+ access it like so:
+
+ ```
+ response = yield request
+ ```
+
+ A `return` (or reaching the end of the generator) will result in the
+ client returning the last response obtained from the server.
+
+ You can dispatch as many requests as is necessary.
+ """
+ yield request
+
+
+class FunctionAuth(Auth):
+ """
+ Allows the 'auth' argument to be passed as a simple callable function,
+ that takes the request, and returns a new, modified request.
+ """
+
+ def __init__(self, func: typing.Callable[[Request], Request]) -> None:
+ self.func = func
+
+ def __call__(self, request: Request) -> AuthFlow:
+ yield self.func(request)
+
+
+class BasicAuth(Auth):
+ """
+ Allows the 'auth' argument to be passed as a (username, password) pair,
+ and uses HTTP Basic authentication.
+ """
-class BasicAuth:
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
self.auth_header = self.build_auth_header(username, password)
- def __call__(self, request: Request) -> Request:
+ def __call__(self, request: Request) -> AuthFlow:
request.headers["Authorization"] = self.auth_header
- return request
+ yield request
def build_auth_header(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
return f"Basic {token}"
-class DigestAuth(Middleware):
+class DigestAuth(Auth):
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
self.username = to_bytes(username)
self.password = to_bytes(password)
- async def __call__(
- self, request: Request, get_response: typing.Callable
- ) -> Response:
- response = await get_response(request)
+ def __call__(self, request: Request) -> AuthFlow:
+ response = yield request
+
if response.status_code != 401 or "www-authenticate" not in response.headers:
- return response
+ # If the response is not a 401 WWW-Authenticate, then we don't
+ # need to build an authenticated request.
+ return
- await response.close()
header = response.headers["www-authenticate"]
try:
challenge = DigestAuthChallenge.from_header(header)
raise ProtocolError("Malformed Digest authentication header")
request.headers["Authorization"] = self._build_auth_header(request, challenge)
- return await get_response(request)
+ yield request
def _build_auth_header(
self, request: Request, challenge: "DigestAuthChallenge"
import hstspreload
-from .auth import BasicAuth
+from .auth import Auth, AuthTypes, BasicAuth, FunctionAuth
from .concurrency.base import ConcurrencyBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
RedirectLoop,
TooManyRedirects,
)
-from .middleware import Middleware
from .models import (
URL,
- AuthTypes,
Cookies,
CookieTypes,
Headers,
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
- auth = self.auth if auth is None else auth
- trust_env = self.trust_env if trust_env is None else trust_env
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
- if not isinstance(auth, Middleware):
- request = self.authenticate(request, trust_env, auth)
- response = await self.send_handling_redirects(
- request,
- verify=verify,
- cert=cert,
- timeout=timeout,
- allow_redirects=allow_redirects,
- )
- else:
- get_response = functools.partial(
- self.send_handling_redirects,
- verify=verify,
- cert=cert,
- timeout=timeout,
- allow_redirects=allow_redirects,
- )
- response = await auth(request, get_response)
+ auth = self.setup_auth(request, trust_env, auth)
+
+ response = await self.send_handling_redirects(
+ request,
+ auth=auth,
+ verify=verify,
+ cert=cert,
+ timeout=timeout,
+ allow_redirects=allow_redirects,
+ )
if not stream:
try:
return response
- def authenticate(
- self, request: Request, trust_env: bool, auth: AuthTypes = None
- ) -> "Request":
+ def setup_auth(
+ self, request: Request, trust_env: bool = None, auth: AuthTypes = None
+ ) -> Auth:
+ auth = self.auth if auth is None else auth
+ trust_env = self.trust_env if trust_env is None else trust_env
+
if auth is not None:
if isinstance(auth, tuple):
- auth = BasicAuth(username=auth[0], password=auth[1])
- return auth(request)
+ return BasicAuth(username=auth[0], password=auth[1])
+ elif isinstance(auth, Auth):
+ return auth
+ elif callable(auth):
+ return FunctionAuth(func=auth)
+ raise TypeError('Invalid "auth" argument.')
username, password = request.url.username, request.url.password
if username or password:
- auth = BasicAuth(username=username, password=password)
- return auth(request)
+ return BasicAuth(username=username, password=password)
if trust_env and "Authorization" not in request.headers:
credentials = self.netrc.get_credentials(request.url.authority)
if credentials is not None:
- auth = BasicAuth(username=credentials[0], password=credentials[1])
- return auth(request)
+ return BasicAuth(username=credentials[0], password=credentials[1])
- return request
+ return Auth()
async def send_handling_redirects(
self,
request: Request,
+ auth: Auth,
timeout: Timeout,
verify: VerifyTypes = None,
cert: CertTypes = None,
if request.url in (response.url for response in history):
raise RedirectLoop()
- response = await self.send_single_request(
- request, verify=verify, cert=cert, timeout=timeout
+ response = await self.send_handling_auth(
+ request, auth=auth, timeout=timeout, verify=verify, cert=cert
)
response.history = list(history)
response.call_next = functools.partial(
self.send_handling_redirects,
request=request,
+ auth=auth,
verify=verify,
cert=cert,
timeout=timeout,
raise RedirectBodyUnavailable()
return request.content
+ async def send_handling_auth(
+ self,
+ request: Request,
+ auth: Auth,
+ timeout: Timeout,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ ) -> Response:
+ auth_flow = auth(request)
+ request = next(auth_flow)
+ while True:
+ response = await self.send_single_request(request, timeout, verify, cert)
+ try:
+ next_request = auth_flow.send(response)
+ except StopIteration:
+ return response
+ except BaseException as exc:
+ await response.close()
+ raise exc from None
+ else:
+ request = next_request
+ await response.close()
+
async def send_single_request(
self,
request: Request,
+++ /dev/null
-import typing
-
-from .models import Request, Response
-
-
-class Middleware:
- async def __call__(
- self, request: Request, get_response: typing.Callable
- ) -> Response:
- raise NotImplementedError # pragma: no cover
CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
-AuthTypes = typing.Union[
- typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
- typing.Callable[["Request"], "Request"],
- "BaseMiddleware",
-]
-
ProxiesTypes = typing.Union[
URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
]
os.environ["NETRC"] = "tests/.netrc"
url = "http://netrcexample.org"
- client = Client(dispatch=MockDispatch())
- response = await client.get(url, trust_env=False)
+ client = Client(dispatch=MockDispatch(), trust_env=False)
+ response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": None}
- client = Client(dispatch=MockDispatch(), trust_env=False)
- response = await client.get(url, trust_env=True)
+ client = Client(dispatch=MockDispatch(), trust_env=True)
+ response = await client.get(url)
assert response.status_code == 200
assert response.json() == {