import typing
from urllib.parse import urlsplit
-from .decoders import SUPPORTED_DECODERS, Decoder, IdentityDecoder, MultiDecoder
+from .decoders import (
+ ACCEPT_ENCODING,
+ SUPPORTED_DECODERS,
+ Decoder,
+ IdentityDecoder,
+ MultiDecoder,
+)
from .exceptions import ResponseClosed, StreamConsumed
def __init__(
self,
method: str,
- url: URL,
+ url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method
- self.url = url
+ self.url = URL(url) if isinstance(url, str) else url
self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
else:
self.is_streaming = True
self.body_aiter = body
+ self.headers = self._auto_headers() + self.headers
+
+ def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+ has_host = False
+ has_content_length = False
+ has_accept_encoding = False
+
+ for header, value in self.headers:
+ header = header.strip().lower()
+ if header == b"host":
+ has_host = True
+ elif header in (b"content-length", b"transfer-encoding"):
+ has_content_length = True
+ elif header == b"accept-encoding":
+ has_accept_encoding = True
+
+ headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+
+ if not has_host:
+ headers.append((b"host", self.url.netloc.encode("ascii")))
+ if not has_content_length:
+ if self.is_streaming:
+ headers.append((b"transfer-encoding", b"chunked"))
+ elif self.body:
+ content_length = str(len(self.body)).encode()
+ headers.append((b"content-length", content_length))
+ if not has_accept_encoding:
+ headers.append((b"accept-encoding", ACCEPT_ENCODING))
+
+ return headers
async def stream(self) -> typing.AsyncIterator[bytes]:
assert self.is_streaming
async def stream(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
- This will allow us to handle gzip, deflate, and brotli encoded responses.
+ This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
--- /dev/null
+import pytest
+
+import httpcore
+
+
+def test_host_header():
+ request = httpcore.Request("GET", "http://example.org")
+ assert request.headers == [
+ (b"host", b"example.org"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ ]
+
+
+def test_content_length_header():
+ request = httpcore.Request("POST", "http://example.org", body=b"test 123")
+ assert request.headers == [
+ (b"host", b"example.org"),
+ (b"content-length", b"8"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ ]
+
+
+def test_transfer_encoding_header():
+ async def streaming_body(data):
+ yield data
+
+ body = streaming_body(b"test 123")
+
+ request = httpcore.Request("POST", "http://example.org", body=body)
+ assert request.headers == [
+ (b"host", b"example.org"),
+ (b"transfer-encoding", b"chunked"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ ]
+
+
+def test_override_host_header():
+ headers = [(b"host", b"1.2.3.4:80")]
+
+ request = httpcore.Request("GET", "http://example.org", headers=headers)
+ assert request.headers == [
+ (b"accept-encoding", b"deflate, gzip, br"),
+ (b"host", b"1.2.3.4:80"),
+ ]
+
+
+def test_override_accept_encoding_header():
+ headers = [(b"accept-encoding", b"identity")]
+
+ request = httpcore.Request("GET", "http://example.org", headers=headers)
+ assert request.headers == [
+ (b"host", b"example.org"),
+ (b"accept-encoding", b"identity"),
+ ]
+
+
+def test_override_content_length_header():
+ async def streaming_body(data):
+ yield data
+
+ body = streaming_body(b"test 123")
+ headers = [(b"content-length", b"8")]
+
+ request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
+ assert request.headers == [
+ (b"host", b"example.org"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ (b"content-length", b"8"),
+ ]
+
+
+def test_url():
+ request = httpcore.Request("GET", "http://example.org")
+ assert request.url.scheme == "http"
+ assert request.url.port == 80
+ assert request.url.target == "/"
+
+ request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
+ assert request.url.scheme == "https"
+ assert request.url.port == 443
+ assert request.url.target == "/abc?foo=bar"
+
+
+def test_invalid_urls():
+ with pytest.raises(ValueError):
+ httpcore.Request("GET", "example.org")
+
+ with pytest.raises(ValueError):
+ httpcore.Request("GET", "invalid://example.org")
+
+ with pytest.raises(ValueError):
+ httpcore.Request("GET", "http:///foo")