]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor middleware (#325)
authorFlorimond Manca <florimond.manca@gmail.com>
Sat, 7 Sep 2019 21:22:11 +0000 (23:22 +0200)
committerGitHub <noreply@github.com>
Sat, 7 Sep 2019 21:22:11 +0000 (23:22 +0200)
* Split middleware into a subpackage

* Refactor basic auth header building

* Add encoding parameter to to_bytes()

httpx/client.py
httpx/middleware/__init__.py [new file with mode: 0644]
httpx/middleware/base.py [new file with mode: 0644]
httpx/middleware/basic_auth.py [new file with mode: 0644]
httpx/middleware/custom_auth.py [new file with mode: 0644]
httpx/middleware/redirect.py [moved from httpx/middleware.py with 74% similarity]
httpx/utils.py

index 471473bf9ec66cc260833b52233d9a086131b0a1..0e493e1993ccbf0e6c804e2b741298e61d3c0725 100644 (file)
@@ -23,12 +23,10 @@ from .dispatch.connection_pool import ConnectionPool
 from .dispatch.threaded import ThreadedDispatcher
 from .dispatch.wsgi import WSGIDispatch
 from .exceptions import HTTPError, InvalidURL
-from .middleware import (
-    BaseMiddleware,
-    BasicAuthMiddleware,
-    CustomAuthMiddleware,
-    RedirectMiddleware,
-)
+from .middleware.base import BaseMiddleware
+from .middleware.basic_auth import BasicAuthMiddleware
+from .middleware.custom_auth import CustomAuthMiddleware
+from .middleware.redirect import RedirectMiddleware
 from .models import (
     URL,
     AsyncRequest,
diff --git a/httpx/middleware/__init__.py b/httpx/middleware/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/httpx/middleware/base.py b/httpx/middleware/base.py
new file mode 100644 (file)
index 0000000..4ed76ea
--- /dev/null
@@ -0,0 +1,10 @@
+import typing
+
+from ..models import AsyncRequest, AsyncResponse
+
+
+class BaseMiddleware:
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        raise NotImplementedError  # pragma: no cover
diff --git a/httpx/middleware/basic_auth.py b/httpx/middleware/basic_auth.py
new file mode 100644 (file)
index 0000000..faffba2
--- /dev/null
@@ -0,0 +1,27 @@
+import typing
+from base64 import b64encode
+
+from ..models import AsyncRequest, AsyncResponse
+from ..utils import to_bytes
+from .base import BaseMiddleware
+
+
+class BasicAuthMiddleware(BaseMiddleware):
+    def __init__(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ):
+        self.authorization_header = build_basic_auth_header(username, password)
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        request.headers["Authorization"] = self.authorization_header
+        return await get_response(request)
+
+
+def build_basic_auth_header(
+    username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+) -> str:
+    userpass = b":".join((to_bytes(username), to_bytes(password)))
+    token = b64encode(userpass).decode().strip()
+    return f"Basic {token}"
diff --git a/httpx/middleware/custom_auth.py b/httpx/middleware/custom_auth.py
new file mode 100644 (file)
index 0000000..86548dd
--- /dev/null
@@ -0,0 +1,15 @@
+import typing
+
+from ..models import AsyncRequest, AsyncResponse
+from .base import BaseMiddleware
+
+
+class CustomAuthMiddleware(BaseMiddleware):
+    def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
+        self.auth = auth
+
+    async def __call__(
+        self, request: AsyncRequest, get_response: typing.Callable
+    ) -> AsyncResponse:
+        request = self.auth(request)
+        return await get_response(request)
similarity index 74%
rename from httpx/middleware.py
rename to httpx/middleware/redirect.py
index c98c85e573f0a70f4d348817c36b464a396a61c3..6f5c02e38656b7969e2db9985ee8b23ad6d31f21 100644 (file)
@@ -1,51 +1,11 @@
 import functools
 import typing
-from base64 import b64encode
 
-from .config import DEFAULT_MAX_REDIRECTS
-from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
-from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
-from .status_codes import codes
-
-
-class BaseMiddleware:
-    async def __call__(
-        self, request: AsyncRequest, get_response: typing.Callable
-    ) -> AsyncResponse:
-        raise NotImplementedError  # pragma: no cover
-
-
-class BasicAuthMiddleware(BaseMiddleware):
-    def __init__(
-        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
-    ):
-        if isinstance(username, str):
-            username = username.encode("latin1")
-
-        if isinstance(password, str):
-            password = password.encode("latin1")
-
-        userpass = b":".join((username, password))
-        token = b64encode(userpass).decode().strip()
-
-        self.authorization_header = f"Basic {token}"
-
-    async def __call__(
-        self, request: AsyncRequest, get_response: typing.Callable
-    ) -> AsyncResponse:
-        request.headers["Authorization"] = self.authorization_header
-        return await get_response(request)
-
-
-class CustomAuthMiddleware(BaseMiddleware):
-    def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
-        self.auth = auth
-
-    async def __call__(
-        self, request: AsyncRequest, get_response: typing.Callable
-    ) -> AsyncResponse:
-        request = self.auth(request)
-        return await get_response(request)
+from ..config import DEFAULT_MAX_REDIRECTS
+from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
+from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
+from ..status_codes import codes
+from .base import BaseMiddleware
 
 
 class RedirectMiddleware(BaseMiddleware):
index 0357ea2bd23398217f7fdd4e2b21daeb11980606..b80003ecc4cff2c86f788ef2c36caeadc755fe62 100644 (file)
@@ -169,3 +169,7 @@ def get_logger(name: str) -> logging.Logger:
             logger.addHandler(handler)
 
     return logging.getLogger(name)
+
+
+def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
+    return value.encode(encoding) if isinstance(value, str) else value