From: Tom Christie Date: Mon, 29 Apr 2019 13:25:24 +0000 (+0100) Subject: Rejig request preparing X-Git-Tag: 0.3.0~66^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4f23ab4f0d3fa074d44d055fabaa194412e28661;p=thirdparty%2Fhttpx.git Rejig request preparing --- diff --git a/httpcore/adapters/redirects.py b/httpcore/adapters/redirects.py index 9ef287f7..aa6c557e 100644 --- a/httpcore/adapters/redirects.py +++ b/httpcore/adapters/redirects.py @@ -4,7 +4,7 @@ from urllib.parse import urljoin, urlparse from ..config import DEFAULT_MAX_REDIRECTS from ..exceptions import RedirectLoop, TooManyRedirects from ..interfaces import Adapter -from ..models import URL, Request, Response +from ..models import URL, Headers, Request, Response from ..status_codes import codes from ..utils import requote_uri @@ -19,11 +19,12 @@ class RedirectAdapter(Adapter): async def send(self, request: Request, **options: typing.Any) -> Response: allow_redirects = options.pop("allow_redirects", True) - history = [] + history = [] # type: typing.List[Response] seen_urls = set((request.url,)) while True: response = await self.dispatch.send(request, **options) + response.history = list(history) if not allow_redirects or not response.is_redirect: break history.append(response) @@ -42,7 +43,8 @@ class RedirectAdapter(Adapter): def build_redirect_request(self, request: Request, response: Response) -> Request: method = self.redirect_method(request, response) url = self.redirect_url(request, response) - return Request(method=method, url=url) + headers = self.redirect_headers(request, url) + return Request(method=method, url=url, headers=headers) def redirect_method(self, request: Request, response: Response) -> str: """ @@ -89,3 +91,9 @@ class RedirectAdapter(Adapter): url = requote_uri(url) return URL(url) + + def redirect_headers(self, request: Request, url: URL) -> Headers: + headers = Headers(request.headers) + if url.origin != request.url.origin: + del headers["Authorization"] + return headers diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 45134456..ffe82fcb 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -37,7 +37,7 @@ class HTTPConnection(Adapter): self.h2_connection = None # type: typing.Optional[HTTP2Connection] def prepare_request(self, request: Request) -> None: - pass + request.prepare() async def send(self, request: Request, **options: typing.Any) -> Response: if self.h11_connection is None and self.h2_connection is None: diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index a536d5f3..ca7a837f 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -10,6 +10,7 @@ from ..config import ( SSLConfig, TimeoutConfig, ) +from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout from ..interfaces import Adapter from ..models import Origin, Request, Response @@ -104,7 +105,7 @@ class ConnectionPool(Adapter): return len(self.keepalive_connections) + len(self.active_connections) def prepare_request(self, request: Request) -> None: - pass + request.prepare() async def send(self, request: Request, **options: typing.Any) -> Response: connection = await self.acquire_connection(request.url.origin) diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index e4a5c83a..8128dc18 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -46,7 +46,7 @@ class HTTP11Connection(Adapter): self.h11_state = h11.Connection(our_role=h11.CLIENT) def prepare_request(self, request: Request) -> None: - pass + request.prepare() async def send(self, request: Request, **options: typing.Any) -> Response: timeout = options.get("timeout") @@ -87,6 +87,7 @@ class HTTP11Connection(Adapter): headers=headers, body=body, on_close=self.response_closed, + request=request, ) if not stream: diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 7bd124b7..d44c5f60 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -32,7 +32,7 @@ class HTTP2Connection(Adapter): self.initialized = False def prepare_request(self, request: Request) -> None: - pass + request.prepare() async def send(self, request: Request, **options: typing.Any) -> Response: timeout = options.get("timeout") @@ -75,6 +75,7 @@ class HTTP2Connection(Adapter): headers=headers, body=body, on_close=on_close, + request=request, ) if not stream: diff --git a/httpcore/models.py b/httpcore/models.py index b0a4723d..92352403 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -79,6 +79,11 @@ class URL: def __str__(self) -> str: return self.components.geturl() + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url_str = str(self) + return f"{class_name}({url_str!r})" + class Origin: def __init__(self, url: typing.Union[str, URL]) -> None: @@ -100,13 +105,21 @@ class Origin: return hash((self.is_ssl, self.hostname, self.port)) +HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]] + + class Headers(typing.MutableMapping[str, str]): """ A case-insensitive multidict. """ - def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: - self._list = [(k.lower(), v) for k, v in headers] + def __init__(self, headers: HeaderTypes = None) -> None: + if headers is None: + self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif isinstance(headers, Headers): + self._list = list(headers.raw) + else: + self._list = [(k.lower(), v) for k, v in headers] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: @@ -213,7 +226,7 @@ class Request: method: str, url: typing.Union[str, URL], *, - headers: typing.List[typing.Tuple[bytes, bytes]] = [], + headers: HeaderTypes = None, body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", ): self.method = method.upper() @@ -224,26 +237,24 @@ class Request: else: self.is_streaming = True self.body_aiter = body - self.headers = self.build_headers(headers) - - def build_headers( - self, init_headers: typing.List[typing.Tuple[bytes, bytes]] - ) -> Headers: - has_host = False - has_content_length = False - has_accept_encoding = False - - for header, value in init_headers: - header = header.strip().lower() - if header == b"host": - has_host = True - elif header in (b"content-length", b"transfer-encoding"): - has_content_length = True - elif header == b"accept-encoding": - has_accept_encoding = True + self.headers = Headers(headers) + + async def stream(self) -> typing.AsyncIterator[bytes]: + if self.is_streaming: + async for part in self.body_aiter: + yield part + elif self.body: + yield self.body + def prepare(self) -> None: auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + has_host = "host" in self.headers + has_content_length = ( + "content-length" in self.headers or "transfer-encoding" in self.headers + ) + has_accept_encoding = "accept-encoding" in self.headers + if not has_host: auto_headers.append((b"host", self.url.netloc.encode("ascii"))) if not has_content_length: @@ -255,14 +266,8 @@ class Request: if not has_accept_encoding: auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) - return Headers(auto_headers + init_headers) - - async def stream(self) -> typing.AsyncIterator[bytes]: - if self.is_streaming: - async for part in self.body_aiter: - yield part - elif self.body: - yield self.body + for item in reversed(auto_headers): + self.headers.raw.insert(0, item) class Response: @@ -275,6 +280,8 @@ class Response: headers: typing.List[typing.Tuple[bytes, bytes]] = [], body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", on_close: typing.Callable = None, + request: Request = None, + history: typing.List["Response"] = None, ): self.status_code = status_code if not reason: @@ -310,6 +317,13 @@ class Response: else: self.body_aiter = body + self.request = request + self.history = [] if history is None else list(history) + + @property + def url(self) -> typing.Optional[URL]: + return None if self.request is None else self.request.url + async def read(self) -> bytes: """ Read and return the response content. @@ -358,4 +372,6 @@ class Response: @property def is_redirect(self) -> bool: - return self.status_code in (301, 302, 303, 307, 308) + return ( + self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers + ) diff --git a/tests/adapters/test_redirects.py b/tests/adapters/test_redirects.py index 3609dd73..ab371a04 100644 --- a/tests/adapters/test_redirects.py +++ b/tests/adapters/test_redirects.py @@ -1,8 +1,10 @@ +import json from urllib.parse import parse_qs import pytest from httpcore import ( + URL, Adapter, RedirectAdapter, RedirectLoop, @@ -19,32 +21,60 @@ class MockDispatch(Adapter): async def send(self, request: Request, **options) -> Response: if request.url.path == "/redirect_301": # "Moved Permanently" - return Response(301, headers=[(b"location", b"https://example.org/")]) + return Response( + 301, headers=[(b"location", b"https://example.org/")], request=request + ) elif request.url.path == "/redirect_302": # "Found" - return Response(302, headers=[(b"location", b"https://example.org/")]) + return Response( + 302, headers=[(b"location", b"https://example.org/")], request=request + ) elif request.url.path == "/redirect_303": # "See Other" - return Response(303, headers=[(b"location", b"https://example.org/")]) + return Response( + 303, headers=[(b"location", b"https://example.org/")], request=request + ) elif request.url.path == "/relative_redirect": - return Response(codes.see_other, headers=[(b"location", b"/")]) + return Response( + codes.see_other, headers=[(b"location", b"/")], request=request + ) elif request.url.path == "/no_scheme_redirect": - return Response(codes.see_other, headers=[(b"location", b"//example.org/")]) + return Response( + codes.see_other, + headers=[(b"location", b"//example.org/")], + request=request, + ) elif request.url.path == "/multiple_redirects": params = parse_qs(request.url.query) count = int(params.get("count", "0")[0]) + redirect_count = count - 1 code = codes.see_other if count else codes.ok - location = "/multiple_redirects?count=" + str(count - 1) + location = "/multiple_redirects" + if redirect_count: + location += "?count=" + str(redirect_count) headers = [(b"location", location.encode())] if count else [] - return Response(code, headers=headers) + return Response(code, headers=headers, request=request) if request.url.path == "/redirect_loop": - return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")]) + return Response( + codes.see_other, + headers=[(b"location", b"/redirect_loop")], + request=request, + ) - return Response(codes.ok, body=b"Hello, world!") + elif request.url.path == "/cross_domain": + location = b"https://example.org/cross_domain_target" + return Response(301, headers=[(b"location", location)], request=request) + + elif request.url.path == "/cross_domain_target": + headers = {k.decode(): v.decode() for k, v in request.headers.raw} + body = json.dumps({"headers": headers}).encode() + return Response(codes.ok, body=body, request=request) + + return Response(codes.ok, body=b"Hello, world!", request=request) @pytest.mark.asyncio @@ -52,6 +82,8 @@ async def test_redirect_301(): client = RedirectAdapter(MockDispatch()) response = await client.request("POST", "https://example.org/redirect_301") assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 @pytest.mark.asyncio @@ -59,6 +91,8 @@ async def test_redirect_302(): client = RedirectAdapter(MockDispatch()) response = await client.request("POST", "https://example.org/redirect_302") assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 @pytest.mark.asyncio @@ -66,6 +100,8 @@ async def test_redirect_303(): client = RedirectAdapter(MockDispatch()) response = await client.request("GET", "https://example.org/redirect_303") assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 @pytest.mark.asyncio @@ -73,6 +109,8 @@ async def test_relative_redirect(): client = RedirectAdapter(MockDispatch()) response = await client.request("GET", "https://example.org/relative_redirect") assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 @pytest.mark.asyncio @@ -80,13 +118,19 @@ async def test_no_scheme_redirect(): client = RedirectAdapter(MockDispatch()) response = await client.request("GET", "https://example.org/no_scheme_redirect") assert response.status_code == codes.ok + assert response.url == URL("https://example.org/") + assert len(response.history) == 1 @pytest.mark.asyncio async def test_fragment_redirect(): client = RedirectAdapter(MockDispatch()) - response = await client.request("GET", "https://example.org/relative_redirect#fragment") + response = await client.request( + "GET", "https://example.org/relative_redirect#fragment" + ) assert response.status_code == codes.ok + assert response.url == URL("https://example.org/#fragment") + assert len(response.history) == 1 @pytest.mark.asyncio @@ -96,6 +140,8 @@ async def test_multiple_redirects(): "GET", "https://example.org/multiple_redirects?count=20" ) assert response.status_code == codes.ok + assert response.url == URL("https://example.org/multiple_redirects") + assert len(response.history) == 20 @pytest.mark.asyncio @@ -110,3 +156,25 @@ async def test_redirect_loop(): client = RedirectAdapter(MockDispatch()) with pytest.raises(RedirectLoop): await client.request("GET", "https://example.org/redirect_loop") + + +@pytest.mark.asyncio +async def test_cross_domain_redirect(): + client = RedirectAdapter(MockDispatch()) + headers = [(b"Authorization", b"abc")] + url = "https://example.com/cross_domain" + response = await client.request("GET", url, headers=headers) + data = json.loads(response.body.decode()) + assert response.url == URL("https://example.org/cross_domain_target") + assert data == {"headers": {}} + + +@pytest.mark.asyncio +async def test_same_domain_redirect(): + client = RedirectAdapter(MockDispatch()) + headers = [(b"Authorization", b"abc")] + url = "https://example.org/cross_domain" + response = await client.request("GET", url, headers=headers) + data = json.loads(response.body.decode()) + assert response.url == URL("https://example.org/cross_domain_target") + assert data == {"headers": {"authorization": "abc"}} diff --git a/tests/test_requests.py b/tests/test_requests.py index 840dde1a..4df3529c 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -5,6 +5,7 @@ import httpcore def test_host_header(): request = httpcore.Request("GET", "http://example.org") + request.prepare() assert request.headers == httpcore.Headers( [(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")] ) @@ -12,6 +13,7 @@ def test_host_header(): def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", body=b"test 123") + request.prepare() assert request.headers == httpcore.Headers( [ (b"host", b"example.org"), @@ -28,6 +30,7 @@ def test_transfer_encoding_header(): body = streaming_body(b"test 123") request = httpcore.Request("POST", "http://example.org", body=body) + request.prepare() assert request.headers == httpcore.Headers( [ (b"host", b"example.org"), @@ -41,6 +44,7 @@ def test_override_host_header(): headers = [(b"host", b"1.2.3.4:80")] request = httpcore.Request("GET", "http://example.org", headers=headers) + request.prepare() assert request.headers == httpcore.Headers( [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")] ) @@ -50,6 +54,7 @@ def test_override_accept_encoding_header(): headers = [(b"accept-encoding", b"identity")] request = httpcore.Request("GET", "http://example.org", headers=headers) + request.prepare() assert request.headers == httpcore.Headers( [(b"host", b"example.org"), (b"accept-encoding", b"identity")] ) @@ -63,6 +68,7 @@ def test_override_content_length_header(): headers = [(b"content-length", b"8")] request = httpcore.Request("POST", "http://example.org", body=body, headers=headers) + request.prepare() assert request.headers == httpcore.Headers( [ (b"host", b"example.org"),