]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Allow setting headers at the Client level (#159)
authorStephen Brown II <Stephen.Brown2@gmail.com>
Mon, 29 Jul 2019 01:39:35 +0000 (20:39 -0500)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Mon, 29 Jul 2019 01:39:35 +0000 (20:39 -0500)
docs/api.md
httpx/client.py
httpx/models.py
tests/client/test_headers.py [new file with mode: 0755]

index 1c6887cb1a1e4d7bf25a709905598104fcfad2ee..dbc3d5215f523cefafcf1171c1c680606f9a87e2 100644 (file)
@@ -20,7 +20,7 @@
 >>> response = client.get('https://example.org')
 ```
 
-* `def __init__([auth], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])`
+* `def __init__([auth], [headers], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])`
 * `def .get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
 * `def .options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
 * `def .head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
index 7bf759aa4bdd917d056dc3eecd615ac0f6a5c2ae..60088c26574be3b4115eef8ac427d35e51414b22 100644 (file)
@@ -49,6 +49,7 @@ class BaseClient:
     def __init__(
         self,
         auth: AuthTypes = None,
+        headers: HeaderTypes = None,
         cookies: CookieTypes = None,
         verify: VerifyTypes = True,
         cert: CertTypes = None,
@@ -95,6 +96,7 @@ class BaseClient:
             self.base_url = URL(base_url)
 
         self.auth = auth
+        self.headers = Headers(headers)
         self.cookies = Cookies(cookies)
         self.max_redirects = max_redirects
         self.dispatch = async_dispatch
@@ -109,6 +111,15 @@ class BaseClient:
             return merged_cookies
         return cookies
 
+    def merge_headers(
+        self, headers: HeaderTypes = None
+    ) -> typing.Optional[HeaderTypes]:
+        if headers or self.headers:
+            merged_headers = Headers(self.headers)
+            merged_headers.update(headers)
+            return merged_headers
+        return headers
+
     async def send(
         self,
         request: AsyncRequest,
@@ -527,6 +538,7 @@ class AsyncClient(BaseClient):
         timeout: TimeoutTypes = None,
     ) -> AsyncResponse:
         url = self.base_url.join(url)
+        headers = self.merge_headers(headers)
         cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
             method,
@@ -608,6 +620,7 @@ class Client(BaseClient):
         timeout: TimeoutTypes = None,
     ) -> Response:
         url = self.base_url.join(url)
+        headers = self.merge_headers(headers)
         cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
             method,
index 2f29be5f9f8e5380d877f7a661ad0835484c6107..bea47faf0dbda857666de8bf37d5364624ebfb74 100644 (file)
@@ -415,6 +415,11 @@ class Headers(typing.MutableMapping[str, str]):
             split_values.extend([item.strip() for item in value.split(",")])
         return split_values
 
+    def update(self, headers: HeaderTypes = None) -> None:  # type: ignore
+        headers = Headers(headers)
+        for header in headers:
+            self[header] = headers[header]
+
     def __getitem__(self, key: str) -> str:
         """
         Return a single header value.
diff --git a/tests/client/test_headers.py b/tests/client/test_headers.py
new file mode 100755 (executable)
index 0000000..bf82a57
--- /dev/null
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+
+import json
+from httpx import (
+    __version__,
+    Client,
+    AsyncRequest,
+    AsyncResponse,
+    VerifyTypes,
+    CertTypes,
+    TimeoutTypes,
+    AsyncDispatcher,
+)
+
+
+class MockDispatch(AsyncDispatcher):
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        if request.url.path.startswith("/echo_headers"):
+            request_headers = dict(request.headers.items())
+            body = json.dumps({"headers": request_headers}).encode()
+            return AsyncResponse(200, content=body, request=request)
+
+
+def test_client_header():
+    """
+    Set a header in the Client.
+    """
+    url = "http://example.org/echo_headers"
+    headers = {"Example-Header": "example-value"}
+
+    with Client(dispatch=MockDispatch(), headers=headers) as client:
+        response = client.get(url)
+
+    assert response.status_code == 200
+    assert response.json() == {
+        "headers": {
+            "accept": "*/*",
+            "accept-encoding": "gzip, deflate, br",
+            "connection": "keep-alive",
+            "example-header": "example-value",
+            "host": "example.org",
+            "user-agent": f"python-httpx/{__version__}",
+        }
+    }
+
+
+def test_header_merge():
+    url = "http://example.org/echo_headers"
+    client_headers = {"User-Agent": "python-myclient/0.2.1"}
+    request_headers = {"X-Auth-Token": "FooBarBazToken"}
+    with Client(dispatch=MockDispatch(), headers=client_headers) as client:
+        response = client.get(url, headers=request_headers)
+
+    assert response.status_code == 200
+    assert response.json() == {
+        "headers": {
+            "accept": "*/*",
+            "accept-encoding": "gzip, deflate, br",
+            "connection": "keep-alive",
+            "host": "example.org",
+            "user-agent": "python-myclient/0.2.1",
+            "x-auth-token": "FooBarBazToken",
+        }
+    }
+
+
+def test_header_merge_conflicting_headers():
+    url = "http://example.org/echo_headers"
+    client_headers = {"X-Auth-Token": "FooBar"}
+    request_headers = {"X-Auth-Token": "BazToken"}
+    with Client(dispatch=MockDispatch(), headers=client_headers) as client:
+        response = client.get(url, headers=request_headers)
+
+    assert response.status_code == 200
+    assert response.json() == {
+        "headers": {
+            "accept": "*/*",
+            "accept-encoding": "gzip, deflate, br",
+            "connection": "keep-alive",
+            "host": "example.org",
+            "user-agent": f"python-httpx/{__version__}",
+            "x-auth-token": "BazToken",
+        }
+    }
+
+
+def test_header_update():
+    url = "http://example.org/echo_headers"
+    with Client(dispatch=MockDispatch()) as client:
+        first_response = client.get(url)
+        client.headers.update(
+            {"User-Agent": "python-myclient/0.2.1", "Another-Header": "AThing"}
+        )
+        second_response = client.get(url)
+
+    assert first_response.status_code == 200
+    assert first_response.json() == {
+        "headers": {
+            "accept": "*/*",
+            "accept-encoding": "gzip, deflate, br",
+            "connection": "keep-alive",
+            "host": "example.org",
+            "user-agent": f"python-httpx/{__version__}",
+        }
+    }
+
+    assert second_response.status_code == 200
+    assert second_response.json() == {
+        "headers": {
+            "accept": "*/*",
+            "accept-encoding": "gzip, deflate, br",
+            "another-header": "AThing",
+            "connection": "keep-alive",
+            "host": "example.org",
+            "user-agent": "python-myclient/0.2.1",
+        }
+    }