]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add auto headers on Request
authorTom Christie <tom@tomchristie.com>
Tue, 16 Apr 2019 11:14:20 +0000 (12:14 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 16 Apr 2019 11:14:20 +0000 (12:14 +0100)
httpcore/connections.py
httpcore/datastructures.py
httpcore/decoders.py
tests/test_requests.py [new file with mode: 0644]

index 312fc46cd946f27e9120c8291d69278548d95365..e181cb50ad76f587927557752163b312647f82ba 100644 (file)
@@ -43,13 +43,7 @@ class Connection:
     async def send(self, request: Request, stream: bool = False) -> Response:
         method = request.method.encode()
         target = request.url.target
-        host_header = (b"host", request.url.netloc.encode("ascii"))
-        if request.is_streaming:
-            content_length = (b"transfer-encoding", b"chunked")
-        else:
-            content_length = (b"content-length", str(len(request.body)).encode())
-
-        headers = [host_header, content_length] + request.headers
+        headers = request.headers
 
         #  Start sending the request.
         event = h11.Request(method=method, target=target, headers=headers)
index bf8dae786541438020550c3f7a225574ba7d6f29..28f3bf91f92c546208f21ebdbd26adfc7001240e 100644 (file)
@@ -1,7 +1,13 @@
 import typing
 from urllib.parse import urlsplit
 
-from .decoders import SUPPORTED_DECODERS, Decoder, IdentityDecoder, MultiDecoder
+from .decoders import (
+    ACCEPT_ENCODING,
+    SUPPORTED_DECODERS,
+    Decoder,
+    IdentityDecoder,
+    MultiDecoder,
+)
 from .exceptions import ResponseClosed, StreamConsumed
 
 
@@ -59,13 +65,13 @@ class Request:
     def __init__(
         self,
         method: str,
-        url: URL,
+        url: typing.Union[str, URL],
         *,
         headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
     ):
         self.method = method
-        self.url = url
+        self.url = URL(url) if isinstance(url, str) else url
         self.headers = list(headers)
         if isinstance(body, bytes):
             self.is_streaming = False
@@ -73,6 +79,36 @@ class Request:
         else:
             self.is_streaming = True
             self.body_aiter = body
+        self.headers = self._auto_headers() + self.headers
+
+    def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+        has_host = False
+        has_content_length = False
+        has_accept_encoding = False
+
+        for header, value in self.headers:
+            header = header.strip().lower()
+            if header == b"host":
+                has_host = True
+            elif header in (b"content-length", b"transfer-encoding"):
+                has_content_length = True
+            elif header == b"accept-encoding":
+                has_accept_encoding = True
+
+        headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+
+        if not has_host:
+            headers.append((b"host", self.url.netloc.encode("ascii")))
+        if not has_content_length:
+            if self.is_streaming:
+                headers.append((b"transfer-encoding", b"chunked"))
+            elif self.body:
+                content_length = str(len(self.body)).encode()
+                headers.append((b"content-length", content_length))
+        if not has_accept_encoding:
+            headers.append((b"accept-encoding", ACCEPT_ENCODING))
+
+        return headers
 
     async def stream(self) -> typing.AsyncIterator[bytes]:
         assert self.is_streaming
@@ -131,7 +167,7 @@ class Response:
     async def stream(self) -> typing.AsyncIterator[bytes]:
         """
         A byte-iterator over the decoded response content.
-        This will allow us to handle gzip, deflate, and brotli encoded responses.
+        This allows us to handle gzip, deflate, and brotli encoded responses.
         """
         if hasattr(self, "body"):
             yield self.body
index 8b464f5c4fdda372057dddabdfb48d3c06b649e6..1e61998a3ad53ce116a559e7155fb5b69a9fc167 100644 (file)
@@ -107,12 +107,17 @@ class MultiDecoder(Decoder):
 
 
 SUPPORTED_DECODERS = {
-    b"gzip": GZipDecoder,
-    b"deflate": DeflateDecoder,
     b"identity": IdentityDecoder,
+    b"deflate": DeflateDecoder,
+    b"gzip": GZipDecoder,
     b"br": BrotliDecoder,
 }
 
 
 if brotli is None:
     SUPPORTED_DECODERS.pop(b"br")  # pragma: nocover
+
+
+ACCEPT_ENCODING = b", ".join(
+    [key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
+)
diff --git a/tests/test_requests.py b/tests/test_requests.py
new file mode 100644 (file)
index 0000000..a433e0e
--- /dev/null
@@ -0,0 +1,92 @@
+import pytest
+
+import httpcore
+
+
+def test_host_header():
+    request = httpcore.Request("GET", "http://example.org")
+    assert request.headers == [
+        (b"host", b"example.org"),
+        (b"accept-encoding", b"deflate, gzip, br"),
+    ]
+
+
+def test_content_length_header():
+    request = httpcore.Request("POST", "http://example.org", body=b"test 123")
+    assert request.headers == [
+        (b"host", b"example.org"),
+        (b"content-length", b"8"),
+        (b"accept-encoding", b"deflate, gzip, br"),
+    ]
+
+
+def test_transfer_encoding_header():
+    async def streaming_body(data):
+        yield data
+
+    body = streaming_body(b"test 123")
+
+    request = httpcore.Request("POST", "http://example.org", body=body)
+    assert request.headers == [
+        (b"host", b"example.org"),
+        (b"transfer-encoding", b"chunked"),
+        (b"accept-encoding", b"deflate, gzip, br"),
+    ]
+
+
+def test_override_host_header():
+    headers = [(b"host", b"1.2.3.4:80")]
+
+    request = httpcore.Request("GET", "http://example.org", headers=headers)
+    assert request.headers == [
+        (b"accept-encoding", b"deflate, gzip, br"),
+        (b"host", b"1.2.3.4:80"),
+    ]
+
+
+def test_override_accept_encoding_header():
+    headers = [(b"accept-encoding", b"identity")]
+
+    request = httpcore.Request("GET", "http://example.org", headers=headers)
+    assert request.headers == [
+        (b"host", b"example.org"),
+        (b"accept-encoding", b"identity"),
+    ]
+
+
+def test_override_content_length_header():
+    async def streaming_body(data):
+        yield data
+
+    body = streaming_body(b"test 123")
+    headers = [(b"content-length", b"8")]
+
+    request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
+    assert request.headers == [
+        (b"host", b"example.org"),
+        (b"accept-encoding", b"deflate, gzip, br"),
+        (b"content-length", b"8"),
+    ]
+
+
+def test_url():
+    request = httpcore.Request("GET", "http://example.org")
+    assert request.url.scheme == "http"
+    assert request.url.port == 80
+    assert request.url.target == "/"
+
+    request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
+    assert request.url.scheme == "https"
+    assert request.url.port == 443
+    assert request.url.target == "/abc?foo=bar"
+
+
+def test_invalid_urls():
+    with pytest.raises(ValueError):
+        httpcore.Request("GET", "example.org")
+
+    with pytest.raises(ValueError):
+        httpcore.Request("GET", "invalid://example.org")
+
+    with pytest.raises(ValueError):
+        httpcore.Request("GET", "http:///foo")