]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Swap auth/redirects ordering (#1267)
authorTom Christie <tom@tomchristie.com>
Thu, 10 Sep 2020 08:12:05 +0000 (09:12 +0100)
committerGitHub <noreply@github.com>
Thu, 10 Sep 2020 08:12:05 +0000 (09:12 +0100)
* Internal refactoring to swap auth/redirects ordering

* Test for auth with cross domain redirect

httpx/_client.py
tests/client/test_redirects.py

index 61a862bde367649eb568c0d1e5d9555acc7307ab..afe4f9d9c5db5ada33f996f9129bc8e3249f609e 100644 (file)
@@ -725,8 +725,12 @@ class Client(BaseClient):
 
         auth = self._build_request_auth(request, auth)
 
-        response = self._send_handling_redirects(
-            request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
+        response = self._send_handling_auth(
+            request,
+            auth=auth,
+            timeout=timeout,
+            allow_redirects=allow_redirects,
+            history=[],
         )
 
         if not stream:
@@ -740,23 +744,17 @@ class Client(BaseClient):
     def _send_handling_redirects(
         self,
         request: Request,
-        auth: Auth,
         timeout: Timeout,
-        allow_redirects: bool = True,
-        history: typing.List[Response] = None,
+        allow_redirects: bool,
+        history: typing.List[Response],
     ) -> Response:
-        if history is None:
-            history = []
-
         while True:
             if len(history) > self.max_redirects:
                 raise TooManyRedirects(
                     "Exceeded maximum allowed redirects.", request=request
                 )
 
-            response = self._send_handling_auth(
-                request, auth=auth, timeout=timeout, history=history
-            )
+            response = self._send_single_request(request, timeout)
             response.history = list(history)
 
             if not response.is_redirect:
@@ -771,7 +769,6 @@ class Client(BaseClient):
                 response.call_next = functools.partial(
                     self._send_handling_redirects,
                     request=request,
-                    auth=auth,
                     timeout=timeout,
                     allow_redirects=False,
                     history=history,
@@ -781,16 +778,21 @@ class Client(BaseClient):
     def _send_handling_auth(
         self,
         request: Request,
-        history: typing.List[Response],
         auth: Auth,
         timeout: Timeout,
+        allow_redirects: bool,
+        history: typing.List[Response],
     ) -> Response:
         auth_flow = auth.sync_auth_flow(request)
         request = next(auth_flow)
 
         while True:
-            response = self._send_single_request(request, timeout)
-
+            response = self._send_handling_redirects(
+                request,
+                timeout=timeout,
+                allow_redirects=allow_redirects,
+                history=history,
+            )
             try:
                 next_request = auth_flow.send(response)
             except StopIteration:
@@ -1346,8 +1348,12 @@ class AsyncClient(BaseClient):
 
         auth = self._build_request_auth(request, auth)
 
-        response = await self._send_handling_redirects(
-            request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
+        response = await self._send_handling_auth(
+            request,
+            auth=auth,
+            timeout=timeout,
+            allow_redirects=allow_redirects,
+            history=[],
         )
 
         if not stream:
@@ -1361,23 +1367,17 @@ class AsyncClient(BaseClient):
     async def _send_handling_redirects(
         self,
         request: Request,
-        auth: Auth,
         timeout: Timeout,
-        allow_redirects: bool = True,
-        history: typing.List[Response] = None,
+        allow_redirects: bool,
+        history: typing.List[Response],
     ) -> Response:
-        if history is None:
-            history = []
-
         while True:
             if len(history) > self.max_redirects:
                 raise TooManyRedirects(
                     "Exceeded maximum allowed redirects.", request=request
                 )
 
-            response = await self._send_handling_auth(
-                request, auth=auth, timeout=timeout, history=history
-            )
+            response = await self._send_single_request(request, timeout)
             response.history = list(history)
 
             if not response.is_redirect:
@@ -1392,7 +1392,6 @@ class AsyncClient(BaseClient):
                 response.call_next = functools.partial(
                     self._send_handling_redirects,
                     request=request,
-                    auth=auth,
                     timeout=timeout,
                     allow_redirects=False,
                     history=history,
@@ -1402,16 +1401,21 @@ class AsyncClient(BaseClient):
     async def _send_handling_auth(
         self,
         request: Request,
-        history: typing.List[Response],
         auth: Auth,
         timeout: Timeout,
+        allow_redirects: bool,
+        history: typing.List[Response],
     ) -> Response:
         auth_flow = auth.async_auth_flow(request)
         request = await auth_flow.__anext__()
 
         while True:
-            response = await self._send_single_request(request, timeout)
-
+            response = await self._send_handling_redirects(
+                request,
+                timeout=timeout,
+                allow_redirects=allow_redirects,
+                history=history,
+            )
             try:
                 next_request = await auth_flow.asend(response)
             except StopAsyncIteration:
index b18feee95b9404779064aab9f7bd4b36f191d189..4b00133e313168ed3220fa8be4832563ec975212 100644 (file)
@@ -323,7 +323,7 @@ def test_redirect_loop():
         client.get("https://example.org/redirect_loop")
 
 
-def test_cross_domain_redirect():
+def test_cross_domain_redirect_with_auth_header():
     client = httpx.Client(transport=SyncMockTransport())
     url = "https://example.com/cross_domain"
     headers = {"Authorization": "abc"}
@@ -332,6 +332,14 @@ def test_cross_domain_redirect():
     assert "authorization" not in response.json()["headers"]
 
 
+def test_cross_domain_redirect_with_auth():
+    client = httpx.Client(transport=SyncMockTransport())
+    url = "https://example.com/cross_domain"
+    response = client.get(url, auth=("user", "pass"))
+    assert response.url == "https://example.org/cross_domain_target"
+    assert "authorization" not in response.json()["headers"]
+
+
 def test_same_domain_redirect():
     client = httpx.Client(transport=SyncMockTransport())
     url = "https://example.org/cross_domain"