]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
response.next()
authorTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 16:14:08 +0000 (17:14 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 16:14:08 +0000 (17:14 +0100)
httpcore/adapters/redirects.py
httpcore/decoders.py
httpcore/exceptions.py
httpcore/models.py
tests/adapters/test_redirects.py

index ecbb172825bde8200edadd5a3e771953eaeb3c90..2d4df0d5cb1242559d9a247c40b809ff4231fa68 100644 (file)
@@ -1,4 +1,3 @@
-import functools
 import typing
 from urllib.parse import urljoin, urlparse
 
@@ -19,29 +18,44 @@ class RedirectAdapter(Adapter):
         self.dispatch.prepare_request(request)
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
-        allow_redirects = options.pop("allow_redirects", True)
+        allow_redirects = options.pop("allow_redirects", True)  # type: bool
+
+        # The following will not typically be specified by the end-user developer,
+        # but are included in `response.next()` calls.
         history = options.pop("history", [])  # type: typing.List[Response]
         seen_urls = options.pop("seen_urls", set())  # type: typing.Set[URL]
-        seen_urls.add(request.url)
 
         while True:
+            # We perform these checks here, so that calls to `response.next()`
+            # will raise redirect errors if appropriate.
+            if len(history) > self.max_redirects:
+                raise TooManyRedirects()
+            if request.url in seen_urls:
+                raise RedirectLoop()
+
             response = await self.dispatch.send(request, **options)
             response.history = list(history)
             if not response.is_redirect:
                 break
-            history.append(response)
-            request = self.build_redirect_request(request, response)
-            if not allow_redirects:
+
+            history.insert(0, response)
+            seen_urls.add(request.url)
+
+            if allow_redirects:
+                request = self.build_redirect_request(request, response)
+            else:
                 next_options = dict(options)
                 next_options["seen_urls"] = seen_urls
                 next_options["history"] = history
-                response.next = functools.partial(self.send, request=request, **next_options)
+
+                async def send_next() -> Response:
+                    nonlocal request, response, next_options
+                    request = self.build_redirect_request(request, response)
+                    response = await self.send(request, **next_options)
+                    return response
+
+                response.next = send_next  # type: ignore
                 break
-            if len(history) > self.max_redirects:
-                raise TooManyRedirects()
-            if request.url in seen_urls:
-                raise RedirectLoop()
-            seen_urls.add(request.url)
 
         return response
 
index 4b2b67bb34d102dff83e9d2577683d61cd18c065..e58b8ee28c80b257910626802fcfc48fd9302114 100644 (file)
@@ -6,15 +6,14 @@ See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
 import typing
 import zlib
 
+import httpcore.exceptions
+
 try:
     import brotli
 except ImportError:  # pragma: nocover
     brotli = None
 
 
-import httpcore.exceptions
-
-
 class Decoder:
     def decode(self, data: bytes) -> bytes:
         raise NotImplementedError()  # pragma: nocover
index e8c8b72ee836038b665de8d5e961107b6a8a1951..6c7fc605344095c10b7ed1f99ef8856d5f7a35f3 100644 (file)
@@ -78,7 +78,7 @@ class DecodingError(Exception):
     Decoding of the response failed.
     """
 
-    
+
 class InvalidURL(Exception):
     """
     URL was missing a hostname, or was not one of HTTP/HTTPS.
index 43f799a6299dcb0af0a791bded7bfbbf1d4ae526..4b97211c1fcbc3757fd2116b0a0945c3467bb0b8 100644 (file)
@@ -342,6 +342,7 @@ class Response:
 
         self.request = request
         self.history = [] if history is None else list(history)
+        self.next = None  # typing.Optional[typing.Callable]
 
     @property
     def url(self) -> typing.Optional[URL]:
index ce6d4d3f8e8d078f1c3bab3bbd55503eaf40bf46..97a5393e33bfe9ebfe2ed6e5ce1b2fe95e9d3f8b 100644 (file)
@@ -111,7 +111,9 @@ async def test_redirect_303():
 @pytest.mark.asyncio
 async def test_disallow_redirects():
     client = RedirectAdapter(MockDispatch())
-    response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False)
+    response = await client.request(
+        "POST", "https://example.org/redirect_303", allow_redirects=False
+    )
     assert response.status_code == codes.see_other
     assert response.url == URL("https://example.org/redirect_303")
     assert len(response.history) == 0
@@ -167,6 +169,16 @@ async def test_too_many_redirects():
         await client.request("GET", "https://example.org/multiple_redirects?count=21")
 
 
+@pytest.mark.asyncio
+async def test_too_many_redirects_calling_next():
+    client = RedirectAdapter(MockDispatch())
+    url = "https://example.org/multiple_redirects?count=21"
+    response = await client.request("GET", url, allow_redirects=False)
+    with pytest.raises(TooManyRedirects):
+        while response.is_redirect:
+            response = await response.next()
+
+
 @pytest.mark.asyncio
 async def test_redirect_loop():
     client = RedirectAdapter(MockDispatch())
@@ -174,6 +186,16 @@ async def test_redirect_loop():
         await client.request("GET", "https://example.org/redirect_loop")
 
 
+@pytest.mark.asyncio
+async def test_redirect_loop_calling_next():
+    client = RedirectAdapter(MockDispatch())
+    url = "https://example.org/redirect_loop"
+    response = await client.request("GET", url, allow_redirects=False)
+    with pytest.raises(RedirectLoop):
+        while response.is_redirect:
+            response = await response.next()
+
+
 @pytest.mark.asyncio
 async def test_cross_domain_redirect():
     client = RedirectAdapter(MockDispatch())