From: Tom Christie Date: Wed, 15 May 2019 15:43:35 +0000 (+0100) Subject: Auth (#65) X-Git-Tag: 0.3.0~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5afa7dd5ccc223cbb9118b65e278136e8f7a4dc2;p=thirdparty%2Fhttpx.git Auth (#65) * Initial work towards auth * Add auth support * Add test for custom auth * Support auth-in-URL * Support auth-on-session --- diff --git a/httpcore/auth.py b/httpcore/auth.py new file mode 100644 index 00000000..49ff998b --- /dev/null +++ b/httpcore/auth.py @@ -0,0 +1,38 @@ +import typing +from base64 import b64encode + +from .models import Request + + +class AuthBase: + """ + Base class that all auth implementations derive from. + """ + + def __call__(self, request: Request) -> Request: + raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover + + +class HTTPBasicAuth(AuthBase): + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> None: + self.username = username + self.password = password + + def __call__(self, request: Request) -> Request: + request.headers["Authorization"] = self.build_auth_header() + return request + + def build_auth_header(self) -> str: + username, password = self.username, self.password + + if isinstance(username, str): + username = username.encode("latin1") + + if isinstance(password, str): + password = password.encode("latin1") + + userpass = b":".join((username, password)) + token = b64encode(userpass).decode().strip() + return f"Basic {token}" diff --git a/httpcore/client.py b/httpcore/client.py index cb8ead9f..61db5975 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -2,6 +2,7 @@ import asyncio import typing from types import TracebackType +from .auth import HTTPBasicAuth from .config import ( DEFAULT_MAX_REDIRECTS, DEFAULT_POOL_LIMITS, @@ -16,6 +17,7 @@ from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects from .interfaces import ConcurrencyBackend, Dispatcher from .models import ( URL, + AuthTypes, Headers, HeaderTypes, QueryParamTypes, @@ -31,6 +33,7 @@ from .status_codes import codes class AsyncClient: def __init__( self, + auth: AuthTypes = None, ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, @@ -43,6 +46,7 @@ class AsyncClient: ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend ) + self.auth = auth self.max_redirects = max_redirects self.dispatch = dispatch @@ -53,6 +57,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -63,6 +68,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -75,6 +81,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -85,6 +92,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -97,6 +105,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -107,6 +116,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -120,6 +130,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -131,6 +142,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -144,6 +156,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -155,6 +168,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -168,6 +182,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -179,6 +194,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -192,6 +208,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -203,6 +220,7 @@ class AsyncClient: query_params=query_params, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -217,6 +235,7 @@ class AsyncClient: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -228,6 +247,7 @@ class AsyncClient: response = await self.send( request, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -238,6 +258,37 @@ class AsyncClient: request.prepare() async def send( + self, + request: Request, + *, + stream: bool = False, + auth: AuthTypes = None, + ssl: SSLConfig = None, + timeout: TimeoutConfig = None, + allow_redirects: bool = True, + ) -> Response: + if auth is None: + auth = self.auth + + url = request.url + if auth is None and (url.username or url.password): + auth = HTTPBasicAuth(username=url.username, password=url.password) + + if auth is not None: + if isinstance(auth, tuple): + auth = HTTPBasicAuth(username=auth[0], password=auth[1]) + request = auth(request) + + response = await self.send_handling_redirects( + request, + stream=stream, + ssl=ssl, + timeout=timeout, + allow_redirects=allow_redirects, + ) + return response + + async def send_handling_redirects( self, request: Request, *, @@ -273,7 +324,7 @@ class AsyncClient: async def send_next() -> Response: nonlocal request, response, ssl, allow_redirects, timeout, history request = self.build_redirect_request(request, response) - response = await self.send( + response = await self.send_handling_redirects( request, stream=stream, allow_redirects=allow_redirects, @@ -375,6 +426,7 @@ class AsyncClient: class Client: def __init__( self, + auth: AuthTypes = None, ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, @@ -383,6 +435,7 @@ class Client: backend: ConcurrencyBackend = None, ) -> None: self._client = AsyncClient( + auth=auth, ssl=ssl, timeout=timeout, pool_limits=pool_limits, @@ -401,6 +454,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -412,6 +466,7 @@ class Client: response = self.send( request, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -425,6 +480,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -434,6 +490,7 @@ class Client: url, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -446,6 +503,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -455,6 +513,7 @@ class Client: url, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -467,6 +526,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -476,6 +536,7 @@ class Client: url, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -489,6 +550,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -499,6 +561,7 @@ class Client: data=data, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -512,6 +575,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -522,6 +586,7 @@ class Client: data=data, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -535,6 +600,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -545,6 +611,7 @@ class Client: data=data, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -558,6 +625,7 @@ class Client: query_params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -568,6 +636,7 @@ class Client: data=data, headers=headers, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, @@ -581,6 +650,7 @@ class Client: request: Request, *, stream: bool = False, + auth: AuthTypes = None, allow_redirects: bool = True, ssl: SSLConfig = None, timeout: TimeoutConfig = None, @@ -589,6 +659,7 @@ class Client: self._client.send( request, stream=stream, + auth=auth, allow_redirects=allow_redirects, ssl=ssl, timeout=timeout, diff --git a/httpcore/models.py b/httpcore/models.py index 68ab4b2d..663b6400 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -43,6 +43,11 @@ HeaderTypes = typing.Union[ typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]], ] +AuthTypes = typing.Union[ + typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]], + typing.Callable[["Request"], "Request"] +] + RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] @@ -93,6 +98,16 @@ class URL: def authority(self) -> str: return self.components.authority or "" + @property + def username(self) -> str: + userinfo = self.components.userinfo or "" + return userinfo.partition(':')[0] + + @property + def password(self) -> str: + userinfo = self.components.userinfo or "" + return userinfo.partition(':')[2] + @property def host(self) -> str: return self.components.host or "" diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py new file mode 100644 index 00000000..e044acd8 --- /dev/null +++ b/tests/client/test_auth.py @@ -0,0 +1,72 @@ +import json +from urllib.parse import parse_qs + +import pytest + +from httpcore import ( + URL, + Client, + Dispatcher, + Request, + Response, + SSLConfig, + TimeoutConfig, +) + + +class MockDispatch(Dispatcher): + async def send( + self, + request: Request, + stream: bool = False, + ssl: SSLConfig = None, + timeout: TimeoutConfig = None, + ) -> Response: + body = json.dumps({"auth": request.headers.get('Authorization')}).encode() + return Response(200, content=body, request=request) + + +def test_basic_auth(): + url = "https://example.org/" + auth = ('tomchristie', 'password123') + + with Client(dispatch=MockDispatch()) as client: + response = client.get(url, auth=auth) + + assert response.status_code == 200 + assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='} + + +def test_basic_auth_in_url(): + url = "https://tomchristie:password123@example.org/" + + with Client(dispatch=MockDispatch()) as client: + response = client.get(url) + + assert response.status_code == 200 + assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='} + + +def test_basic_auth_on_session(): + url = "https://example.org/" + auth = ('tomchristie', 'password123') + + with Client(dispatch=MockDispatch(), auth=auth) as client: + response = client.get(url) + + assert response.status_code == 200 + assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='} + + +def test_custom_auth(): + url = "https://example.org/" + + def auth(request): + request.headers['Authorization'] = 'Token 123' + return request + + with Client(dispatch=MockDispatch()) as client: + response = client.get(url, auth=auth) + + assert response.status_code == 200 + assert json.loads(response.text) == {'auth': 'Token 123'}