From: Tom Christie Date: Mon, 29 Apr 2019 16:14:08 +0000 (+0100) Subject: response.next() X-Git-Tag: 0.3.0~64 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d652479b074a97b1829b9fa7663bf400d5e3af04;p=thirdparty%2Fhttpx.git response.next() --- diff --git a/httpcore/adapters/redirects.py b/httpcore/adapters/redirects.py index ecbb1728..2d4df0d5 100644 --- a/httpcore/adapters/redirects.py +++ b/httpcore/adapters/redirects.py @@ -1,4 +1,3 @@ -import functools import typing from urllib.parse import urljoin, urlparse @@ -19,29 +18,44 @@ class RedirectAdapter(Adapter): self.dispatch.prepare_request(request) async def send(self, request: Request, **options: typing.Any) -> Response: - allow_redirects = options.pop("allow_redirects", True) + allow_redirects = options.pop("allow_redirects", True) # type: bool + + # The following will not typically be specified by the end-user developer, + # but are included in `response.next()` calls. history = options.pop("history", []) # type: typing.List[Response] seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL] - seen_urls.add(request.url) while True: + # We perform these checks here, so that calls to `response.next()` + # will raise redirect errors if appropriate. + if len(history) > self.max_redirects: + raise TooManyRedirects() + if request.url in seen_urls: + raise RedirectLoop() + response = await self.dispatch.send(request, **options) response.history = list(history) if not response.is_redirect: break - history.append(response) - request = self.build_redirect_request(request, response) - if not allow_redirects: + + history.insert(0, response) + seen_urls.add(request.url) + + if allow_redirects: + request = self.build_redirect_request(request, response) + else: next_options = dict(options) next_options["seen_urls"] = seen_urls next_options["history"] = history - response.next = functools.partial(self.send, request=request, **next_options) + + async def send_next() -> Response: + nonlocal request, response, next_options + request = self.build_redirect_request(request, response) + response = await self.send(request, **next_options) + return response + + response.next = send_next # type: ignore break - if len(history) > self.max_redirects: - raise TooManyRedirects() - if request.url in seen_urls: - raise RedirectLoop() - seen_urls.add(request.url) return response diff --git a/httpcore/decoders.py b/httpcore/decoders.py index 4b2b67bb..e58b8ee2 100644 --- a/httpcore/decoders.py +++ b/httpcore/decoders.py @@ -6,15 +6,14 @@ See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding import typing import zlib +import httpcore.exceptions + try: import brotli except ImportError: # pragma: nocover brotli = None -import httpcore.exceptions - - class Decoder: def decode(self, data: bytes) -> bytes: raise NotImplementedError() # pragma: nocover diff --git a/httpcore/exceptions.py b/httpcore/exceptions.py index e8c8b72e..6c7fc605 100644 --- a/httpcore/exceptions.py +++ b/httpcore/exceptions.py @@ -78,7 +78,7 @@ class DecodingError(Exception): Decoding of the response failed. """ - + class InvalidURL(Exception): """ URL was missing a hostname, or was not one of HTTP/HTTPS. diff --git a/httpcore/models.py b/httpcore/models.py index 43f799a6..4b97211c 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -342,6 +342,7 @@ class Response: self.request = request self.history = [] if history is None else list(history) + self.next = None # typing.Optional[typing.Callable] @property def url(self) -> typing.Optional[URL]: diff --git a/tests/adapters/test_redirects.py b/tests/adapters/test_redirects.py index ce6d4d3f..97a5393e 100644 --- a/tests/adapters/test_redirects.py +++ b/tests/adapters/test_redirects.py @@ -111,7 +111,9 @@ async def test_redirect_303(): @pytest.mark.asyncio async def test_disallow_redirects(): client = RedirectAdapter(MockDispatch()) - response = await client.request("POST", "https://example.org/redirect_303", allow_redirects=False) + response = await client.request( + "POST", "https://example.org/redirect_303", allow_redirects=False + ) assert response.status_code == codes.see_other assert response.url == URL("https://example.org/redirect_303") assert len(response.history) == 0 @@ -167,6 +169,16 @@ async def test_too_many_redirects(): await client.request("GET", "https://example.org/multiple_redirects?count=21") +@pytest.mark.asyncio +async def test_too_many_redirects_calling_next(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/multiple_redirects?count=21" + response = await client.request("GET", url, allow_redirects=False) + with pytest.raises(TooManyRedirects): + while response.is_redirect: + response = await response.next() + + @pytest.mark.asyncio async def test_redirect_loop(): client = RedirectAdapter(MockDispatch()) @@ -174,6 +186,16 @@ async def test_redirect_loop(): await client.request("GET", "https://example.org/redirect_loop") +@pytest.mark.asyncio +async def test_redirect_loop_calling_next(): + client = RedirectAdapter(MockDispatch()) + url = "https://example.org/redirect_loop" + response = await client.request("GET", url, allow_redirects=False) + with pytest.raises(RedirectLoop): + while response.is_redirect: + response = await response.next() + + @pytest.mark.asyncio async def test_cross_domain_redirect(): client = RedirectAdapter(MockDispatch())