]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Client handles redirect + auth (#552)
authorTom Christie <tom@tomchristie.com>
Wed, 27 Nov 2019 12:10:10 +0000 (12:10 +0000)
committerGitHub <noreply@github.com>
Wed, 27 Nov 2019 12:10:10 +0000 (12:10 +0000)
* Drop sync client

* Drop unused imports

* Async only

* Update tests/test_decoders.py

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>
* Linting

* Update docs for async-only

* Import sorting

* Add async notes to docs

* Update README for 0.8 async switch

* Move auth away from middleware where possible

* Drop middleware sub-package

* Client.dispatcher -> Client.dispatch

* Docs tweak

* Linting

* Fix type checking issue

* Import ordering

* Fix up docstrings

* Minor docs fixes

* Linting

* Remove unused import

docs/index.md
httpx/__init__.py
httpx/auth.py [moved from httpx/middleware/digest_auth.py with 86% similarity]
httpx/client.py
httpx/dispatch/proxy_http.py
httpx/middleware.py [moved from httpx/middleware/base.py with 74% similarity]
httpx/middleware/__init__.py [deleted file]
httpx/middleware/basic_auth.py [deleted file]
httpx/middleware/custom_auth.py [deleted file]
httpx/middleware/redirect.py [deleted file]

index b0f98a0ab29568ca770e0864d4c779b9ff560f0d..846d58a49624539d3a0de61324286c7b04b91848 100644 (file)
@@ -24,13 +24,13 @@ HTTPX
 <em>A next-generation HTTP client for Python.</em>
 </div>
 
-HTTPX is an asynchronous HTTP client, that supports HTTP/2 and HTTP/1.1.
+HTTPX is an asynchronous client library that supports HTTP/1.1 and HTTP/2.
 
 It can be used in high-performance async web frameworks, using either asyncio
-or trio, and is able to support making large numbers of requests concurrently.
+or trio, and is able to support making large numbers of concurrent requests.
 
 !!! note
-    The 0.8 release switched HTTPX into focusing exclusively on the async
+    The 0.8 release switched HTTPX into focusing exclusively on providing an async
     client. It is possible that we'll look at re-introducing a sync API at a
     later date.
 
@@ -38,11 +38,10 @@ or trio, and is able to support making large numbers of requests concurrently.
 
 Let's get started...
 
-!!! note
-    The standard Python REPL does not allow top-level async statements.
+The standard Python REPL does not allow top-level async statements.
 
-    To run async examples directly you'll probably want to either use `ipython`,
-    or use Python 3.8 with `python -m asyncio`.
+To run these async examples you'll probably want to either use `ipython`,
+or use Python 3.8 with `python -m asyncio`.
 
 ```python
 >>> import httpx
index 9dec3af8cef076931b38dd975ee113151529ce83..bda76ae309ccac6058e0459b3e0a4421803e84e9 100644 (file)
@@ -1,5 +1,6 @@
 from .__version__ import __description__, __title__, __version__
 from .api import delete, get, head, options, patch, post, put, request
+from .auth import BasicAuth, DigestAuth
 from .client import Client
 from .concurrency.asyncio import AsyncioBackend
 from .concurrency.base import (
@@ -43,7 +44,6 @@ from .exceptions import (
     TooManyRedirects,
     WriteTimeout,
 )
-from .middleware.digest_auth import DigestAuth
 from .models import (
     URL,
     AuthTypes,
@@ -76,7 +76,9 @@ __all__ = [
     "patch",
     "put",
     "request",
+    "BasicAuth",
     "Client",
+    "DigestAuth",
     "AsyncioBackend",
     "USER_AGENT",
     "CertTypes",
similarity index 86%
rename from httpx/middleware/digest_auth.py
rename to httpx/auth.py
index ab1a1d775aa8e29dafecf347a5dab89f0661753a..eb93ff35928a6adb4b823f0f2ad47f963a3fefce 100644 (file)
@@ -3,15 +3,34 @@ import os
 import re
 import time
 import typing
+from base64 import b64encode
 from urllib.request import parse_http_list
 
-from ..exceptions import ProtocolError
-from ..models import Request, Response, StatusCode
-from ..utils import to_bytes, to_str, unquote
-from .base import BaseMiddleware
+from .exceptions import ProtocolError
+from .middleware import Middleware
+from .models import Request, Response
+from .utils import to_bytes, to_str, unquote
 
 
-class DigestAuth(BaseMiddleware):
+class BasicAuth:
+    def __init__(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ):
+        self.auth_header = self.build_auth_header(username, password)
+
+    def __call__(self, request: Request) -> Request:
+        request.headers["Authorization"] = self.auth_header
+        return request
+
+    def build_auth_header(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ) -> str:
+        userpass = b":".join((to_bytes(username), to_bytes(password)))
+        token = b64encode(userpass).decode().strip()
+        return f"Basic {token}"
+
+
+class DigestAuth(Middleware):
     ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
         "MD5": hashlib.md5,
         "MD5-SESS": hashlib.md5,
@@ -33,12 +52,10 @@ class DigestAuth(BaseMiddleware):
         self, request: Request, get_response: typing.Callable
     ) -> Response:
         response = await get_response(request)
-        if not (
-            StatusCode.is_client_error(response.status_code)
-            and "www-authenticate" in response.headers
-        ):
+        if response.status_code != 401 or "www-authenticate" not in response.headers:
             return response
 
+        await response.close()
         header = response.headers["www-authenticate"]
         try:
             challenge = DigestAuthChallenge.from_header(header)
index 50b33f782da29d76b7b33da411c9b004d27cde40..7d3034323d12830ebdb1f501116b3a348edc50cf 100644 (file)
@@ -5,6 +5,7 @@ from types import TracebackType
 
 import hstspreload
 
+from .auth import BasicAuth
 from .concurrency.asyncio import AsyncioBackend
 from .concurrency.base import ConcurrencyBackend
 from .config import (
@@ -21,11 +22,14 @@ from .dispatch.asgi import ASGIDispatch
 from .dispatch.base import Dispatcher
 from .dispatch.connection_pool import ConnectionPool
 from .dispatch.proxy_http import HTTPProxy
-from .exceptions import HTTPError, InvalidURL
-from .middleware.base import BaseMiddleware
-from .middleware.basic_auth import BasicAuthMiddleware
-from .middleware.custom_auth import CustomAuthMiddleware
-from .middleware.redirect import RedirectMiddleware
+from .exceptions import (
+    HTTPError,
+    InvalidURL,
+    RedirectBodyUnavailable,
+    RedirectLoop,
+    TooManyRedirects,
+)
+from .middleware import Middleware
 from .models import (
     URL,
     AuthTypes,
@@ -42,6 +46,7 @@ from .models import (
     Response,
     URLTypes,
 )
+from .status_codes import codes
 from .utils import ElapsedTimer, get_environment_proxies, get_logger, get_netrc
 
 logger = get_logger(__name__)
@@ -125,8 +130,6 @@ class Client:
         if app is not None:
             dispatch = ASGIDispatch(app=app, backend=backend)
 
-        self.trust_env = True if trust_env is None else trust_env
-
         if dispatch is None:
             dispatch = ConnectionPool(
                 verify=verify,
@@ -135,7 +138,7 @@ class Client:
                 http_versions=http_versions,
                 pool_limits=pool_limits,
                 backend=backend,
-                trust_env=self.trust_env,
+                trust_env=trust_env,
                 uds=uds,
             )
 
@@ -152,6 +155,7 @@ class Client:
         self._headers = Headers(headers)
         self._cookies = Cookies(cookies)
         self.max_redirects = max_redirects
+        self.trust_env = trust_env
         self.dispatch = dispatch
         self.concurrency_backend = backend
 
@@ -202,7 +206,82 @@ class Client:
     def params(self, params: QueryParamTypes) -> None:
         self._params = QueryParams(params)
 
+    async def request(
+        self,
+        method: str,
+        url: URLTypes,
+        *,
+        data: RequestData = None,
+        files: RequestFiles = None,
+        json: typing.Any = None,
+        params: QueryParamTypes = None,
+        headers: HeaderTypes = None,
+        cookies: CookieTypes = None,
+        stream: bool = False,
+        auth: AuthTypes = None,
+        allow_redirects: bool = True,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
+        trust_env: bool = None,
+    ) -> Response:
+        request = self.build_request(
+            method=method,
+            url=url,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+        )
+        response = await self.send(
+            request,
+            stream=stream,
+            auth=auth,
+            allow_redirects=allow_redirects,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            trust_env=trust_env,
+        )
+        return response
+
+    def build_request(
+        self,
+        method: str,
+        url: URLTypes,
+        *,
+        data: RequestData = None,
+        files: RequestFiles = None,
+        json: typing.Any = None,
+        params: QueryParamTypes = None,
+        headers: HeaderTypes = None,
+        cookies: CookieTypes = None,
+    ) -> Request:
+        """
+        Build and return a request instance.
+        """
+        url = self.merge_url(url)
+        headers = self.merge_headers(headers)
+        cookies = self.merge_cookies(cookies)
+        params = self.merge_queryparams(params)
+        return Request(
+            method,
+            url,
+            data=data,
+            files=files,
+            json=json,
+            params=params,
+            headers=headers,
+            cookies=cookies,
+        )
+
     def merge_url(self, url: URLTypes) -> URL:
+        """
+        Merge a URL argument together with any 'base_url' on the client,
+        to create the URL used for the outgoing request.
+        """
         url = self.base_url.join(relative_url=url)
         if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
             url = url.copy_with(scheme="https")
@@ -211,6 +290,10 @@ class Client:
     def merge_cookies(
         self, cookies: CookieTypes = None
     ) -> typing.Optional[CookieTypes]:
+        """
+        Merge a cookies argument together with any cookies on the client,
+        to create the cookies used for the outgoing request.
+        """
         if cookies or self.cookies:
             merged_cookies = Cookies(self.cookies)
             merged_cookies.update(cookies)
@@ -220,6 +303,10 @@ class Client:
     def merge_headers(
         self, headers: HeaderTypes = None
     ) -> typing.Optional[HeaderTypes]:
+        """
+        Merge a headers argument together with any headers on the client,
+        to create the headers used for the outgoing request.
+        """
         if headers or self.headers:
             merged_headers = Headers(self.headers)
             merged_headers.update(headers)
@@ -229,13 +316,17 @@ class Client:
     def merge_queryparams(
         self, params: QueryParamTypes = None
     ) -> typing.Optional[QueryParamTypes]:
+        """
+        Merge a queryparams argument together with any queryparams on the client,
+        to create the queryparams used for the outgoing request.
+        """
         if params or self.params:
             merged_queryparams = QueryParams(self.params)
             merged_queryparams.update(params)
             return merged_queryparams
         return params
 
-    async def _get_response(
+    async def send(
         self,
         request: Request,
         *,
@@ -250,92 +341,227 @@ class Client:
         if request.url.scheme not in ("http", "https"):
             raise InvalidURL('URL scheme must be "http" or "https".')
 
-        dispatch = self._dispatcher_for_request(request, self.proxies)
+        auth = self.auth if auth is None else auth
+        trust_env = self.trust_env if trust_env is None else trust_env
 
-        async def get_response(request: Request) -> Response:
-            try:
-                with ElapsedTimer() as timer:
-                    response = await dispatch.send(
-                        request, verify=verify, cert=cert, timeout=timeout
-                    )
-                response.elapsed = timer.elapsed
-                response.request = request
-            except HTTPError as exc:
-                # Add the original request to any HTTPError unless
-                # there'a already a request attached in the case of
-                # a ProxyError.
-                if exc.request is None:
-                    exc.request = request
-                raise
-
-            self.cookies.extract_cookies(response)
-            if not stream:
-                try:
-                    await response.read()
-                finally:
-                    await response.close()
-
-            status = f"{response.status_code} {response.reason_phrase}"
-            response_line = f"{response.http_version} {status}"
-            logger.debug(
-                f'HTTP Request: {request.method} {request.url} "{response_line}"'
+        if not isinstance(auth, Middleware):
+            request = self.authenticate(request, trust_env, auth)
+            response = await self.send_handling_redirects(
+                request,
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                allow_redirects=allow_redirects,
             )
+        else:
+            get_response = functools.partial(
+                self.send_handling_redirects,
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                allow_redirects=allow_redirects,
+            )
+            response = await auth(request, get_response)
 
-            return response
+        if not stream:
+            try:
+                await response.read()
+            finally:
+                await response.close()
 
-        def wrap(
-            get_response: typing.Callable, middleware: BaseMiddleware
-        ) -> typing.Callable:
-            return functools.partial(middleware, get_response=get_response)
+        return response
 
-        get_response = wrap(
-            get_response,
-            RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
-        )
+    def authenticate(
+        self, request: Request, trust_env: bool, auth: AuthTypes = None
+    ) -> "Request":
+        if auth is not None:
+            if isinstance(auth, tuple):
+                auth = BasicAuth(username=auth[0], password=auth[1])
+            return auth(request)
 
-        auth_middleware = self._get_auth_middleware(
-            request=request,
-            trust_env=self.trust_env if trust_env is None else trust_env,
-            auth=self.auth if auth is None else auth,
-        )
+        username, password = request.url.username, request.url.password
+        if username or password:
+            auth = BasicAuth(username=username, password=password)
+            return auth(request)
 
-        if auth_middleware is not None:
-            get_response = wrap(get_response, auth_middleware)
+        if trust_env:
+            netrc_info = self._get_netrc()
+            if netrc_info is not None:
+                netrc_login = netrc_info.authenticators(request.url.authority)
+                netrc_username, _, netrc_password = netrc_login or ("", None, None)
+                if netrc_password is not None:
+                    auth = BasicAuth(username=netrc_username, password=netrc_password)
+                    return auth(request)
 
-        return await get_response(request)
+        return request
 
-    def _get_auth_middleware(
-        self, request: Request, trust_env: bool, auth: AuthTypes = None
-    ) -> typing.Optional[BaseMiddleware]:
-        if isinstance(auth, tuple):
-            return BasicAuthMiddleware(username=auth[0], password=auth[1])
-        elif isinstance(auth, BaseMiddleware):
-            return auth
-        elif callable(auth):
-            return CustomAuthMiddleware(auth=auth)
+    async def send_handling_redirects(
+        self,
+        request: Request,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+        allow_redirects: bool = True,
+        history: typing.List[Response] = None,
+    ) -> Response:
+        if history is None:
+            history = []
 
-        if auth is not None:
-            raise TypeError(
-                'When specified, "auth" must be a (username, password) tuple or '
-                "a callable with signature (Request) -> Request "
-                f"(got {auth!r})"
-            )
+        while True:
+            if len(history) > self.max_redirects:
+                raise TooManyRedirects()
+            if request.url in (response.url for response in history):
+                raise RedirectLoop()
 
-        if request.url.username or request.url.password:
-            return BasicAuthMiddleware(
-                username=request.url.username, password=request.url.password
+            response = await self.send_single_request(
+                request, verify=verify, cert=cert, timeout=timeout
             )
+            response.history = list(history)
+
+            if not response.is_redirect:
+                return response
+
+            await response.close()
+            request = self.build_redirect_request(request, response)
+            history = history + [response]
+
+            if not allow_redirects:
+                response.call_next = functools.partial(
+                    self.send_handling_redirects,
+                    request=request,
+                    verify=verify,
+                    cert=cert,
+                    timeout=timeout,
+                    allow_redirects=False,
+                    history=history,
+                )
+                return response
+
+    def build_redirect_request(self, request: Request, response: Response) -> Request:
+        """
+        Given a request and a redirect response, return a new request that
+        should be used to effect the redirect.
+        """
+        method = self.redirect_method(request, response)
+        url = self.redirect_url(request, response)
+        headers = self.redirect_headers(request, url, method)
+        content = self.redirect_content(request, method)
+        cookies = Cookies(self.cookies)
+        return Request(
+            method=method, url=url, headers=headers, data=content, cookies=cookies
+        )
 
-        if trust_env:
-            netrc_info = self._get_netrc()
-            if netrc_info:
-                netrc_login = netrc_info.authenticators(request.url.authority)
-                if netrc_login:
-                    username, _, password = netrc_login
-                    assert password is not None
-                    return BasicAuthMiddleware(username=username, password=password)
+    def redirect_method(self, request: Request, response: Response) -> str:
+        """
+        When being redirected we may want to change the method of the request
+        based on certain specs or browser behavior.
+        """
+        method = request.method
+
+        # https://tools.ietf.org/html/rfc7231#section-6.4.4
+        if response.status_code == codes.SEE_OTHER and method != "HEAD":
+            method = "GET"
+
+        # Do what the browsers do, despite standards...
+        # Turn 302s into GETs.
+        if response.status_code == codes.FOUND and method != "HEAD":
+            method = "GET"
+
+        # If a POST is responded to with a 301, turn it into a GET.
+        # This bizarre behaviour is explained in 'requests' issue 1704.
+        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
+            method = "GET"
 
-        return None
+        return method
+
+    def redirect_url(self, request: Request, response: Response) -> URL:
+        """
+        Return the URL for the redirect to follow.
+        """
+        location = response.headers["Location"]
+
+        url = URL(location, allow_relative=True)
+
+        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
+        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+        if url.is_relative_url:
+            url = request.url.join(url)
+
+        # Attach previous fragment if needed (RFC 7231 7.1.2)
+        if request.url.fragment and not url.fragment:
+            url = url.copy_with(fragment=request.url.fragment)
+
+        return url
+
+    def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
+        """
+        Return the headers that should be used for the redirect request.
+        """
+        headers = Headers(request.headers)
+
+        if url.origin != request.url.origin:
+            # Strip Authorization headers when responses are redirected away from
+            # the origin.
+            headers.pop("Authorization", None)
+            headers["Host"] = url.authority
+
+        if method != request.method and method == "GET":
+            # If we've switch to a 'GET' request, then strip any headers which
+            # are only relevant to the request body.
+            headers.pop("Content-Length", None)
+            headers.pop("Transfer-Encoding", None)
+
+        # We should use the client cookie store to determine any cookie header,
+        # rather than whatever was on the original outgoing request.
+        headers.pop("Cookie", None)
+
+        return headers
+
+    def redirect_content(self, request: Request, method: str) -> bytes:
+        """
+        Return the body that should be used for the redirect request.
+        """
+        if method != request.method and method == "GET":
+            return b""
+        if request.is_streaming:
+            raise RedirectBodyUnavailable()
+        return request.content
+
+    async def send_single_request(
+        self,
+        request: Request,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> Response:
+        """
+        Sends a single request, without handling any redirections.
+        """
+
+        dispatcher = self._dispatcher_for_request(request, self.proxies)
+
+        try:
+            with ElapsedTimer() as timer:
+                response = await dispatcher.send(
+                    request, verify=verify, cert=cert, timeout=timeout
+                )
+            response.elapsed = timer.elapsed
+            response.request = request
+        except HTTPError as exc:
+            # Add the original request to any HTTPError unless
+            # there'a already a request attached in the case of
+            # a ProxyError.
+            if exc.request is None:
+                exc.request = request
+            raise
+
+        self.cookies.extract_cookies(response)
+
+        status = f"{response.status_code} {response.reason_phrase}"
+        response_line = f"{response.http_version} {status}"
+        logger.debug(f'HTTP Request: {request.method} {request.url} "{response_line}"')
+
+        return response
 
     @functools.lru_cache(1)
     def _get_netrc(self) -> typing.Optional[netrc.netrc]:
@@ -366,36 +592,6 @@ class Client:
 
         return self.dispatch
 
-    def build_request(
-        self,
-        method: str,
-        url: URLTypes,
-        *,
-        data: RequestData = None,
-        files: RequestFiles = None,
-        json: typing.Any = None,
-        params: QueryParamTypes = None,
-        headers: HeaderTypes = None,
-        cookies: CookieTypes = None,
-    ) -> Request:
-        """
-        Build and return a request instance.
-        """
-        url = self.merge_url(url)
-        headers = self.merge_headers(headers)
-        cookies = self.merge_cookies(cookies)
-        params = self.merge_queryparams(params)
-        return Request(
-            method,
-            url,
-            data=data,
-            files=files,
-            json=json,
-            params=params,
-            headers=headers,
-            cookies=cookies,
-        )
-
     async def get(
         self,
         url: URLTypes,
@@ -624,70 +820,6 @@ class Client:
             trust_env=trust_env,
         )
 
-    async def request(
-        self,
-        method: str,
-        url: URLTypes,
-        *,
-        data: RequestData = None,
-        files: RequestFiles = None,
-        json: typing.Any = None,
-        params: QueryParamTypes = None,
-        headers: HeaderTypes = None,
-        cookies: CookieTypes = None,
-        stream: bool = False,
-        auth: AuthTypes = None,
-        allow_redirects: bool = True,
-        cert: CertTypes = None,
-        verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
-        trust_env: bool = None,
-    ) -> Response:
-        request = self.build_request(
-            method=method,
-            url=url,
-            data=data,
-            files=files,
-            json=json,
-            params=params,
-            headers=headers,
-            cookies=cookies,
-        )
-        response = await self.send(
-            request,
-            stream=stream,
-            auth=auth,
-            allow_redirects=allow_redirects,
-            verify=verify,
-            cert=cert,
-            timeout=timeout,
-            trust_env=trust_env,
-        )
-        return response
-
-    async def send(
-        self,
-        request: Request,
-        *,
-        stream: bool = False,
-        auth: AuthTypes = None,
-        allow_redirects: bool = True,
-        verify: VerifyTypes = None,
-        cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
-        trust_env: bool = None,
-    ) -> Response:
-        return await self._get_response(
-            request=request,
-            stream=stream,
-            auth=auth,
-            allow_redirects=allow_redirects,
-            verify=verify,
-            cert=cert,
-            timeout=timeout,
-            trust_env=trust_env,
-        )
-
     async def close(self) -> None:
         await self.dispatch.close()
 
index aa7dbc5f004539cb1509d320e880aac04210a737..54516348f8a4a908f2eb77f48a5b19098b698384 100644 (file)
@@ -1,4 +1,5 @@
 import enum
+from base64 import b64encode
 
 import h11
 
@@ -14,7 +15,6 @@ from ..config import (
     VerifyTypes,
 )
 from ..exceptions import ProxyError
-from ..middleware.basic_auth import build_basic_auth_header
 from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
 from ..utils import get_logger
 from .connection import HTTPConnection
@@ -69,13 +69,18 @@ class HTTPProxy(ConnectionPool):
         if url.username or url.password:
             self.proxy_headers.setdefault(
                 "Proxy-Authorization",
-                build_basic_auth_header(url.username, url.password),
+                self.build_auth_header(url.username, url.password),
             )
             # Remove userinfo from the URL authority, e.g.:
             # 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
             credentials, _, authority = url.authority.rpartition("@")
             self.proxy_url = url.copy_with(authority=authority)
 
+    def build_auth_header(self, username: str, password: str) -> str:
+        userpass = (username.encode("utf-8"), password.encode("utf-8"))
+        token = b64encode(b":".join(userpass)).decode().strip()
+        return f"Basic {token}"
+
     async def acquire_connection(self, origin: Origin) -> HTTPConnection:
         if self.should_forward_origin(origin):
             logger.trace(
similarity index 74%
rename from httpx/middleware/base.py
rename to httpx/middleware.py
index e87176ce70f85679e02aa897f54e18e74fbc7b5c..7ef87076c8ae9741afcc1bf05317ce860d4dcaa2 100644 (file)
@@ -1,9 +1,9 @@
 import typing
 
-from ..models import Request, Response
+from .models import Request, Response
 
 
-class BaseMiddleware:
+class Middleware:
     async def __call__(
         self, request: Request, get_response: typing.Callable
     ) -> Response:
diff --git a/httpx/middleware/__init__.py b/httpx/middleware/__init__.py
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/httpx/middleware/basic_auth.py b/httpx/middleware/basic_auth.py
deleted file mode 100644 (file)
index cb945f1..0000000
+++ /dev/null
@@ -1,27 +0,0 @@
-import typing
-from base64 import b64encode
-
-from ..models import Request, Response
-from ..utils import to_bytes
-from .base import BaseMiddleware
-
-
-class BasicAuthMiddleware(BaseMiddleware):
-    def __init__(
-        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
-    ):
-        self.authorization_header = build_basic_auth_header(username, password)
-
-    async def __call__(
-        self, request: Request, get_response: typing.Callable
-    ) -> Response:
-        request.headers["Authorization"] = self.authorization_header
-        return await get_response(request)
-
-
-def build_basic_auth_header(
-    username: typing.Union[str, bytes], password: typing.Union[str, bytes]
-) -> str:
-    userpass = b":".join((to_bytes(username), to_bytes(password)))
-    token = b64encode(userpass).decode().strip()
-    return f"Basic {token}"
diff --git a/httpx/middleware/custom_auth.py b/httpx/middleware/custom_auth.py
deleted file mode 100644 (file)
index 2fd5e22..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-import typing
-
-from ..models import Request, Response
-from .base import BaseMiddleware
-
-
-class CustomAuthMiddleware(BaseMiddleware):
-    def __init__(self, auth: typing.Callable[[Request], Request]):
-        self.auth = auth
-
-    async def __call__(
-        self, request: Request, get_response: typing.Callable
-    ) -> Response:
-        request = self.auth(request)
-        return await get_response(request)
diff --git a/httpx/middleware/redirect.py b/httpx/middleware/redirect.py
deleted file mode 100644 (file)
index 2e0ca33..0000000
+++ /dev/null
@@ -1,130 +0,0 @@
-import functools
-import typing
-
-from ..config import DEFAULT_MAX_REDIRECTS
-from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
-from ..models import URL, Cookies, Headers, Request, Response
-from ..status_codes import codes
-from .base import BaseMiddleware
-
-
-class RedirectMiddleware(BaseMiddleware):
-    def __init__(
-        self,
-        allow_redirects: bool = True,
-        max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        cookies: typing.Optional[Cookies] = None,
-    ):
-        self.allow_redirects = allow_redirects
-        self.max_redirects = max_redirects
-        self.cookies = cookies
-        self.history: typing.List[Response] = []
-
-    async def __call__(
-        self, request: Request, get_response: typing.Callable
-    ) -> Response:
-        if len(self.history) > self.max_redirects:
-            raise TooManyRedirects()
-        if request.url in (response.url for response in self.history):
-            raise RedirectLoop()
-
-        response = await get_response(request)
-        response.history = list(self.history)
-
-        if not response.is_redirect:
-            return response
-
-        self.history.append(response)
-        next_request = self.build_redirect_request(request, response)
-
-        if self.allow_redirects:
-            return await self(next_request, get_response)
-
-        response.call_next = functools.partial(self, next_request, get_response)
-        return response
-
-    def build_redirect_request(self, request: Request, response: Response) -> Request:
-        method = self.redirect_method(request, response)
-        url = self.redirect_url(request, response)
-        headers = self.redirect_headers(request, url, method)  # TODO: merge headers?
-        content = self.redirect_content(request, method)
-        cookies = Cookies(self.cookies)
-        return Request(
-            method=method, url=url, headers=headers, data=content, cookies=cookies
-        )
-
-    def redirect_method(self, request: Request, response: Response) -> str:
-        """
-        When being redirected we may want to change the method of the request
-        based on certain specs or browser behavior.
-        """
-        method = request.method
-
-        # https://tools.ietf.org/html/rfc7231#section-6.4.4
-        if response.status_code == codes.SEE_OTHER and method != "HEAD":
-            method = "GET"
-
-        # Do what the browsers do, despite standards...
-        # Turn 302s into GETs.
-        if response.status_code == codes.FOUND and method != "HEAD":
-            method = "GET"
-
-        # If a POST is responded to with a 301, turn it into a GET.
-        # This bizarre behaviour is explained in 'requests' issue 1704.
-        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
-            method = "GET"
-
-        return method
-
-    def redirect_url(self, request: Request, response: Response) -> URL:
-        """
-        Return the URL for the redirect to follow.
-        """
-        location = response.headers["Location"]
-
-        url = URL(location, allow_relative=True)
-
-        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
-        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
-        if url.is_relative_url:
-            url = request.url.join(url)
-
-        # Attach previous fragment if needed (RFC 7231 7.1.2)
-        if request.url.fragment and not url.fragment:
-            url = url.copy_with(fragment=request.url.fragment)
-
-        return url
-
-    def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
-        """
-        Return the headers that should be used for the redirect request.
-        """
-        headers = Headers(request.headers)
-
-        if url.origin != request.url.origin:
-            # Strip Authorization headers when responses are redirected away from
-            # the origin.
-            headers.pop("Authorization", None)
-            headers["Host"] = url.authority
-
-        if method != request.method and method == "GET":
-            # If we've switch to a 'GET' request, then strip any headers which
-            # are only relevant to the request body.
-            headers.pop("Content-Length", None)
-            headers.pop("Transfer-Encoding", None)
-
-        # We should use the client cookie store to determine any cookie header,
-        # rather than whatever was on the original outgoing request.
-        headers.pop("Cookie", None)
-
-        return headers
-
-    def redirect_content(self, request: Request, method: str) -> bytes:
-        """
-        Return the body that should be used for the redirect request.
-        """
-        if method != request.method and method == "GET":
-            return b""
-        if request.is_streaming:
-            raise RedirectBodyUnavailable()
-        return request.content