From: Tom Christie Date: Tue, 16 Apr 2019 11:14:20 +0000 (+0100) Subject: Add auto headers on Request X-Git-Tag: 0.1.0~9^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=042dd2205846c9ac0e4796ec754aef9127a752c7;p=thirdparty%2Fhttpx.git Add auto headers on Request --- diff --git a/httpcore/connections.py b/httpcore/connections.py index 312fc46c..e181cb50 100644 --- a/httpcore/connections.py +++ b/httpcore/connections.py @@ -43,13 +43,7 @@ class Connection: async def send(self, request: Request, stream: bool = False) -> Response: method = request.method.encode() target = request.url.target - host_header = (b"host", request.url.netloc.encode("ascii")) - if request.is_streaming: - content_length = (b"transfer-encoding", b"chunked") - else: - content_length = (b"content-length", str(len(request.body)).encode()) - - headers = [host_header, content_length] + request.headers + headers = request.headers #  Start sending the request. event = h11.Request(method=method, target=target, headers=headers) diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py index bf8dae78..28f3bf91 100644 --- a/httpcore/datastructures.py +++ b/httpcore/datastructures.py @@ -1,7 +1,13 @@ import typing from urllib.parse import urlsplit -from .decoders import SUPPORTED_DECODERS, Decoder, IdentityDecoder, MultiDecoder +from .decoders import ( + ACCEPT_ENCODING, + SUPPORTED_DECODERS, + Decoder, + IdentityDecoder, + MultiDecoder, +) from .exceptions import ResponseClosed, StreamConsumed @@ -59,13 +65,13 @@ class Request: def __init__( self, method: str, - url: URL, + url: typing.Union[str, URL], *, headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", ): self.method = method - self.url = url + self.url = URL(url) if isinstance(url, str) else url self.headers = list(headers) if isinstance(body, bytes): self.is_streaming = False @@ -73,6 +79,36 @@ class Request: else: self.is_streaming = True self.body_aiter = body + self.headers = self._auto_headers() + self.headers + + def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]: + has_host = False + has_content_length = False + has_accept_encoding = False + + for header, value in self.headers: + header = header.strip().lower() + if header == b"host": + has_host = True + elif header in (b"content-length", b"transfer-encoding"): + has_content_length = True + elif header == b"accept-encoding": + has_accept_encoding = True + + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + + if not has_host: + headers.append((b"host", self.url.netloc.encode("ascii"))) + if not has_content_length: + if self.is_streaming: + headers.append((b"transfer-encoding", b"chunked")) + elif self.body: + content_length = str(len(self.body)).encode() + headers.append((b"content-length", content_length)) + if not has_accept_encoding: + headers.append((b"accept-encoding", ACCEPT_ENCODING)) + + return headers async def stream(self) -> typing.AsyncIterator[bytes]: assert self.is_streaming @@ -131,7 +167,7 @@ class Response: async def stream(self) -> typing.AsyncIterator[bytes]: """ A byte-iterator over the decoded response content. - This will allow us to handle gzip, deflate, and brotli encoded responses. + This allows us to handle gzip, deflate, and brotli encoded responses. """ if hasattr(self, "body"): yield self.body diff --git a/httpcore/decoders.py b/httpcore/decoders.py index 8b464f5c..1e61998a 100644 --- a/httpcore/decoders.py +++ b/httpcore/decoders.py @@ -107,12 +107,17 @@ class MultiDecoder(Decoder): SUPPORTED_DECODERS = { - b"gzip": GZipDecoder, - b"deflate": DeflateDecoder, b"identity": IdentityDecoder, + b"deflate": DeflateDecoder, + b"gzip": GZipDecoder, b"br": BrotliDecoder, } if brotli is None: SUPPORTED_DECODERS.pop(b"br") # pragma: nocover + + +ACCEPT_ENCODING = b", ".join( + [key for key in SUPPORTED_DECODERS.keys() if key != b"identity"] +) diff --git a/tests/test_requests.py b/tests/test_requests.py new file mode 100644 index 00000000..a433e0e3 --- /dev/null +++ b/tests/test_requests.py @@ -0,0 +1,92 @@ +import pytest + +import httpcore + + +def test_host_header(): + request = httpcore.Request("GET", "http://example.org") + assert request.headers == [ + (b"host", b"example.org"), + (b"accept-encoding", b"deflate, gzip, br"), + ] + + +def test_content_length_header(): + request = httpcore.Request("POST", "http://example.org", body=b"test 123") + assert request.headers == [ + (b"host", b"example.org"), + (b"content-length", b"8"), + (b"accept-encoding", b"deflate, gzip, br"), + ] + + +def test_transfer_encoding_header(): + async def streaming_body(data): + yield data + + body = streaming_body(b"test 123") + + request = httpcore.Request("POST", "http://example.org", body=body) + assert request.headers == [ + (b"host", b"example.org"), + (b"transfer-encoding", b"chunked"), + (b"accept-encoding", b"deflate, gzip, br"), + ] + + +def test_override_host_header(): + headers = [(b"host", b"1.2.3.4:80")] + + request = httpcore.Request("GET", "http://example.org", headers=headers) + assert request.headers == [ + (b"accept-encoding", b"deflate, gzip, br"), + (b"host", b"1.2.3.4:80"), + ] + + +def test_override_accept_encoding_header(): + headers = [(b"accept-encoding", b"identity")] + + request = httpcore.Request("GET", "http://example.org", headers=headers) + assert request.headers == [ + (b"host", b"example.org"), + (b"accept-encoding", b"identity"), + ] + + +def test_override_content_length_header(): + async def streaming_body(data): + yield data + + body = streaming_body(b"test 123") + headers = [(b"content-length", b"8")] + + request = httpcore.Request("POST", "http://example.org", body=body, headers=headers) + assert request.headers == [ + (b"host", b"example.org"), + (b"accept-encoding", b"deflate, gzip, br"), + (b"content-length", b"8"), + ] + + +def test_url(): + request = httpcore.Request("GET", "http://example.org") + assert request.url.scheme == "http" + assert request.url.port == 80 + assert request.url.target == "/" + + request = httpcore.Request("GET", "https://example.org/abc?foo=bar") + assert request.url.scheme == "https" + assert request.url.port == 443 + assert request.url.target == "/abc?foo=bar" + + +def test_invalid_urls(): + with pytest.raises(ValueError): + httpcore.Request("GET", "example.org") + + with pytest.raises(ValueError): + httpcore.Request("GET", "invalid://example.org") + + with pytest.raises(ValueError): + httpcore.Request("GET", "http:///foo")