]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor client functionality into middleware (#268)
authorFlorimond Manca <florimond.manca@gmail.com>
Sun, 1 Sep 2019 21:01:14 +0000 (23:01 +0200)
committerGitHub <noreply@github.com>
Sun, 1 Sep 2019 21:01:14 +0000 (23:01 +0200)
Co-authored-by: yeraydiazdiaz <yeraydiazdiaz@gmail.com>
* Dispatcher middlewares

* Redirect and BasicAuth dispatchers
* Remove HTTPBasicAuth and reinstate trust_env logic
* Call resolve dispatcher correctly
* Fix redirection tests
* Add basic and custom auth dispatchers
* Reinstate extracting cookies from response
* Fix linting

* Refactor middleware interface

httpx/auth.py [deleted file]
httpx/client.py
httpx/middleware.py [new file with mode: 0644]
tests/client/test_auth.py
tests/client/test_redirects.py

diff --git a/httpx/auth.py b/httpx/auth.py
deleted file mode 100644 (file)
index 6a39c1b..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-import typing
-from base64 import b64encode
-
-from .models import AsyncRequest
-
-
-class AuthBase:
-    """
-    Base class that all auth implementations derive from.
-    """
-
-    def __call__(self, request: AsyncRequest) -> AsyncRequest:
-        raise NotImplementedError("Auth hooks must be callable.")  # pragma: nocover
-
-
-class HTTPBasicAuth(AuthBase):
-    def __init__(
-        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
-    ) -> None:
-        self.username = username
-        self.password = password
-
-    def __call__(self, request: AsyncRequest) -> AsyncRequest:
-        request.headers["Authorization"] = self.build_auth_header()
-        return request
-
-    def build_auth_header(self) -> str:
-        username, password = self.username, self.password
-
-        if isinstance(username, str):
-            username = username.encode("latin1")
-
-        if isinstance(password, str):
-            password = password.encode("latin1")
-
-        userpass = b":".join((username, password))
-        token = b64encode(userpass).decode().strip()
-        return f"Basic {token}"
index 3c135e29cafb15ad21a88e7eef8efe8267b1fbd7..471473bf9ec66cc260833b52233d9a086131b0a1 100644 (file)
@@ -1,10 +1,10 @@
+import functools
 import inspect
 import typing
 from types import TracebackType
 
 import hstspreload
 
-from .auth import HTTPBasicAuth
 from .concurrency.asyncio import AsyncioBackend
 from .concurrency.base import ConcurrencyBackend
 from .config import (
@@ -22,12 +22,12 @@ from .dispatch.base import AsyncDispatcher, Dispatcher
 from .dispatch.connection_pool import ConnectionPool
 from .dispatch.threaded import ThreadedDispatcher
 from .dispatch.wsgi import WSGIDispatch
-from .exceptions import (
-    HTTPError,
-    InvalidURL,
-    RedirectBodyUnavailable,
-    RedirectLoop,
-    TooManyRedirects,
+from .exceptions import HTTPError, InvalidURL
+from .middleware import (
+    BaseMiddleware,
+    BasicAuthMiddleware,
+    CustomAuthMiddleware,
+    RedirectMiddleware,
 )
 from .models import (
     URL,
@@ -47,7 +47,6 @@ from .models import (
     ResponseContent,
     URLTypes,
 )
-from .status_codes import codes
 from .utils import get_netrc_login
 
 
@@ -67,7 +66,7 @@ class BaseClient:
         dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
         app: typing.Callable = None,
         backend: ConcurrencyBackend = None,
-        trust_env: bool = None,
+        trust_env: bool = True,
     ):
         if backend is None:
             backend = AsyncioBackend()
@@ -166,188 +165,77 @@ class BaseClient:
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
     ) -> AsyncResponse:
-        if auth is None:
-            auth = self.auth
-
-        url = request.url
-
-        if url.scheme not in ("http", "https"):
+        if request.url.scheme not in ("http", "https"):
             raise InvalidURL('URL scheme must be "http" or "https".')
 
-        if auth is None:
-            if url.username or url.password:
-                auth = HTTPBasicAuth(username=url.username, password=url.password)
-            elif self.trust_env if trust_env is None else trust_env:
-                netrc_login = get_netrc_login(url.authority)
-                if netrc_login:
-                    netrc_username, _, netrc_password = netrc_login
-                    auth = HTTPBasicAuth(
-                        username=netrc_username, password=netrc_password
-                    )
-
-        if auth is not None:
-            if isinstance(auth, tuple):
-                auth = HTTPBasicAuth(username=auth[0], password=auth[1])
-            request = auth(request)
-
-        try:
-            response = await self.send_handling_redirects(
-                request,
-                verify=verify,
-                cert=cert,
-                timeout=timeout,
-                allow_redirects=allow_redirects,
-            )
-        except HTTPError as exc:
-            # Add the original request to any HTTPError
-            exc.request = request
-            raise
-
-        if not stream:
+        async def get_response(request: AsyncRequest) -> AsyncResponse:
             try:
-                await response.read()
-            finally:
-                await response.close()
-
-        return response
-
-    async def send_handling_redirects(
-        self,
-        request: AsyncRequest,
-        *,
-        cert: CertTypes = None,
-        verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
-        allow_redirects: bool = True,
-        history: typing.List[AsyncResponse] = None,
-    ) -> AsyncResponse:
-        if history is None:
-            history = []
-
-        while True:
-            # We perform these checks here, so that calls to `response.next()`
-            # will raise redirect errors if appropriate.
-            if len(history) > self.max_redirects:
-                raise TooManyRedirects(response=history[-1])
-            if request.url in (response.url for response in history):
-                raise RedirectLoop(response=history[-1])
-
-            response = await self.dispatch.send(
-                request, verify=verify, cert=cert, timeout=timeout
-            )
-
-            should_close_response = True
-            try:
-                assert isinstance(response, AsyncResponse)
-                response.history = list(history)
-                self.cookies.extract_cookies(response)
-                history.append(response)
-
-                if allow_redirects and response.is_redirect:
-                    request = self.build_redirect_request(request, response)
-                else:
-                    should_close_response = False
-                    break
-            finally:
-                if should_close_response:
-                    await response.close()
-
-        if response.is_redirect:
-
-            async def call_next() -> AsyncResponse:
-                nonlocal request, response, verify, cert
-                nonlocal allow_redirects, timeout, history
-                request = self.build_redirect_request(request, response)
-                response = await self.send_handling_redirects(
-                    request,
-                    allow_redirects=allow_redirects,
-                    verify=verify,
-                    cert=cert,
-                    timeout=timeout,
-                    history=history,
+                response = await self.dispatch.send(
+                    request, verify=verify, cert=cert, timeout=timeout
                 )
-                return response
+            except HTTPError as exc:
+                # Add the original request to any HTTPError
+                exc.request = request
+                raise
+
+            self.cookies.extract_cookies(response)
+            if not stream:
+                try:
+                    await response.read()
+                finally:
+                    await response.close()
 
-            response.call_next = call_next  # type: ignore
+            return response
 
-        return response
+        def wrap(
+            get_response: typing.Callable, middleware: BaseMiddleware
+        ) -> typing.Callable:
+            return functools.partial(middleware, get_response=get_response)
 
-    def build_redirect_request(
-        self, request: AsyncRequest, response: AsyncResponse
-    ) -> AsyncRequest:
-        method = self.redirect_method(request, response)
-        url = self.redirect_url(request, response)
-        headers = self.redirect_headers(request, url)
-        content = self.redirect_content(request, method, response)
-        cookies = self.merge_cookies(request.cookies)
-        return AsyncRequest(
-            method=method, url=url, headers=headers, data=content, cookies=cookies
+        get_response = wrap(
+            get_response,
+            RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
         )
 
-    def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
-        """
-        When being redirected we may want to change the method of the request
-        based on certain specs or browser behavior.
-        """
-        method = request.method
-
-        # https://tools.ietf.org/html/rfc7231#section-6.4.4
-        if response.status_code == codes.SEE_OTHER and method != "HEAD":
-            method = "GET"
-
-        # Do what the browsers do, despite standards...
-        # Turn 302s into GETs.
-        if response.status_code == codes.FOUND and method != "HEAD":
-            method = "GET"
-
-        # If a POST is responded to with a 301, turn it into a GET.
-        # This bizarre behaviour is explained in 'requests' issue 1704.
-        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
-            method = "GET"
+        auth_middleware = self._get_auth_middleware(
+            request=request,
+            trust_env=self.trust_env if trust_env is None else trust_env,
+            auth=self.auth if auth is None else auth,
+        )
 
-        return method
+        if auth_middleware is not None:
+            get_response = wrap(get_response, auth_middleware)
 
-    def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
-        """
-        Return the URL for the redirect to follow.
-        """
-        location = response.headers["Location"]
+        return await get_response(request)
 
-        url = URL(location, allow_relative=True)
+    def _get_auth_middleware(
+        self, request: AsyncRequest, trust_env: bool, auth: AuthTypes = None
+    ) -> typing.Optional[BaseMiddleware]:
+        if isinstance(auth, tuple):
+            return BasicAuthMiddleware(username=auth[0], password=auth[1])
 
-        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
-        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
-        if url.is_relative_url:
-            url = request.url.join(url)
+        if callable(auth):
+            return CustomAuthMiddleware(auth=auth)
 
-        # Attach previous fragment if needed (RFC 7231 7.1.2)
-        if request.url.fragment and not url.fragment:
-            url = url.copy_with(fragment=request.url.fragment)
+        if auth is not None:
+            raise TypeError(
+                'When specified, "auth" must be a (username, password) tuple or '
+                "a callable with signature (AsyncRequest) -> AsyncRequest "
+                f"(got {auth!r})"
+            )
 
-        return url
+        if request.url.username or request.url.password:
+            return BasicAuthMiddleware(
+                username=request.url.username, password=request.url.password
+            )
 
-    def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
-        """
-        Strip Authorization headers when responses are redirected away from
-        the origin.
-        """
-        headers = Headers(request.headers)
-        if url.origin != request.url.origin:
-            del headers["Authorization"]
-            del headers["host"]
-        return headers
+        if trust_env:
+            netrc_login = get_netrc_login(request.url.authority)
+            if netrc_login:
+                username, _, password = netrc_login
+                return BasicAuthMiddleware(username=username, password=password)
 
-    def redirect_content(
-        self, request: AsyncRequest, method: str, response: AsyncResponse
-    ) -> bytes:
-        """
-        Return the body that should be used for the redirect request.
-        """
-        if method != request.method and method == "GET":
-            return b""
-        if request.is_streaming:
-            raise RedirectBodyUnavailable(response=response)
-        return request.content
+        return None
 
 
 class AsyncClient(BaseClient):
diff --git a/httpx/middleware.py b/httpx/middleware.py
new file mode 100644 (file)
index 0000000..4ed750e
--- /dev/null
@@ -0,0 +1,160 @@
+import functools
+import typing
+from base64 import b64encode
+
+from .config import DEFAULT_MAX_REDIRECTS
+from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
+from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
+from .status_codes import codes
+
+
+class BaseMiddleware:
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        raise NotImplementedError  # pragma: no cover
+
+
+class BasicAuthMiddleware(BaseMiddleware):
+    def __init__(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ):
+        if isinstance(username, str):
+            username = username.encode("latin1")
+
+        if isinstance(password, str):
+            password = password.encode("latin1")
+
+        userpass = b":".join((username, password))
+        token = b64encode(userpass).decode().strip()
+
+        self.authorization_header = f"Basic {token}"
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        request.headers["Authorization"] = self.authorization_header
+        return await get_response(request)
+
+
+class CustomAuthMiddleware(BaseMiddleware):
+    def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
+        self.auth = auth
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        request = self.auth(request)
+        return await get_response(request)
+
+
+class RedirectMiddleware(BaseMiddleware):
+    def __init__(
+        self,
+        allow_redirects: bool = True,
+        max_redirects: int = DEFAULT_MAX_REDIRECTS,
+        cookies: typing.Optional[Cookies] = None,
+    ):
+        self.allow_redirects = allow_redirects
+        self.max_redirects = max_redirects
+        self.cookies = cookies
+        self.history: typing.List[AsyncResponse] = []
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        if len(self.history) > self.max_redirects:
+            raise TooManyRedirects()
+        if request.url in (response.url for response in self.history):
+            raise RedirectLoop()
+
+        response = await get_response(request)
+        response.history = list(self.history)
+
+        if not response.is_redirect:
+            return response
+
+        self.history.append(response)
+        next_request = self.build_redirect_request(request, response)
+
+        if self.allow_redirects:
+            return await self(next_request, get_response)
+
+        response.call_next = functools.partial(self, next_request, get_response)
+        return response
+
+    def build_redirect_request(
+        self, request: AsyncRequest, response: AsyncResponse
+    ) -> AsyncRequest:
+        method = self.redirect_method(request, response)
+        url = self.redirect_url(request, response)
+        headers = self.redirect_headers(request, url)  # TODO: merge headers?
+        content = self.redirect_content(request, method)
+        cookies = Cookies(self.cookies)
+        cookies.update(request.cookies)
+        return AsyncRequest(
+            method=method, url=url, headers=headers, data=content, cookies=cookies
+        )
+
+    def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
+        """
+        When being redirected we may want to change the method of the request
+        based on certain specs or browser behavior.
+        """
+        method = request.method
+
+        # https://tools.ietf.org/html/rfc7231#section-6.4.4
+        if response.status_code == codes.SEE_OTHER and method != "HEAD":
+            method = "GET"
+
+        # Do what the browsers do, despite standards...
+        # Turn 302s into GETs.
+        if response.status_code == codes.FOUND and method != "HEAD":
+            method = "GET"
+
+        # If a POST is responded to with a 301, turn it into a GET.
+        # This bizarre behaviour is explained in 'requests' issue 1704.
+        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
+            method = "GET"
+
+        return method
+
+    def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
+        """
+        Return the URL for the redirect to follow.
+        """
+        location = response.headers["Location"]
+
+        url = URL(location, allow_relative=True)
+
+        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
+        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+        if url.is_relative_url:
+            url = request.url.join(url)
+
+        # Attach previous fragment if needed (RFC 7231 7.1.2)
+        if request.url.fragment and not url.fragment:
+            url = url.copy_with(fragment=request.url.fragment)
+
+        return url
+
+    def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
+        """
+        Strip Authorization headers when responses are redirected away from
+        the origin.
+        """
+        headers = Headers(request.headers)
+        if url.origin != request.url.origin:
+            del headers["Authorization"]
+            del headers["host"]
+        return headers
+
+    def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
+        """
+        Return the body that should be used for the redirect request.
+        """
+        if method != request.method and method == "GET":
+            return b""
+        if request.is_streaming:
+            raise RedirectBodyUnavailable()
+        return request.content
index 725ea56c00491aec0819ea66d17be5acabbb28ac..fc3b192f496d17ff5bd2d28a0c62903b5be36450 100644 (file)
@@ -1,6 +1,8 @@
 import json
 import os
 
+import pytest
+
 from httpx import (
     URL,
     AsyncDispatcher,
@@ -118,3 +120,10 @@ def test_auth_hidden_header():
         response = client.get(url, auth=auth)
 
     assert "'authorization': '[secure]'" in str(response.request.headers)
+
+
+def test_auth_invalid_type():
+    url = "https://example.org/"
+    with Client(dispatch=MockDispatch(), auth="not a tuple, not a callable") as client:
+        with pytest.raises(TypeError):
+            client.get(url)
index 38c330e2efb824859e8c07ef375ccf4e01e7e3df..7daf6c801e2cc5d567f071983257d7ea39f5c886 100644 (file)
@@ -262,7 +262,7 @@ async def test_cannot_redirect_streaming_body(backend):
         await client.post(url, data=streaming_body())
 
 
-async def test_cross_dubdomain_redirect(backend):
+async def test_cross_subdomain_redirect(backend):
     client = AsyncClient(dispatch=MockDispatch(), backend=backend)
     url = "https://example.com/cross_subdomain"
     response = await client.get(url)