From: Stephen Brown II Date: Mon, 29 Jul 2019 01:39:35 +0000 (-0500) Subject: Allow setting headers at the Client level (#159) X-Git-Tag: 0.7.0~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=72e6f478970cf2c51d3fd23d5f3aa3150bf0e82a;p=thirdparty%2Fhttpx.git Allow setting headers at the Client level (#159) --- diff --git a/docs/api.md b/docs/api.md index 1c6887cb..dbc3d521 100644 --- a/docs/api.md +++ b/docs/api.md @@ -20,7 +20,7 @@ >>> response = client.get('https://example.org') ``` -* `def __init__([auth], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])` +* `def __init__([auth], [headers], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])` * `def .get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])` * `def .options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])` * `def .head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])` diff --git a/httpx/client.py b/httpx/client.py index 7bf759aa..60088c26 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -49,6 +49,7 @@ class BaseClient: def __init__( self, auth: AuthTypes = None, + headers: HeaderTypes = None, cookies: CookieTypes = None, verify: VerifyTypes = True, cert: CertTypes = None, @@ -95,6 +96,7 @@ class BaseClient: self.base_url = URL(base_url) self.auth = auth + self.headers = Headers(headers) self.cookies = Cookies(cookies) self.max_redirects = max_redirects self.dispatch = async_dispatch @@ -109,6 +111,15 @@ class BaseClient: return merged_cookies return cookies + def merge_headers( + self, headers: HeaderTypes = None + ) -> typing.Optional[HeaderTypes]: + if headers or self.headers: + merged_headers = Headers(self.headers) + merged_headers.update(headers) + return merged_headers + return headers + async def send( self, request: AsyncRequest, @@ -527,6 +538,7 @@ class AsyncClient(BaseClient): timeout: TimeoutTypes = None, ) -> AsyncResponse: url = self.base_url.join(url) + headers = self.merge_headers(headers) cookies = self.merge_cookies(cookies) request = AsyncRequest( method, @@ -608,6 +620,7 @@ class Client(BaseClient): timeout: TimeoutTypes = None, ) -> Response: url = self.base_url.join(url) + headers = self.merge_headers(headers) cookies = self.merge_cookies(cookies) request = AsyncRequest( method, diff --git a/httpx/models.py b/httpx/models.py index 2f29be5f..bea47faf 100644 --- a/httpx/models.py +++ b/httpx/models.py @@ -415,6 +415,11 @@ class Headers(typing.MutableMapping[str, str]): split_values.extend([item.strip() for item in value.split(",")]) return split_values + def update(self, headers: HeaderTypes = None) -> None: # type: ignore + headers = Headers(headers) + for header in headers: + self[header] = headers[header] + def __getitem__(self, key: str) -> str: """ Return a single header value. diff --git a/tests/client/test_headers.py b/tests/client/test_headers.py new file mode 100755 index 00000000..bf82a578 --- /dev/null +++ b/tests/client/test_headers.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +import json +from httpx import ( + __version__, + Client, + AsyncRequest, + AsyncResponse, + VerifyTypes, + CertTypes, + TimeoutTypes, + AsyncDispatcher, +) + + +class MockDispatch(AsyncDispatcher): + async def send( + self, + request: AsyncRequest, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> AsyncResponse: + if request.url.path.startswith("/echo_headers"): + request_headers = dict(request.headers.items()) + body = json.dumps({"headers": request_headers}).encode() + return AsyncResponse(200, content=body, request=request) + + +def test_client_header(): + """ + Set a header in the Client. + """ + url = "http://example.org/echo_headers" + headers = {"Example-Header": "example-value"} + + with Client(dispatch=MockDispatch(), headers=headers) as client: + response = client.get(url) + + assert response.status_code == 200 + assert response.json() == { + "headers": { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "example-header": "example-value", + "host": "example.org", + "user-agent": f"python-httpx/{__version__}", + } + } + + +def test_header_merge(): + url = "http://example.org/echo_headers" + client_headers = {"User-Agent": "python-myclient/0.2.1"} + request_headers = {"X-Auth-Token": "FooBarBazToken"} + with Client(dispatch=MockDispatch(), headers=client_headers) as client: + response = client.get(url, headers=request_headers) + + assert response.status_code == 200 + assert response.json() == { + "headers": { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "host": "example.org", + "user-agent": "python-myclient/0.2.1", + "x-auth-token": "FooBarBazToken", + } + } + + +def test_header_merge_conflicting_headers(): + url = "http://example.org/echo_headers" + client_headers = {"X-Auth-Token": "FooBar"} + request_headers = {"X-Auth-Token": "BazToken"} + with Client(dispatch=MockDispatch(), headers=client_headers) as client: + response = client.get(url, headers=request_headers) + + assert response.status_code == 200 + assert response.json() == { + "headers": { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "host": "example.org", + "user-agent": f"python-httpx/{__version__}", + "x-auth-token": "BazToken", + } + } + + +def test_header_update(): + url = "http://example.org/echo_headers" + with Client(dispatch=MockDispatch()) as client: + first_response = client.get(url) + client.headers.update( + {"User-Agent": "python-myclient/0.2.1", "Another-Header": "AThing"} + ) + second_response = client.get(url) + + assert first_response.status_code == 200 + assert first_response.json() == { + "headers": { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "host": "example.org", + "user-agent": f"python-httpx/{__version__}", + } + } + + assert second_response.status_code == 200 + assert second_response.json() == { + "headers": { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "another-header": "AThing", + "connection": "keep-alive", + "host": "example.org", + "user-agent": "python-myclient/0.2.1", + } + }