-import functools
import typing
from urllib.parse import urljoin, urlparse
self.dispatch.prepare_request(request)
async def send(self, request: Request, **options: typing.Any) -> Response:
- allow_redirects = options.pop("allow_redirects", True)
+ allow_redirects = options.pop("allow_redirects", True) # type: bool
+
+ # The following will not typically be specified by the end-user developer,
+ # but are included in `response.next()` calls.
history = options.pop("history", []) # type: typing.List[Response]
seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL]
- seen_urls.add(request.url)
while True:
+ # We perform these checks here, so that calls to `response.next()`
+ # will raise redirect errors if appropriate.
+ if len(history) > self.max_redirects:
+ raise TooManyRedirects()
+ if request.url in seen_urls:
+ raise RedirectLoop()
+
response = await self.dispatch.send(request, **options)
response.history = list(history)
if not response.is_redirect:
break
- history.append(response)
- request = self.build_redirect_request(request, response)
- if not allow_redirects:
+
+ history.insert(0, response)
+ seen_urls.add(request.url)
+
+ if allow_redirects:
+ request = self.build_redirect_request(request, response)
+ else:
next_options = dict(options)
next_options["seen_urls"] = seen_urls
next_options["history"] = history
- response.next = functools.partial(self.send, request=request, **next_options)
+
+ async def send_next() -> Response:
+ nonlocal request, response, next_options
+ request = self.build_redirect_request(request, response)
+ response = await self.send(request, **next_options)
+ return response
+
+ response.next = send_next # type: ignore
break
- if len(history) > self.max_redirects:
- raise TooManyRedirects()
- if request.url in seen_urls:
- raise RedirectLoop()
- seen_urls.add(request.url)
return response
@pytest.mark.asyncio
async def test_disallow_redirects():
client = RedirectAdapter(MockDispatch())
- response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False)
+ response = await client.request(
+ "POST", "https://example.org/redirect_303", allow_redirects=False
+ )
assert response.status_code == codes.see_other
assert response.url == URL("https://example.org/redirect_303")
assert len(response.history) == 0
await client.request("GET", "https://example.org/multiple_redirects?count=21")
+@pytest.mark.asyncio
+async def test_too_many_redirects_calling_next():
+ client = RedirectAdapter(MockDispatch())
+ url = "https://example.org/multiple_redirects?count=21"
+ response = await client.request("GET", url, allow_redirects=False)
+ with pytest.raises(TooManyRedirects):
+ while response.is_redirect:
+ response = await response.next()
+
+
@pytest.mark.asyncio
async def test_redirect_loop():
client = RedirectAdapter(MockDispatch())
await client.request("GET", "https://example.org/redirect_loop")
+@pytest.mark.asyncio
+async def test_redirect_loop_calling_next():
+ client = RedirectAdapter(MockDispatch())
+ url = "https://example.org/redirect_loop"
+ response = await client.request("GET", url, allow_redirects=False)
+ with pytest.raises(RedirectLoop):
+ while response.is_redirect:
+ response = await response.next()
+
+
@pytest.mark.asyncio
async def test_cross_domain_redirect():
client = RedirectAdapter(MockDispatch())