]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Rejig request preparing
authorTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 13:25:24 +0000 (14:25 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 13:25:24 +0000 (14:25 +0100)
httpcore/adapters/redirects.py
httpcore/dispatch/connection.py
httpcore/dispatch/connection_pool.py
httpcore/dispatch/http11.py
httpcore/dispatch/http2.py
httpcore/models.py
tests/adapters/test_redirects.py
tests/test_requests.py

index 9ef287f720a1e0d9ece961404a4630fe69d4da8f..aa6c557e5743f0d36309b44ddebcdbd6db5f8159 100644 (file)
@@ -4,7 +4,7 @@ from urllib.parse import urljoin, urlparse
 from ..config import DEFAULT_MAX_REDIRECTS
 from ..exceptions import RedirectLoop, TooManyRedirects
 from ..interfaces import Adapter
-from ..models import URL, Request, Response
+from ..models import URL, Headers, Request, Response
 from ..status_codes import codes
 from ..utils import requote_uri
 
@@ -19,11 +19,12 @@ class RedirectAdapter(Adapter):
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         allow_redirects = options.pop("allow_redirects", True)
-        history = []
+        history = []  # type: typing.List[Response]
         seen_urls = set((request.url,))
 
         while True:
             response = await self.dispatch.send(request, **options)
+            response.history = list(history)
             if not allow_redirects or not response.is_redirect:
                 break
             history.append(response)
@@ -42,7 +43,8 @@ class RedirectAdapter(Adapter):
     def build_redirect_request(self, request: Request, response: Response) -> Request:
         method = self.redirect_method(request, response)
         url = self.redirect_url(request, response)
-        return Request(method=method, url=url)
+        headers = self.redirect_headers(request, url)
+        return Request(method=method, url=url, headers=headers)
 
     def redirect_method(self, request: Request, response: Response) -> str:
         """
@@ -89,3 +91,9 @@ class RedirectAdapter(Adapter):
             url = requote_uri(url)
 
         return URL(url)
+
+    def redirect_headers(self, request: Request, url: URL) -> Headers:
+        headers = Headers(request.headers)
+        if url.origin != request.url.origin:
+            del headers["Authorization"]
+        return headers
index 45134456b9ea2468225b29c587b9aa6dcecdff75..ffe82fcb0e4d25056d7756f669b53fb677a12fa5 100644 (file)
@@ -37,7 +37,7 @@ class HTTPConnection(Adapter):
         self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
 
     def prepare_request(self, request: Request) -> None:
-        pass
+        request.prepare()
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         if self.h11_connection is None and self.h2_connection is None:
index a536d5f38d45f3b87204c53fdce794bbf14f79fa..ca7a837fa5e0bfe49da71d70f3a2169fdb6a489f 100644 (file)
@@ -10,6 +10,7 @@ from ..config import (
     SSLConfig,
     TimeoutConfig,
 )
+from ..decoders import ACCEPT_ENCODING
 from ..exceptions import PoolTimeout
 from ..interfaces import Adapter
 from ..models import Origin, Request, Response
@@ -104,7 +105,7 @@ class ConnectionPool(Adapter):
         return len(self.keepalive_connections) + len(self.active_connections)
 
     def prepare_request(self, request: Request) -> None:
-        pass
+        request.prepare()
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         connection = await self.acquire_connection(request.url.origin)
index e4a5c83a565bf6309c953105ac718951e2b0aac2..8128dc18fd8027e5d7c925273875b8ed941625ca 100644 (file)
@@ -46,7 +46,7 @@ class HTTP11Connection(Adapter):
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
     def prepare_request(self, request: Request) -> None:
-        pass
+        request.prepare()
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         timeout = options.get("timeout")
@@ -87,6 +87,7 @@ class HTTP11Connection(Adapter):
             headers=headers,
             body=body,
             on_close=self.response_closed,
+            request=request,
         )
 
         if not stream:
index 7bd124b768bf099d3e0c04a7a9fda804294557fd..d44c5f60470c770625224788fc22ef2753a7ee93 100644 (file)
@@ -32,7 +32,7 @@ class HTTP2Connection(Adapter):
         self.initialized = False
 
     def prepare_request(self, request: Request) -> None:
-        pass
+        request.prepare()
 
     async def send(self, request: Request, **options: typing.Any) -> Response:
         timeout = options.get("timeout")
@@ -75,6 +75,7 @@ class HTTP2Connection(Adapter):
             headers=headers,
             body=body,
             on_close=on_close,
+            request=request,
         )
 
         if not stream:
index b0a4723d6a1d4621820d5afa33298fb7e21d9176..9235240309b6a62bb7578f66d1c46d53e6559236 100644 (file)
@@ -79,6 +79,11 @@ class URL:
     def __str__(self) -> str:
         return self.components.geturl()
 
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        url_str = str(self)
+        return f"{class_name}({url_str!r})"
+
 
 class Origin:
     def __init__(self, url: typing.Union[str, URL]) -> None:
@@ -100,13 +105,21 @@ class Origin:
         return hash((self.is_ssl, self.hostname, self.port))
 
 
+HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]]
+
+
 class Headers(typing.MutableMapping[str, str]):
     """
     A case-insensitive multidict.
     """
 
-    def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
-        self._list = [(k.lower(), v) for k, v in headers]
+    def __init__(self, headers: HeaderTypes = None) -> None:
+        if headers is None:
+            self._list = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+        elif isinstance(headers, Headers):
+            self._list = list(headers.raw)
+        else:
+            self._list = [(k.lower(), v) for k, v in headers]
 
     @property
     def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
@@ -213,7 +226,7 @@ class Request:
         method: str,
         url: typing.Union[str, URL],
         *,
-        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
+        headers: HeaderTypes = None,
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
     ):
         self.method = method.upper()
@@ -224,26 +237,24 @@ class Request:
         else:
             self.is_streaming = True
             self.body_aiter = body
-        self.headers = self.build_headers(headers)
-
-    def build_headers(
-        self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
-    ) -> Headers:
-        has_host = False
-        has_content_length = False
-        has_accept_encoding = False
-
-        for header, value in init_headers:
-            header = header.strip().lower()
-            if header == b"host":
-                has_host = True
-            elif header in (b"content-length", b"transfer-encoding"):
-                has_content_length = True
-            elif header == b"accept-encoding":
-                has_accept_encoding = True
+        self.headers = Headers(headers)
+
+    async def stream(self) -> typing.AsyncIterator[bytes]:
+        if self.is_streaming:
+            async for part in self.body_aiter:
+                yield part
+        elif self.body:
+            yield self.body
 
+    def prepare(self) -> None:
         auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
 
+        has_host = "host" in self.headers
+        has_content_length = (
+            "content-length" in self.headers or "transfer-encoding" in self.headers
+        )
+        has_accept_encoding = "accept-encoding" in self.headers
+
         if not has_host:
             auto_headers.append((b"host", self.url.netloc.encode("ascii")))
         if not has_content_length:
@@ -255,14 +266,8 @@ class Request:
         if not has_accept_encoding:
             auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
 
-        return Headers(auto_headers + init_headers)
-
-    async def stream(self) -> typing.AsyncIterator[bytes]:
-        if self.is_streaming:
-            async for part in self.body_aiter:
-                yield part
-        elif self.body:
-            yield self.body
+        for item in reversed(auto_headers):
+            self.headers.raw.insert(0, item)
 
 
 class Response:
@@ -275,6 +280,8 @@ class Response:
         headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         on_close: typing.Callable = None,
+        request: Request = None,
+        history: typing.List["Response"] = None,
     ):
         self.status_code = status_code
         if not reason:
@@ -310,6 +317,13 @@ class Response:
         else:
             self.body_aiter = body
 
+        self.request = request
+        self.history = [] if history is None else list(history)
+
+    @property
+    def url(self) -> typing.Optional[URL]:
+        return None if self.request is None else self.request.url
+
     async def read(self) -> bytes:
         """
         Read and return the response content.
@@ -358,4 +372,6 @@ class Response:
 
     @property
     def is_redirect(self) -> bool:
-        return self.status_code in (301, 302, 303, 307, 308)
+        return (
+            self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
+        )
index 3609dd7342d972e3f6b8815df5eaedd87c3da575..ab371a047d481ec3479c966ce535743f7f94c842 100644 (file)
@@ -1,8 +1,10 @@
+import json
 from urllib.parse import parse_qs
 
 import pytest
 
 from httpcore import (
+    URL,
     Adapter,
     RedirectAdapter,
     RedirectLoop,
@@ -19,32 +21,60 @@ class MockDispatch(Adapter):
 
     async def send(self, request: Request, **options) -> Response:
         if request.url.path == "/redirect_301":  # "Moved Permanently"
-            return Response(301, headers=[(b"location", b"https://example.org/")])
+            return Response(
+                301, headers=[(b"location", b"https://example.org/")], request=request
+            )
 
         elif request.url.path == "/redirect_302":  # "Found"
-            return Response(302, headers=[(b"location", b"https://example.org/")])
+            return Response(
+                302, headers=[(b"location", b"https://example.org/")], request=request
+            )
 
         elif request.url.path == "/redirect_303":  # "See Other"
-            return Response(303, headers=[(b"location", b"https://example.org/")])
+            return Response(
+                303, headers=[(b"location", b"https://example.org/")], request=request
+            )
 
         elif request.url.path == "/relative_redirect":
-            return Response(codes.see_other, headers=[(b"location", b"/")])
+            return Response(
+                codes.see_other, headers=[(b"location", b"/")], request=request
+            )
 
         elif request.url.path == "/no_scheme_redirect":
-            return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
+            return Response(
+                codes.see_other,
+                headers=[(b"location", b"//example.org/")],
+                request=request,
+            )
 
         elif request.url.path == "/multiple_redirects":
             params = parse_qs(request.url.query)
             count = int(params.get("count", "0")[0])
+            redirect_count = count - 1
             code = codes.see_other if count else codes.ok
-            location = "/multiple_redirects?count=" + str(count - 1)
+            location = "/multiple_redirects"
+            if redirect_count:
+                location += "?count=" + str(redirect_count)
             headers = [(b"location", location.encode())] if count else []
-            return Response(code, headers=headers)
+            return Response(code, headers=headers, request=request)
 
         if request.url.path == "/redirect_loop":
-            return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
+            return Response(
+                codes.see_other,
+                headers=[(b"location", b"/redirect_loop")],
+                request=request,
+            )
 
-        return Response(codes.ok, body=b"Hello, world!")
+        elif request.url.path == "/cross_domain":
+            location = b"https://example.org/cross_domain_target"
+            return Response(301, headers=[(b"location", location)], request=request)
+
+        elif request.url.path == "/cross_domain_target":
+            headers = {k.decode(): v.decode() for k, v in request.headers.raw}
+            body = json.dumps({"headers": headers}).encode()
+            return Response(codes.ok, body=body, request=request)
+
+        return Response(codes.ok, body=b"Hello, world!", request=request)
 
 
 @pytest.mark.asyncio
@@ -52,6 +82,8 @@ async def test_redirect_301():
     client = RedirectAdapter(MockDispatch())
     response = await client.request("POST", "https://example.org/redirect_301")
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -59,6 +91,8 @@ async def test_redirect_302():
     client = RedirectAdapter(MockDispatch())
     response = await client.request("POST", "https://example.org/redirect_302")
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -66,6 +100,8 @@ async def test_redirect_303():
     client = RedirectAdapter(MockDispatch())
     response = await client.request("GET", "https://example.org/redirect_303")
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -73,6 +109,8 @@ async def test_relative_redirect():
     client = RedirectAdapter(MockDispatch())
     response = await client.request("GET", "https://example.org/relative_redirect")
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -80,13 +118,19 @@ async def test_no_scheme_redirect():
     client = RedirectAdapter(MockDispatch())
     response = await client.request("GET", "https://example.org/no_scheme_redirect")
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
 async def test_fragment_redirect():
     client = RedirectAdapter(MockDispatch())
-    response = await client.request("GET", "https://example.org/relative_redirect#fragment")
+    response = await client.request(
+        "GET", "https://example.org/relative_redirect#fragment"
+    )
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/#fragment")
+    assert len(response.history) == 1
 
 
 @pytest.mark.asyncio
@@ -96,6 +140,8 @@ async def test_multiple_redirects():
         "GET", "https://example.org/multiple_redirects?count=20"
     )
     assert response.status_code == codes.ok
+    assert response.url == URL("https://example.org/multiple_redirects")
+    assert len(response.history) == 20
 
 
 @pytest.mark.asyncio
@@ -110,3 +156,25 @@ async def test_redirect_loop():
     client = RedirectAdapter(MockDispatch())
     with pytest.raises(RedirectLoop):
         await client.request("GET", "https://example.org/redirect_loop")
+
+
+@pytest.mark.asyncio
+async def test_cross_domain_redirect():
+    client = RedirectAdapter(MockDispatch())
+    headers = [(b"Authorization", b"abc")]
+    url = "https://example.com/cross_domain"
+    response = await client.request("GET", url, headers=headers)
+    data = json.loads(response.body.decode())
+    assert response.url == URL("https://example.org/cross_domain_target")
+    assert data == {"headers": {}}
+
+
+@pytest.mark.asyncio
+async def test_same_domain_redirect():
+    client = RedirectAdapter(MockDispatch())
+    headers = [(b"Authorization", b"abc")]
+    url = "https://example.org/cross_domain"
+    response = await client.request("GET", url, headers=headers)
+    data = json.loads(response.body.decode())
+    assert response.url == URL("https://example.org/cross_domain_target")
+    assert data == {"headers": {"authorization": "abc"}}
index 840dde1a8aa910107f108568b1f15ce068abd72e..4df3529cd5b4b5a401bd4cbf7963fc5f6730abd7 100644 (file)
@@ -5,6 +5,7 @@ import httpcore
 
 def test_host_header():
     request = httpcore.Request("GET", "http://example.org")
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
     )
@@ -12,6 +13,7 @@ def test_host_header():
 
 def test_content_length_header():
     request = httpcore.Request("POST", "http://example.org", body=b"test 123")
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [
             (b"host", b"example.org"),
@@ -28,6 +30,7 @@ def test_transfer_encoding_header():
     body = streaming_body(b"test 123")
 
     request = httpcore.Request("POST", "http://example.org", body=body)
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [
             (b"host", b"example.org"),
@@ -41,6 +44,7 @@ def test_override_host_header():
     headers = [(b"host", b"1.2.3.4:80")]
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
     )
@@ -50,6 +54,7 @@ def test_override_accept_encoding_header():
     headers = [(b"accept-encoding", b"identity")]
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [(b"host", b"example.org"), (b"accept-encoding", b"identity")]
     )
@@ -63,6 +68,7 @@ def test_override_content_length_header():
     headers = [(b"content-length", b"8")]
 
     request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
+    request.prepare()
     assert request.headers == httpcore.Headers(
         [
             (b"host", b"example.org"),