]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Nicer headers interface
authorTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 15:05:05 +0000 (16:05 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 15:05:05 +0000 (16:05 +0100)
httpcore/models.py
httpcore/utils.py
tests/adapters/test_redirects.py

index 9235240309b6a62bb7578f66d1c46d53e6559236..43f799a6299dcb0af0a791bded7bfbbf1d4ae526 100644 (file)
@@ -11,6 +11,7 @@ from .decoders import (
     MultiDecoder,
 )
 from .exceptions import ResponseClosed, StreamConsumed
+from .utils import normalize_header_key, normalize_header_value
 
 
 class URL:
@@ -105,7 +106,11 @@ class Origin:
         return hash((self.is_ssl, self.hostname, self.port))
 
 
-HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]]
+HeaderTypes = typing.Union[
+    "Headers",
+    typing.Dict[typing.AnyStr, typing.AnyStr],
+    typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
+]
 
 
 class Headers(typing.MutableMapping[str, str]):
@@ -118,8 +123,15 @@ class Headers(typing.MutableMapping[str, str]):
             self._list = []  # type: typing.List[typing.Tuple[bytes, bytes]]
         elif isinstance(headers, Headers):
             self._list = list(headers.raw)
+        elif isinstance(headers, dict):
+            self._list = [
+                (normalize_header_key(k), normalize_header_value(v))
+                for k, v in headers.items()
+            ]
         else:
-            self._list = [(k.lower(), v) for k, v in headers]
+            self._list = [
+                (normalize_header_key(k), normalize_header_value(v)) for k, v in headers
+            ]
 
     @property
     def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
@@ -239,6 +251,17 @@ class Request:
             self.body_aiter = body
         self.headers = Headers(headers)
 
+    async def read(self) -> bytes:
+        """
+        Read and return the response content.
+        """
+        if not hasattr(self, "body"):
+            body = b""
+            async for part in self.stream():
+                body += part
+            self.body = body
+        return self.body
+
     async def stream(self) -> typing.AsyncIterator[bytes]:
         if self.is_streaming:
             async for part in self.body_aiter:
index cd11858a8a8aab08bf1ab52883c3afd7fc03ec92..419e7ec27424b1097d897dc155e887f78340ff08 100644 (file)
@@ -1,3 +1,4 @@
+import typing
 from urllib.parse import quote
 
 from .exceptions import InvalidURL
@@ -50,3 +51,21 @@ def requote_uri(uri: str) -> str:
         # there may be unquoted '%'s in the URI. We need to make sure they're
         # properly quoted so they do not cause issues elsewhere.
         return quote(uri, safe=safe_without_percent)
+
+
+def normalize_header_key(value: typing.AnyStr) -> bytes:
+    """
+    Coerce str/bytes into a strictly byte-wise HTTP header key.
+    """
+    if isinstance(value, bytes):
+        return value.lower()
+    return value.encode("latin-1").lower()
+
+
+def normalize_header_value(value: typing.AnyStr) -> bytes:
+    """
+    Coerce str/bytes into a strictly byte-wise HTTP header value.
+    """
+    if isinstance(value, bytes):
+        return value
+    return value.encode("latin-1")
index 6499c762f2e77ee5c18b0505b54e6bd26f4005a6..56ee919751048b5623a3dc7f752d9f7dc382ac36 100644 (file)
@@ -21,32 +21,28 @@ class MockDispatch(Adapter):
         pass
 
     async def send(self, request: Request, **options) -> Response:
-        if request.url.path == "/redirect_301":  # "Moved Permanently"
-            return Response(
-                301, headers=[(b"location", b"https://example.org/")], request=request
-            )
+        if request.url.path == "/redirect_301":
+            status_code = codes.moved_permanently
+            headers = {"location": "https://example.org/"}
+            return Response(status_code, headers=headers, request=request)
 
-        elif request.url.path == "/redirect_302":  # "Found"
-            return Response(
-                302, headers=[(b"location", b"https://example.org/")], request=request
-            )
+        elif request.url.path == "/redirect_302":
+            status_code = codes.found
+            headers = {"location": "https://example.org/"}
+            return Response(status_code, headers=headers, request=request)
 
-        elif request.url.path == "/redirect_303":  # "See Other"
-            return Response(
-                303, headers=[(b"location", b"https://example.org/")], request=request
-            )
+        elif request.url.path == "/redirect_303":
+            status_code = codes.see_other
+            headers = {"location": "https://example.org/"}
+            return Response(status_code, headers=headers, request=request)
 
         elif request.url.path == "/relative_redirect":
-            return Response(
-                codes.see_other, headers=[(b"location", b"/")], request=request
-            )
+            headers = {"location": "/"}
+            return Response(codes.see_other, headers=headers, request=request)
 
         elif request.url.path == "/no_scheme_redirect":
