+import functools
import inspect
import typing
from types import TracebackType
import hstspreload
-from .auth import HTTPBasicAuth
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import ConcurrencyBackend
from .config import (
from .dispatch.connection_pool import ConnectionPool
from .dispatch.threaded import ThreadedDispatcher
from .dispatch.wsgi import WSGIDispatch
-from .exceptions import (
- HTTPError,
- InvalidURL,
- RedirectBodyUnavailable,
- RedirectLoop,
- TooManyRedirects,
+from .exceptions import HTTPError, InvalidURL
+from .middleware import (
+ BaseMiddleware,
+ BasicAuthMiddleware,
+ CustomAuthMiddleware,
+ RedirectMiddleware,
)
from .models import (
URL,
ResponseContent,
URLTypes,
)
-from .status_codes import codes
from .utils import get_netrc_login
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
app: typing.Callable = None,
backend: ConcurrencyBackend = None,
- trust_env: bool = None,
+ trust_env: bool = True,
):
if backend is None:
backend = AsyncioBackend()
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> AsyncResponse:
- if auth is None:
- auth = self.auth
-
- url = request.url
-
- if url.scheme not in ("http", "https"):
+ if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
- if auth is None:
- if url.username or url.password:
- auth = HTTPBasicAuth(username=url.username, password=url.password)
- elif self.trust_env if trust_env is None else trust_env:
- netrc_login = get_netrc_login(url.authority)
- if netrc_login:
- netrc_username, _, netrc_password = netrc_login
- auth = HTTPBasicAuth(
- username=netrc_username, password=netrc_password
- )
-
- if auth is not None:
- if isinstance(auth, tuple):
- auth = HTTPBasicAuth(username=auth[0], password=auth[1])
- request = auth(request)
-
- try:
- response = await self.send_handling_redirects(
- request,
- verify=verify,
- cert=cert,
- timeout=timeout,
- allow_redirects=allow_redirects,
- )
- except HTTPError as exc:
- # Add the original request to any HTTPError
- exc.request = request
- raise
-
- if not stream:
+ async def get_response(request: AsyncRequest) -> AsyncResponse:
try:
- await response.read()
- finally:
- await response.close()
-
- return response
-
- async def send_handling_redirects(
- self,
- request: AsyncRequest,
- *,
- cert: CertTypes = None,
- verify: VerifyTypes = None,
- timeout: TimeoutTypes = None,
- allow_redirects: bool = True,
- history: typing.List[AsyncResponse] = None,
- ) -> AsyncResponse:
- if history is None:
- history = []
-
- while True:
- # We perform these checks here, so that calls to `response.next()`
- # will raise redirect errors if appropriate.
- if len(history) > self.max_redirects:
- raise TooManyRedirects(response=history[-1])
- if request.url in (response.url for response in history):
- raise RedirectLoop(response=history[-1])
-
- response = await self.dispatch.send(
- request, verify=verify, cert=cert, timeout=timeout
- )
-
- should_close_response = True
- try:
- assert isinstance(response, AsyncResponse)
- response.history = list(history)
- self.cookies.extract_cookies(response)
- history.append(response)
-
- if allow_redirects and response.is_redirect:
- request = self.build_redirect_request(request, response)
- else:
- should_close_response = False
- break
- finally:
- if should_close_response:
- await response.close()
-
- if response.is_redirect:
-
- async def call_next() -> AsyncResponse:
- nonlocal request, response, verify, cert
- nonlocal allow_redirects, timeout, history
- request = self.build_redirect_request(request, response)
- response = await self.send_handling_redirects(
- request,
- allow_redirects=allow_redirects,
- verify=verify,
- cert=cert,
- timeout=timeout,
- history=history,
+ response = await self.dispatch.send(
+ request, verify=verify, cert=cert, timeout=timeout
)
- return response
+ except HTTPError as exc:
+ # Add the original request to any HTTPError
+ exc.request = request
+ raise
+
+ self.cookies.extract_cookies(response)
+ if not stream:
+ try:
+ await response.read()
+ finally:
+ await response.close()
- response.call_next = call_next # type: ignore
+ return response
- return response
+ def wrap(
+ get_response: typing.Callable, middleware: BaseMiddleware
+ ) -> typing.Callable:
+ return functools.partial(middleware, get_response=get_response)
- def build_redirect_request(
- self, request: AsyncRequest, response: AsyncResponse
- ) -> AsyncRequest:
- method = self.redirect_method(request, response)
- url = self.redirect_url(request, response)
- headers = self.redirect_headers(request, url)
- content = self.redirect_content(request, method, response)
- cookies = self.merge_cookies(request.cookies)
- return AsyncRequest(
- method=method, url=url, headers=headers, data=content, cookies=cookies
+ get_response = wrap(
+ get_response,
+ RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
)
- def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
- """
- When being redirected we may want to change the method of the request
- based on certain specs or browser behavior.
- """
- method = request.method
-
- # https://tools.ietf.org/html/rfc7231#section-6.4.4
- if response.status_code == codes.SEE_OTHER and method != "HEAD":
- method = "GET"
-
- # Do what the browsers do, despite standards...
- # Turn 302s into GETs.
- if response.status_code == codes.FOUND and method != "HEAD":
- method = "GET"
-
- # If a POST is responded to with a 301, turn it into a GET.
- # This bizarre behaviour is explained in 'requests' issue 1704.
- if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
- method = "GET"
+ auth_middleware = self._get_auth_middleware(
+ request=request,
+ trust_env=self.trust_env if trust_env is None else trust_env,
+ auth=self.auth if auth is None else auth,
+ )
- return method
+ if auth_middleware is not None:
+ get_response = wrap(get_response, auth_middleware)
- def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
- """
- Return the URL for the redirect to follow.
- """
- location = response.headers["Location"]
+ return await get_response(request)
- url = URL(location, allow_relative=True)
+ def _get_auth_middleware(
+ self, request: AsyncRequest, trust_env: bool, auth: AuthTypes = None
+ ) -> typing.Optional[BaseMiddleware]:
+ if isinstance(auth, tuple):
+ return BasicAuthMiddleware(username=auth[0], password=auth[1])
- # Facilitate relative 'Location' headers, as allowed by RFC 7231.
- # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
- if url.is_relative_url:
- url = request.url.join(url)
+ if callable(auth):
+ return CustomAuthMiddleware(auth=auth)
- # Attach previous fragment if needed (RFC 7231 7.1.2)
- if request.url.fragment and not url.fragment:
- url = url.copy_with(fragment=request.url.fragment)
+ if auth is not None:
+ raise TypeError(
+ 'When specified, "auth" must be a (username, password) tuple or '
+ "a callable with signature (AsyncRequest) -> AsyncRequest "
+ f"(got {auth!r})"
+ )
- return url
+ if request.url.username or request.url.password:
+ return BasicAuthMiddleware(
+ username=request.url.username, password=request.url.password
+ )
- def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
- """
- Strip Authorization headers when responses are redirected away from
- the origin.
- """
- headers = Headers(request.headers)
- if url.origin != request.url.origin:
- del headers["Authorization"]
- del headers["host"]
- return headers
+ if trust_env:
+ netrc_login = get_netrc_login(request.url.authority)
+ if netrc_login:
+ username, _, password = netrc_login
+ return BasicAuthMiddleware(username=username, password=password)
- def redirect_content(
- self, request: AsyncRequest, method: str, response: AsyncResponse
- ) -> bytes:
- """
- Return the body that should be used for the redirect request.
- """
- if method != request.method and method == "GET":
- return b""
- if request.is_streaming:
- raise RedirectBodyUnavailable(response=response)
- return request.content
+ return None
class AsyncClient(BaseClient):
--- /dev/null
+import functools
+import typing
+from base64 import b64encode
+
+from .config import DEFAULT_MAX_REDIRECTS
+from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
+from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
+from .status_codes import codes
+
+
+class BaseMiddleware:
+ async def __call__(
+ self, request: AsyncRequest, get_response: typing.Callable
+ ) -> AsyncResponse:
+ raise NotImplementedError # pragma: no cover
+
+
+class BasicAuthMiddleware(BaseMiddleware):
+ def __init__(
+ self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+ ):
+ if isinstance(username, str):
+ username = username.encode("latin1")
+
+ if isinstance(password, str):
+ password = password.encode("latin1")
+
+ userpass = b":".join((username, password))
+ token = b64encode(userpass).decode().strip()
+
+ self.authorization_header = f"Basic {token}"
+
+ async def __call__(
+ self, request: AsyncRequest, get_response: typing.Callable
+ ) -> AsyncResponse:
+ request.headers["Authorization"] = self.authorization_header
+ return await get_response(request)
+
+
+class CustomAuthMiddleware(BaseMiddleware):
+ def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
+ self.auth = auth
+
+ async def __call__(
+ self, request: AsyncRequest, get_response: typing.Callable
+ ) -> AsyncResponse:
+ request = self.auth(request)
+ return await get_response(request)
+
+
+class RedirectMiddleware(BaseMiddleware):
+ def __init__(
+ self,
+ allow_redirects: bool = True,
+ max_redirects: int = DEFAULT_MAX_REDIRECTS,
+ cookies: typing.Optional[Cookies] = None,
+ ):
+ self.allow_redirects = allow_redirects
+ self.max_redirects = max_redirects
+ self.cookies = cookies
+ self.history: typing.List[AsyncResponse] = []
+
+ async def __call__(
+ self, request: AsyncRequest, get_response: typing.Callable
+ ) -> AsyncResponse:
+ if len(self.history) > self.max_redirects:
+ raise TooManyRedirects()
+ if request.url in (response.url for response in self.history):
+ raise RedirectLoop()
+
+ response = await get_response(request)
+ response.history = list(self.history)
+
+ if not response.is_redirect:
+ return response
+
+ self.history.append(response)
+ next_request = self.build_redirect_request(request, response)
+
+ if self.allow_redirects:
+ return await self(next_request, get_response)
+
+ response.call_next = functools.partial(self, next_request, get_response)
+ return response
+
+ def build_redirect_request(
+ self, request: AsyncRequest, response: AsyncResponse
+ ) -> AsyncRequest:
+ method = self.redirect_method(request, response)
+ url = self.redirect_url(request, response)
+ headers = self.redirect_headers(request, url) # TODO: merge headers?
+ content = self.redirect_content(request, method)
+ cookies = Cookies(self.cookies)
+ cookies.update(request.cookies)
+ return AsyncRequest(
+ method=method, url=url, headers=headers, data=content, cookies=cookies
+ )
+
+ def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
+ """
+ When being redirected we may want to change the method of the request
+ based on certain specs or browser behavior.
+ """
+ method = request.method
+
+ # https://tools.ietf.org/html/rfc7231#section-6.4.4
+ if response.status_code == codes.SEE_OTHER and method != "HEAD":
+ method = "GET"
+
+ # Do what the browsers do, despite standards...
+ # Turn 302s into GETs.
+ if response.status_code == codes.FOUND and method != "HEAD":
+ method = "GET"
+
+ # If a POST is responded to with a 301, turn it into a GET.
+ # This bizarre behaviour is explained in 'requests' issue 1704.
+ if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
+ method = "GET"
+
+ return method
+
+ def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
+ """
+ Return the URL for the redirect to follow.
+ """
+ location = response.headers["Location"]
+
+ url = URL(location, allow_relative=True)
+
+ # Facilitate relative 'Location' headers, as allowed by RFC 7231.
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+ if url.is_relative_url:
+ url = request.url.join(url)
+
+ # Attach previous fragment if needed (RFC 7231 7.1.2)
+ if request.url.fragment and not url.fragment:
+ url = url.copy_with(fragment=request.url.fragment)
+
+ return url
+
+ def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
+ """
+ Strip Authorization headers when responses are redirected away from
+ the origin.
+ """
+ headers = Headers(request.headers)
+ if url.origin != request.url.origin:
+ del headers["Authorization"]
+ del headers["host"]
+ return headers
+
+ def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
+ """
+ Return the body that should be used for the redirect request.
+ """
+ if method != request.method and method == "GET":
+ return b""
+ if request.is_streaming:
+ raise RedirectBodyUnavailable()
+ return request.content