From: Tom Christie Date: Mon, 29 Apr 2019 15:05:05 +0000 (+0100) Subject: Nicer headers interface X-Git-Tag: 0.3.0~66^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=05e0649e8fee91913b59c61515a387f0a389a8ac;p=thirdparty%2Fhttpx.git Nicer headers interface --- diff --git a/httpcore/models.py b/httpcore/models.py index 92352403..43f799a6 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -11,6 +11,7 @@ from .decoders import ( MultiDecoder, ) from .exceptions import ResponseClosed, StreamConsumed +from .utils import normalize_header_key, normalize_header_value class URL: @@ -105,7 +106,11 @@ class Origin: return hash((self.is_ssl, self.hostname, self.port)) -HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]] +HeaderTypes = typing.Union[ + "Headers", + typing.Dict[typing.AnyStr, typing.AnyStr], + typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]], +] class Headers(typing.MutableMapping[str, str]): @@ -118,8 +123,15 @@ class Headers(typing.MutableMapping[str, str]): self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]] elif isinstance(headers, Headers): self._list = list(headers.raw) + elif isinstance(headers, dict): + self._list = [ + (normalize_header_key(k), normalize_header_value(v)) + for k, v in headers.items() + ] else: - self._list = [(k.lower(), v) for k, v in headers] + self._list = [ + (normalize_header_key(k), normalize_header_value(v)) for k, v in headers + ] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: @@ -239,6 +251,17 @@ class Request: self.body_aiter = body self.headers = Headers(headers) + async def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "body"): + body = b"" + async for part in self.stream(): + body += part + self.body = body + return self.body + async def stream(self) -> typing.AsyncIterator[bytes]: if self.is_streaming: async for part in self.body_aiter: diff --git a/httpcore/utils.py b/httpcore/utils.py index cd11858a..419e7ec2 100644 --- a/httpcore/utils.py +++ b/httpcore/utils.py @@ -1,3 +1,4 @@ +import typing from urllib.parse import quote from .exceptions import InvalidURL @@ -50,3 +51,21 @@ def requote_uri(uri: str) -> str: # there may be unquoted '%'s in the URI. We need to make sure they're # properly quoted so they do not cause issues elsewhere. return quote(uri, safe=safe_without_percent) + + +def normalize_header_key(value: typing.AnyStr) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header key. + """ + if isinstance(value, bytes): + return value.lower() + return value.encode("latin-1").lower() + + +def normalize_header_value(value: typing.AnyStr) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header value. + """ + if isinstance(value, bytes): + return value + return value.encode("latin-1") diff --git a/tests/adapters/test_redirects.py b/tests/adapters/test_redirects.py index 6499c762..56ee9197 100644 --- a/tests/adapters/test_redirects.py +++ b/tests/adapters/test_redirects.py @@ -21,32 +21,28 @@ class MockDispatch(Adapter): pass async def send(self, request: Request, **options) -> Response: - if request.url.path == "/redirect_301": # "Moved Permanently" - return Response( - 301, headers=[(b"location", b"https://example.org/")], request=request - ) + if request.url.path == "/redirect_301": + status_code = codes.moved_permanently + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) - elif request.url.path == "/redirect_302": # "Found" - return Response( - 302, headers=[(b"location", b"https://example.org/")], request=request - ) + elif request.url.path == "/redirect_302": + status_code = codes.found + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) - elif request.url.path == "/redirect_303": # "See Other" - return Response( - 303, headers=[(b"location", b"https://example.org/")], request=request - ) + elif request.url.path == "/redirect_303": + status_code = codes.see_other + headers = {"location": "https://example.org/"} + return Response(status_code, headers=headers, request=request) elif request.url.path == "/relative_redirect": - return Response( - codes.see_other, headers=[(b"location", b"/")], request=request - ) + headers = {"location": "/"} + return Response(codes.see_other, headers=headers, request=request) elif request.url.path == "/no_scheme_redirect": - return Response( - codes.see_other, - headers=[(b"location", b"//example.org/")], - request=request, - ) + headers = {"location": "//example.org/"} + return Response(codes.see_other, headers=headers, request=request) elif request.url.path == "/multiple_redirects": params = parse_qs(request.url.query) @@ -56,31 +52,30 @@ class MockDispatch(Adapter): location = "/multiple_redirects" if redirect_count: location += "?count=" + str(redirect_count) - headers = [(b"location", location.encode())] if count else [] + headers = {"location": location} if count else {} return Response(code, headers=headers, request=request) if request.url.path == "/redirect_loop": - return Response( - codes.see_other, - headers=[(b"location", b"/redirect_loop")], - request=request, - ) + headers = {"location": "/redirect_loop"} + return Response(codes.see_other, headers=headers, request=request) elif request.url.path == "/cross_domain": - location = b"https://example.org/cross_domain_target" - return Response(301, headers=[(b"location", location)], request=request) + headers = {"location": "https://example.org/cross_domain_target"} + return Response(codes.see_other, headers=headers, request=request) elif request.url.path == "/cross_domain_target": - headers = {k.decode(): v.decode() for k, v in request.headers.raw} + headers = dict(request.headers.items()) body = json.dumps({"headers": headers}).encode() return Response(codes.ok, body=body, request=request) elif request.url.path == "/redirect_body": - headers = [(b"location", b"/redirect_body_target")] + body = await request.read() + headers = {"location": "/redirect_body_target"} return Response(codes.permanent_redirect, headers=headers, request=request) elif request.url.path == "/redirect_body_target": - body = json.dumps({"body": request.body.decode()}).encode() + body = await request.read() + body = json.dumps({"body": body.decode()}).encode() return Response(codes.ok, body=body, request=request) return Response(codes.ok, body=b"Hello, world!", request=request) @@ -134,9 +129,8 @@ async def test_no_scheme_redirect(): @pytest.mark.asyncio async def test_fragment_redirect(): client = RedirectAdapter(MockDispatch()) - response = await client.request( - "GET", "https://example.org/relative_redirect#fragment" - ) + url = "https://example.org/relative_redirect#fragment" + response = await client.request("GET", url) assert response.status_code == codes.ok assert response.url == URL("https://example.org/#fragment") assert len(response.history) == 1 @@ -145,9 +139,8 @@ async def test_fragment_redirect(): @pytest.mark.asyncio async def test_multiple_redirects(): client = RedirectAdapter(MockDispatch()) - response = await client.request( - "GET", "https://example.org/multiple_redirects?count=20" - ) + url = "https://example.org/multiple_redirects?count=20" + response = await client.request("GET", url) assert response.status_code == codes.ok assert response.url == URL("https://example.org/multiple_redirects") assert len(response.history) == 20 @@ -170,8 +163,8 @@ async def test_redirect_loop(): @pytest.mark.asyncio async def test_cross_domain_redirect(): client = RedirectAdapter(MockDispatch()) - headers = [(b"Authorization", b"abc")] url = "https://example.com/cross_domain" + headers = {"Authorization": "abc"} response = await client.request("GET", url, headers=headers) data = json.loads(response.body.decode()) assert response.url == URL("https://example.org/cross_domain_target") @@ -181,8 +174,8 @@ async def test_cross_domain_redirect(): @pytest.mark.asyncio async def test_same_domain_redirect(): client = RedirectAdapter(MockDispatch()) - headers = [(b"Authorization", b"abc")] url = "https://example.org/cross_domain" + headers = {"Authorization": "abc"} response = await client.request("GET", url, headers=headers) data = json.loads(response.body.decode()) assert response.url == URL("https://example.org/cross_domain_target") @@ -192,8 +185,8 @@ async def test_same_domain_redirect(): @pytest.mark.asyncio async def test_body_redirect(): client = RedirectAdapter(MockDispatch()) - body = b"Example request body" url = "https://example.org/redirect_body" + body = b"Example request body" response = await client.request("POST", url, body=body) data = json.loads(response.body.decode()) assert response.url == URL("https://example.org/redirect_body_target")