-            return Response(
-                codes.see_other,
-                headers=[(b"location", b"//example.org/")],
-                request=request,
-            )
+            headers = {"location": "//example.org/"}
+            return Response(codes.see_other, headers=headers, request=request)
 
         elif request.url.path == "/multiple_redirects":
             params = parse_qs(request.url.query)
@@ -56,31 +52,30 @@ class MockDispatch(Adapter):
             location = "/multiple_redirects"
             if redirect_count:
                 location += "?count=" + str(redirect_count)
-            headers = [(b"location", location.encode())] if count else []
+            headers = {"location": location} if count else {}
             return Response(code, headers=headers, request=request)
 
         if request.url.path == "/redirect_loop":
-            return Response(
-                codes.see_other,
-                headers=[(b"location", b"/redirect_loop")],
-                request=request,
-            )
+            headers = {"location": "/redirect_loop"}
+            return Response(codes.see_other, headers=headers, request=request)
 
         elif request.url.path == "/cross_domain":
-            location = b"https://example.org/cross_domain_target"
-            return Response(301, headers=[(b"location", location)], request=request)
+            headers = {"location": "https://example.org/cross_domain_target"}
+            return Response(codes.see_other, headers=headers, request=request)
 
         elif request.url.path == "/cross_domain_target":
-            headers = {k.decode(): v.decode() for k, v in request.headers.raw}
+            headers = dict(request.headers.items())
             body = json.dumps({"headers": headers}).encode()
             return Response(codes.ok, body=body, request=request)
 
         elif request.url.path == "/redirect_body":
-            headers = [(b"location", b"/redirect_body_target")]
+            body = await request.read()
+            headers = {"location": "/redirect_body_target"}
             return Response(codes.permanent_redirect, headers=headers, request=request)
 
         elif request.url.path == "/redirect_body_target":
-            body = json.dumps({"body": request.body.decode()}).encode()
+            body = await request.read()
+            body = json.dumps({"body": body.decode()}).encode()
             return Response(codes.ok, body=body, request=request)
 
         return Response(codes.ok, body=b"Hello, world!", request=request)
@@ -134,9 +129,8 @@ async def test_no_scheme_redirect():
 @pytest.mark.asyncio
 async def test_fragment_redirect():
     client = RedirectAdapter(MockDispatch())
-    response = await client.request(
-        "GET", "https://example.org/relative_redirect#fragment"
-    )
+    url = "https://example.org/relative_redirect#fragment"
+    response = await client.request("GET", url)
     assert response.status_code == codes.ok
     assert response.url == URL("https://example.org/#fragment")
     assert len(response.history) == 1
@@ -145,9 +139,8 @@ async def test_fragment_redirect():
 @pytest.mark.asyncio
 async def test_multiple_redirects():
     client = RedirectAdapter(MockDispatch())
-    response = await client.request(
-        "GET", "https://example.org/multiple_redirects?count=20"
-    )
+    url = "https://example.org/multiple_redirects?count=20"
+    response = await client.request("GET", url)
     assert response.status_code == codes.ok
     assert response.url == URL("https://example.org/multiple_redirects")
     assert len(response.history) == 20
@@ -170,8 +163,8 @@ async def test_redirect_loop():
 @pytest.mark.asyncio
 async def test_cross_domain_redirect():
     client = RedirectAdapter(MockDispatch())
-    headers = [(b"Authorization", b"abc")]
     url = "https://example.com/cross_domain"
+    headers = {"Authorization": "abc"}
     response = await client.request("GET", url, headers=headers)
     data = json.loads(response.body.decode())
     assert response.url == URL("https://example.org/cross_domain_target")
@@ -181,8 +174,8 @@ async def test_cross_domain_redirect():
 @pytest.mark.asyncio
 async def test_same_domain_redirect():
     client = RedirectAdapter(MockDispatch())
-    headers = [(b"Authorization", b"abc")]
     url = "https://example.org/cross_domain"
+    headers = {"Authorization": "abc"}
     response = await client.request("GET", url, headers=headers)
     data = json.loads(response.body.decode())
     assert response.url == URL("https://example.org/cross_domain_target")
@@ -192,8 +185,8 @@ async def test_same_domain_redirect():
 @pytest.mark.asyncio
 async def test_body_redirect():
     client = RedirectAdapter(MockDispatch())
-    body = b"Example request body"
     url = "https://example.org/redirect_body"
+    body = b"Example request body"
     response = await client.request("POST", url, body=body)
     data = json.loads(response.body.decode())
     assert response.url == URL("https://example.org/redirect_body_target")