]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add `Client.auth` setter (#1185)
authorFlorimond Manca <florimond.manca@gmail.com>
Mon, 17 Aug 2020 12:51:52 +0000 (14:51 +0200)
committerGitHub <noreply@github.com>
Mon, 17 Aug 2020 12:51:52 +0000 (14:51 +0200)
docs/api.md
httpx/_client.py
tests/client/test_auth.py

index 94bbb81cdd6c3664e7e4192b27653ecb5633ec10..e04e5696041dfc0143fca1188b19b9241d76d0a4 100644 (file)
 
 ::: httpx.Client
     :docstring:
-    :members: headers cookies params request get head options post put patch delete build_request send close
+    :members: headers cookies params auth request get head options post put patch delete build_request send close
 
 ## `AsyncClient`
 
 ::: httpx.AsyncClient
     :docstring:
-    :members: headers cookies params request get head options post put patch delete build_request send aclose
+    :members: headers cookies params auth request get head options post put patch delete build_request send aclose
 
 
 ## `Response`
index 53d52e5123c82e58bb154626619fbf45898bc6d5..74c161611133381b0ca1a110dd07813c6180c305 100644 (file)
@@ -71,7 +71,7 @@ class BaseClient:
     ):
         self._base_url = self._enforce_trailing_slash(URL(base_url))
 
-        self.auth = auth
+        self._auth = self._build_auth(auth)
         self._params = QueryParams(params)
         self._headers = Headers(headers)
         self._cookies = Cookies(cookies)
@@ -117,6 +117,21 @@ class BaseClient:
     def timeout(self, timeout: TimeoutTypes) -> None:
         self._timeout = Timeout(timeout)
 
+    @property
+    def auth(self) -> typing.Optional[Auth]:
+        """
+        Authentication class used when none is passed at the request-level.
+
+        See also [Authentication][0].
+
+        [0]: /quickstart/#authentication
+        """
+        return self._auth
+
+    @auth.setter
+    def auth(self, auth: AuthTypes) -> None:
+        self._auth = self._build_auth(auth)
+
     @property
     def base_url(self) -> URL:
         """
@@ -284,19 +299,25 @@ class BaseClient:
             return merged_queryparams
         return params
 
-    def _build_auth(
+    def _build_auth(self, auth: AuthTypes) -> typing.Optional[Auth]:
+        if auth is None:
+            return None
+        elif isinstance(auth, tuple):
+            return BasicAuth(username=auth[0], password=auth[1])
+        elif isinstance(auth, Auth):
+            return auth
+        elif callable(auth):
+            return FunctionAuth(func=auth)
+        else:
+            raise TypeError('Invalid "auth" argument.')
+
+    def _build_request_auth(
         self, request: Request, auth: typing.Union[AuthTypes, UnsetType] = UNSET
     ) -> Auth:
-        auth = self.auth if isinstance(auth, UnsetType) else auth
+        auth = self._auth if isinstance(auth, UnsetType) else self._build_auth(auth)
 
         if auth is not None:
-            if isinstance(auth, tuple):
-                return BasicAuth(username=auth[0], password=auth[1])
-            elif isinstance(auth, Auth):
-                return auth
-            elif callable(auth):
-                return FunctionAuth(func=auth)
-            raise TypeError('Invalid "auth" argument.')
+            return auth
 
         username, password = request.url.username, request.url.password
         if username or password:
@@ -667,7 +688,7 @@ class Client(BaseClient):
         """
         timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
 
-        auth = self._build_auth(request, auth)
+        auth = self._build_request_auth(request, auth)
 
         response = self._send_handling_redirects(
             request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
@@ -1269,7 +1290,7 @@ class AsyncClient(BaseClient):
         """
         timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
 
-        auth = self._build_auth(request, auth)
+        auth = self._build_request_auth(request, auth)
 
         response = await self._send_handling_redirects(
             request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
index d4022eb9a92c003cd2b95d5729954754debc35b2..edfccf0a704de79711b1e9b8fff85d8445babab5 100644 (file)
@@ -9,6 +9,7 @@ from httpx import (
     URL,
     AsyncClient,
     Auth,
+    BasicAuth,
     Client,
     DigestAuth,
     ProtocolError,
@@ -310,14 +311,34 @@ async def test_auth_hidden_header() -> None:
 
 
 @pytest.mark.asyncio
-async def test_auth_invalid_type() -> None:
+async def test_auth_property() -> None:
+    client = AsyncClient(transport=AsyncMockTransport())
+    assert client.auth is None
+
+    client.auth = ("tomchristie", "password123")  # type: ignore
+    assert isinstance(client.auth, BasicAuth)
+
     url = "https://example.org/"
-    client = AsyncClient(
-        transport=AsyncMockTransport(),
-        auth="not a tuple, not a callable",  # type: ignore
-    )
+    response = await client.get(url)
+    assert response.status_code == 200
+    assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
+
+
+@pytest.mark.asyncio
+async def test_auth_invalid_type() -> None:
+    with pytest.raises(TypeError):
+        client = AsyncClient(
+            transport=AsyncMockTransport(),
+            auth="not a tuple, not a callable",  # type: ignore
+        )
+
+    client = AsyncClient(transport=AsyncMockTransport())
+
+    with pytest.raises(TypeError):
+        await client.get(auth="not a tuple, not a callable")  # type: ignore
+
     with pytest.raises(TypeError):
-        await client.get(url)
+        client.auth = "not a tuple, not a callable"  # type: ignore
 
 
 @pytest.mark.asyncio