]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Auth (#65)
authorTom Christie <tom@tomchristie.com>
Wed, 15 May 2019 15:43:35 +0000 (16:43 +0100)
committerGitHub <noreply@github.com>
Wed, 15 May 2019 15:43:35 +0000 (16:43 +0100)
* Initial work towards auth

* Add auth support

* Add test for custom auth

* Support auth-in-URL

* Support auth-on-session

httpcore/auth.py [new file with mode: 0644]
httpcore/client.py
httpcore/models.py
tests/client/test_auth.py [new file with mode: 0644]

diff --git a/httpcore/auth.py b/httpcore/auth.py
new file mode 100644 (file)
index 0000000..49ff998
--- /dev/null
@@ -0,0 +1,38 @@
+import typing
+from base64 import b64encode
+
+from .models import Request
+
+
+class AuthBase:
+    """
+    Base class that all auth implementations derive from.
+    """
+
+    def __call__(self, request: Request) -> Request:
+        raise NotImplementedError("Auth hooks must be callable.")  # pragma: nocover
+
+
+class HTTPBasicAuth(AuthBase):
+    def __init__(
+        self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
+    ) -> None:
+        self.username = username
+        self.password = password
+
+    def __call__(self, request: Request) -> Request:
+        request.headers["Authorization"] = self.build_auth_header()
+        return request
+
+    def build_auth_header(self) -> str:
+        username, password = self.username, self.password
+
+        if isinstance(username, str):
+            username = username.encode("latin1")
+
+        if isinstance(password, str):
+            password = password.encode("latin1")
+
+        userpass = b":".join((username, password))
+        token = b64encode(userpass).decode().strip()
+        return f"Basic {token}"
index cb8ead9fbdad7c4f09972c943fad2509c81afffb..61db5975c7d4149490299db26b87e5f449a4f406 100644 (file)
@@ -2,6 +2,7 @@ import asyncio
 import typing
 from types import TracebackType
 
+from .auth import HTTPBasicAuth
 from .config import (
     DEFAULT_MAX_REDIRECTS,
     DEFAULT_POOL_LIMITS,
@@ -16,6 +17,7 @@ from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
 from .interfaces import ConcurrencyBackend, Dispatcher
 from .models import (
     URL,
+    AuthTypes,
     Headers,
     HeaderTypes,
     QueryParamTypes,
@@ -31,6 +33,7 @@ from .status_codes import codes
 class AsyncClient:
     def __init__(
         self,
+        auth: AuthTypes = None,
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@@ -43,6 +46,7 @@ class AsyncClient:
                 ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
             )
 
+        self.auth = auth
         self.max_redirects = max_redirects
         self.dispatch = dispatch
 
@@ -53,6 +57,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -63,6 +68,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -75,6 +81,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -85,6 +92,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -97,6 +105,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -107,6 +116,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -120,6 +130,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -131,6 +142,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -144,6 +156,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -155,6 +168,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -168,6 +182,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -179,6 +194,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -192,6 +208,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -203,6 +220,7 @@ class AsyncClient:
             query_params=query_params,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -217,6 +235,7 @@ class AsyncClient:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -228,6 +247,7 @@ class AsyncClient:
         response = await self.send(
             request,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -238,6 +258,37 @@ class AsyncClient:
         request.prepare()
 
     async def send(
+        self,
+        request: Request,
+        *,
+        stream: bool = False,
+        auth: AuthTypes = None,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+        allow_redirects: bool = True,
+    ) -> Response:
+        if auth is None:
+            auth = self.auth
+
+        url = request.url
+        if auth is None and (url.username or url.password):
+            auth = HTTPBasicAuth(username=url.username, password=url.password)
+
+        if auth is not None:
+            if isinstance(auth, tuple):
+                auth = HTTPBasicAuth(username=auth[0], password=auth[1])
+            request = auth(request)
+
+        response = await self.send_handling_redirects(
+            request,
+            stream=stream,
+            ssl=ssl,
+            timeout=timeout,
+            allow_redirects=allow_redirects,
+        )
+        return response
+
+    async def send_handling_redirects(
         self,
         request: Request,
         *,
@@ -273,7 +324,7 @@ class AsyncClient:
                 async def send_next() -> Response:
                     nonlocal request, response, ssl, allow_redirects, timeout, history
                     request = self.build_redirect_request(request, response)
-                    response = await self.send(
+                    response = await self.send_handling_redirects(
                         request,
                         stream=stream,
                         allow_redirects=allow_redirects,
@@ -375,6 +426,7 @@ class AsyncClient:
 class Client:
     def __init__(
         self,
+        auth: AuthTypes = None,
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@@ -383,6 +435,7 @@ class Client:
         backend: ConcurrencyBackend = None,
     ) -> None:
         self._client = AsyncClient(
+            auth=auth,
             ssl=ssl,
             timeout=timeout,
             pool_limits=pool_limits,
@@ -401,6 +454,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -412,6 +466,7 @@ class Client:
         response = self.send(
             request,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -425,6 +480,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -434,6 +490,7 @@ class Client:
             url,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -446,6 +503,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -455,6 +513,7 @@ class Client:
             url,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -467,6 +526,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -476,6 +536,7 @@ class Client:
             url,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -489,6 +550,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -499,6 +561,7 @@ class Client:
             data=data,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -512,6 +575,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -522,6 +586,7 @@ class Client:
             data=data,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -535,6 +600,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -545,6 +611,7 @@ class Client:
             data=data,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -558,6 +625,7 @@ class Client:
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -568,6 +636,7 @@ class Client:
             data=data,
             headers=headers,
             stream=stream,
+            auth=auth,
             allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
@@ -581,6 +650,7 @@ class Client:
         request: Request,
         *,
         stream: bool = False,
+        auth: AuthTypes = None,
         allow_redirects: bool = True,
         ssl: SSLConfig = None,
         timeout: TimeoutConfig = None,
@@ -589,6 +659,7 @@ class Client:
             self._client.send(
                 request,
                 stream=stream,
+                auth=auth,
                 allow_redirects=allow_redirects,
                 ssl=ssl,
                 timeout=timeout,
index 68ab4b2d72f31ec50ee949b93b4e93517dec8e64..663b64007533cab3ef717f3798b60671bb8f6994 100644 (file)
@@ -43,6 +43,11 @@ HeaderTypes = typing.Union[
     typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
 ]
 
+AuthTypes = typing.Union[
+    typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
+    typing.Callable[["Request"], "Request"]
+]
+
 RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
 
 ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
@@ -93,6 +98,16 @@ class URL:
     def authority(self) -> str:
         return self.components.authority or ""
 
+    @property
+    def username(self) -> str:
+        userinfo = self.components.userinfo or ""
+        return userinfo.partition(':')[0]
+
+    @property
+    def password(self) -> str:
+        userinfo = self.components.userinfo or ""
+        return userinfo.partition(':')[2]
+
     @property
     def host(self) -> str:
         return self.components.host or ""
diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py
new file mode 100644 (file)
index 0000000..e044acd
--- /dev/null
@@ -0,0 +1,72 @@
+import json
+from urllib.parse import parse_qs
+
+import pytest
+
+from httpcore import (
+    URL,
+    Client,
+    Dispatcher,
+    Request,
+    Response,
+    SSLConfig,
+    TimeoutConfig,
+)
+
+
+class MockDispatch(Dispatcher):
+    async def send(
+        self,
+        request: Request,
+        stream: bool = False,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        body = json.dumps({"auth": request.headers.get('Authorization')}).encode()
+        return Response(200, content=body, request=request)
+
+
+def test_basic_auth():
+    url = "https://example.org/"
+    auth = ('tomchristie', 'password123')
+
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
+
+
+def test_basic_auth_in_url():
+    url = "https://tomchristie:password123@example.org/"
+
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url)
+
+    assert response.status_code == 200
+    assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
+
+
+def test_basic_auth_on_session():
+    url = "https://example.org/"
+    auth = ('tomchristie', 'password123')
+
+    with Client(dispatch=MockDispatch(), auth=auth) as client:
+        response = client.get(url)
+
+    assert response.status_code == 200
+    assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
+
+
+def test_custom_auth():
+    url = "https://example.org/"
+
+    def auth(request):
+        request.headers['Authorization'] = 'Token 123'
+        return request
+
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url, auth=auth)
+
+    assert response.status_code == 200
+    assert json.loads(response.text) == {'auth': 'Token 123'}