]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add await response.next() interface 23/head
authorTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 15:18:41 +0000 (16:18 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 15:18:41 +0000 (16:18 +0100)
httpcore/adapters/redirects.py
tests/adapters/test_redirects.py

index 584efec26e6d9cbc992eee092a51dd1ba2a8e1b7..ecbb172825bde8200edadd5a3e771953eaeb3c90 100644 (file)
@@ -1,3 +1,4 @@
+import functools
 import typing
 from urllib.parse import urljoin, urlparse
 
@@ -19,18 +20,25 @@ class RedirectAdapter(Adapter):
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         allow_redirects = options.pop("allow_redirects", True)
-        history = []  # type: typing.List[Response]
-        seen_urls = set((request.url,))
+        history = options.pop("history", [])  # type: typing.List[Response]
+        seen_urls = options.pop("seen_urls", set())  # type: typing.Set[URL]
+        seen_urls.add(request.url)
 
         while True:
             response = await self.dispatch.send(request, **options)
             response.history = list(history)
-            if not allow_redirects or not response.is_redirect:
+            if not response.is_redirect:
                 break
             history.append(response)
+            request = self.build_redirect_request(request, response)
+            if not allow_redirects:
+                next_options = dict(options)
+                next_options["seen_urls"] = seen_urls
+                next_options["history"] = history
+                response.next = functools.partial(self.send, request=request, **next_options)
+                break
             if len(history) > self.max_redirects:
                 raise TooManyRedirects()
-            request = self.build_redirect_request(request, response)
             if request.url in seen_urls:
                 raise RedirectLoop()
             seen_urls.add(request.url)
@@ -71,6 +79,9 @@ class RedirectAdapter(Adapter):
         return method
 
     def redirect_url(self, request: Request, response: Response) -> URL:
+        """
+        Return the URL for the redirect to follow.
+        """
         location = response.headers["Location"]
 
         # Handle redirection without scheme (see: RFC 1808 Section 4)
@@ -94,12 +105,19 @@ class RedirectAdapter(Adapter):
         return URL(url)
 
     def redirect_headers(self, request: Request, url: URL) -> Headers:
+        """
+        Strip Authorization headers when responses are redirected away from
+        the origin.
+        """
         headers = Headers(request.headers)
         if url.origin != request.url.origin:
             del headers["Authorization"]
         return headers
 
     def redirect_body(self, request: Request, method: str) -> bytes:
+        """
+        Return the body that should be used for the redirect request.
+        """
         if method != request.method and method == "GET":
             return b""
         if request.is_streaming:
index 3197cc3a1f0f6e2939dcbf6d238af1024094e82e..ce6d4d3f8e8d078f1c3bab3bbd55503eaf40bf46 100644 (file)
@@ -116,6 +116,11 @@ async def test_disallow_redirects():
     assert response.url == URL("https://example.org/redirect_303")
     assert len(response.history) == 0
 
+    response = await response.next()
+    assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
+
 
 @pytest.mark.asyncio
 async def test_relative_redirect():