From: Tom Christie Date: Mon, 29 Apr 2019 15:18:41 +0000 (+0100) Subject: Add await response.next() interface X-Git-Tag: 0.3.0~66^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=450ea25b5ad8194d180d16ede4537166ab7b3b18;p=thirdparty%2Fhttpx.git Add await response.next() interface --- diff --git a/httpcore/adapters/redirects.py b/httpcore/adapters/redirects.py index 584efec2..ecbb1728 100644 --- a/httpcore/adapters/redirects.py +++ b/httpcore/adapters/redirects.py @@ -1,3 +1,4 @@ +import functools import typing from urllib.parse import urljoin, urlparse @@ -19,18 +20,25 @@ class RedirectAdapter(Adapter): async def send(self, request: Request, **options: typing.Any) -> Response: allow_redirects = options.pop("allow_redirects", True) - history = [] # type: typing.List[Response] - seen_urls = set((request.url,)) + history = options.pop("history", []) # type: typing.List[Response] + seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL] + seen_urls.add(request.url) while True: response = await self.dispatch.send(request, **options) response.history = list(history) - if not allow_redirects or not response.is_redirect: + if not response.is_redirect: break history.append(response) + request = self.build_redirect_request(request, response) + if not allow_redirects: + next_options = dict(options) + next_options["seen_urls"] = seen_urls + next_options["history"] = history + response.next = functools.partial(self.send, request=request, **next_options) + break if len(history) > self.max_redirects: raise TooManyRedirects() - request = self.build_redirect_request(request, response) if request.url in seen_urls: raise RedirectLoop() seen_urls.add(request.url) @@ -71,6 +79,9 @@ class RedirectAdapter(Adapter): return method def redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ location = response.headers["Location"] # Handle redirection without scheme (see: RFC 1808 Section 4) @@ -94,12 +105,19 @@ class RedirectAdapter(Adapter): return URL(url) def redirect_headers(self, request: Request, url: URL) -> Headers: + """ + Strip Authorization headers when responses are redirected away from + the origin. + """ headers = Headers(request.headers) if url.origin != request.url.origin: del headers["Authorization"] return headers def redirect_body(self, request: Request, method: str) -> bytes: + """ + Return the body that should be used for the redirect request. + """ if method != request.method and method == "GET": return b"" if request.is_streaming: diff --git a/tests/adapters/test_redirects.py b/tests/adapters/test_redirects.py index 3197cc3a..ce6d4d3f 100644 --- a/tests/adapters/test_redirects.py +++ b/tests/adapters/test_redirects.py @@ -116,6 +116,11 @@ async def test_disallow_redirects(): assert response.url == URL("https://example.org/redirect_303") assert len(response.history) == 0 + response = await response.next() + assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 + @pytest.mark.asyncio async def test_relative_redirect():