from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
-from ..models import URL, Request, Response
+from ..models import URL, Headers, Request, Response
from ..status_codes import codes
from ..utils import requote_uri
async def send(self, request: Request, **options: typing.Any) -> Response:
allow_redirects = options.pop("allow_redirects", True)
- history = []
+ history = [] # type: typing.List[Response]
seen_urls = set((request.url,))
while True:
response = await self.dispatch.send(request, **options)
+ response.history = list(history)
if not allow_redirects or not response.is_redirect:
break
history.append(response)
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
- return Request(method=method, url=url)
+ headers = self.redirect_headers(request, url)
+ return Request(method=method, url=url, headers=headers)
def redirect_method(self, request: Request, response: Response) -> str:
"""
url = requote_uri(url)
return URL(url)
+
+ def redirect_headers(self, request: Request, url: URL) -> Headers:
+ headers = Headers(request.headers)
+ if url.origin != request.url.origin:
+ del headers["Authorization"]
+ return headers
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
def prepare_request(self, request: Request) -> None:
- pass
+ request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
if self.h11_connection is None and self.h2_connection is None:
SSLConfig,
TimeoutConfig,
)
+from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
from ..interfaces import Adapter
from ..models import Origin, Request, Response
return len(self.keepalive_connections) + len(self.active_connections)
def prepare_request(self, request: Request) -> None:
- pass
+ request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
connection = await self.acquire_connection(request.url.origin)
self.h11_state = h11.Connection(our_role=h11.CLIENT)
def prepare_request(self, request: Request) -> None:
- pass
+ request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
headers=headers,
body=body,
on_close=self.response_closed,
+ request=request,
)
if not stream:
self.initialized = False
def prepare_request(self, request: Request) -> None:
- pass
+ request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
headers=headers,
body=body,
on_close=on_close,
+ request=request,
)
if not stream:
def __str__(self) -> str:
return self.components.geturl()
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ url_str = str(self)
+ return f"{class_name}({url_str!r})"
+
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
return hash((self.is_ssl, self.hostname, self.port))
+HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]]
+
+
class Headers(typing.MutableMapping[str, str]):
"""
A case-insensitive multidict.
"""
- def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
- self._list = [(k.lower(), v) for k, v in headers]
+ def __init__(self, headers: HeaderTypes = None) -> None:
+ if headers is None:
+ self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+ elif isinstance(headers, Headers):
+ self._list = list(headers.raw)
+ else:
+ self._list = [(k.lower(), v) for k, v in headers]
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
method: str,
url: typing.Union[str, URL],
*,
- headers: typing.List[typing.Tuple[bytes, bytes]] = [],
+ headers: HeaderTypes = None,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
else:
self.is_streaming = True
self.body_aiter = body
- self.headers = self.build_headers(headers)
-
- def build_headers(
- self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
- ) -> Headers:
- has_host = False
- has_content_length = False
- has_accept_encoding = False
-
- for header, value in init_headers:
- header = header.strip().lower()
- if header == b"host":
- has_host = True
- elif header in (b"content-length", b"transfer-encoding"):
- has_content_length = True
- elif header == b"accept-encoding":
- has_accept_encoding = True
+ self.headers = Headers(headers)
+
+ async def stream(self) -> typing.AsyncIterator[bytes]:
+ if self.is_streaming:
+ async for part in self.body_aiter:
+ yield part
+ elif self.body:
+ yield self.body
+ def prepare(self) -> None:
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+ has_host = "host" in self.headers
+ has_content_length = (
+ "content-length" in self.headers or "transfer-encoding" in self.headers
+ )
+ has_accept_encoding = "accept-encoding" in self.headers
+
if not has_host:
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if not has_accept_encoding:
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
- return Headers(auto_headers + init_headers)
-
- async def stream(self) -> typing.AsyncIterator[bytes]:
- if self.is_streaming:
- async for part in self.body_aiter:
- yield part
- elif self.body:
- yield self.body
+ for item in reversed(auto_headers):
+ self.headers.raw.insert(0, item)
class Response:
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
+ request: Request = None,
+ history: typing.List["Response"] = None,
):
self.status_code = status_code
if not reason:
else:
self.body_aiter = body
+ self.request = request
+ self.history = [] if history is None else list(history)
+
+ @property
+ def url(self) -> typing.Optional[URL]:
+ return None if self.request is None else self.request.url
+
async def read(self) -> bytes:
"""
Read and return the response content.
@property
def is_redirect(self) -> bool:
- return self.status_code in (301, 302, 303, 307, 308)
+ return (
+ self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
+ )
+import json
from urllib.parse import parse_qs
import pytest
from httpcore import (
+ URL,
Adapter,
RedirectAdapter,
RedirectLoop,
async def send(self, request: Request, **options) -> Response:
if request.url.path == "/redirect_301": # "Moved Permanently"
- return Response(301, headers=[(b"location", b"https://example.org/")])
+ return Response(
+ 301, headers=[(b"location", b"https://example.org/")], request=request
+ )
elif request.url.path == "/redirect_302": # "Found"
- return Response(302, headers=[(b"location", b"https://example.org/")])
+ return Response(
+ 302, headers=[(b"location", b"https://example.org/")], request=request
+ )
elif request.url.path == "/redirect_303": # "See Other"
- return Response(303, headers=[(b"location", b"https://example.org/")])
+ return Response(
+ 303, headers=[(b"location", b"https://example.org/")], request=request
+ )
elif request.url.path == "/relative_redirect":
- return Response(codes.see_other, headers=[(b"location", b"/")])
+ return Response(
+ codes.see_other, headers=[(b"location", b"/")], request=request
+ )
elif request.url.path == "/no_scheme_redirect":
- return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
+ return Response(
+ codes.see_other,
+ headers=[(b"location", b"//example.org/")],
+ request=request,
+ )
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
count = int(params.get("count", "0")[0])
+ redirect_count = count - 1
code = codes.see_other if count else codes.ok
- location = "/multiple_redirects?count=" + str(count - 1)
+ location = "/multiple_redirects"
+ if redirect_count:
+ location += "?count=" + str(redirect_count)
headers = [(b"location", location.encode())] if count else []
- return Response(code, headers=headers)
+ return Response(code, headers=headers, request=request)
if request.url.path == "/redirect_loop":
- return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
+ return Response(
+ codes.see_other,
+ headers=[(b"location", b"/redirect_loop")],
+ request=request,
+ )
- return Response(codes.ok, body=b"Hello, world!")
+ elif request.url.path == "/cross_domain":
+ location = b"https://example.org/cross_domain_target"
+ return Response(301, headers=[(b"location", location)], request=request)
+
+ elif request.url.path == "/cross_domain_target":
+ headers = {k.decode(): v.decode() for k, v in request.headers.raw}
+ body = json.dumps({"headers": headers}).encode()
+ return Response(codes.ok, body=body, request=request)
+
+ return Response(codes.ok, body=b"Hello, world!", request=request)
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_301")
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/")
+ assert len(response.history) == 1
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_302")
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/")
+ assert len(response.history) == 1
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/redirect_303")
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/")
+ assert len(response.history) == 1
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect")
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/")
+ assert len(response.history) == 1
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/no_scheme_redirect")
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/")
+ assert len(response.history) == 1
@pytest.mark.asyncio
async def test_fragment_redirect():
client = RedirectAdapter(MockDispatch())
- response = await client.request("GET", "https://example.org/relative_redirect#fragment")
+ response = await client.request(
+ "GET", "https://example.org/relative_redirect#fragment"
+ )
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/#fragment")
+ assert len(response.history) == 1
@pytest.mark.asyncio
"GET", "https://example.org/multiple_redirects?count=20"
)
assert response.status_code == codes.ok
+ assert response.url == URL("https://example.org/multiple_redirects")
+ assert len(response.history) == 20
@pytest.mark.asyncio
client = RedirectAdapter(MockDispatch())
with pytest.raises(RedirectLoop):
await client.request("GET", "https://example.org/redirect_loop")
+
+
+@pytest.mark.asyncio
+async def test_cross_domain_redirect():
+ client = RedirectAdapter(MockDispatch())
+ headers = [(b"Authorization", b"abc")]
+ url = "https://example.com/cross_domain"
+ response = await client.request("GET", url, headers=headers)
+ data = json.loads(response.body.decode())
+ assert response.url == URL("https://example.org/cross_domain_target")
+ assert data == {"headers": {}}
+
+
+@pytest.mark.asyncio
+async def test_same_domain_redirect():
+ client = RedirectAdapter(MockDispatch())
+ headers = [(b"Authorization", b"abc")]
+ url = "https://example.org/cross_domain"
+ response = await client.request("GET", url, headers=headers)
+ data = json.loads(response.body.decode())
+ assert response.url == URL("https://example.org/cross_domain_target")
+ assert data == {"headers": {"authorization": "abc"}}
def test_host_header():
request = httpcore.Request("GET", "http://example.org")
+ request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
)
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
+ request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
body = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", body=body)
+ request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
headers = [(b"host", b"1.2.3.4:80")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
+ request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
)
headers = [(b"accept-encoding", b"identity")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
+ request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
)
headers = [(b"content-length", b"8")]
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
+ request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),