]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Adding params to Client (#372)
authorTyrel Souza <923113+tyrelsouza@users.noreply.github.com>
Mon, 23 Sep 2019 14:22:21 +0000 (10:22 -0400)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Mon, 23 Sep 2019 14:22:21 +0000 (09:22 -0500)
docs/api.md
httpx/client.py
httpx/models.py
tests/client/test_queryparams.py [new file with mode: 0644]

index 104ba278c13d5535cfa4989f5a1688cb7bf34d1b..0294a0a531655245281dd9bcc254fb8cb24aa912 100644 (file)
@@ -27,7 +27,8 @@
 >>> response = client.get('https://example.org')
 ```
 
-* `def __init__([auth], [headers], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])`
+* `def __init__([auth], [params], [headers], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])`
+* `.params` - **QueryParams**
 * `.headers` - **Headers**
 * `.cookies` - **Cookies**
 * `def .get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
index 653e7d86cc28b0e0c8cea1c7a82a3dcb68c1445d..30beeb75fd328345d6281c1f4c1abc0071e1a956 100644 (file)
@@ -40,6 +40,7 @@ from .models import (
     Headers,
     HeaderTypes,
     ProxiesTypes,
+    QueryParams,
     QueryParamTypes,
     RequestData,
     RequestFiles,
@@ -55,6 +56,7 @@ class BaseClient:
         self,
         *,
         auth: AuthTypes = None,
+        params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
         verify: VerifyTypes = True,
@@ -112,7 +114,11 @@ class BaseClient:
             proxies
         )
 
+        if params is None:
+            params = {}
+
         self.auth = auth
+        self._params = QueryParams(params)
         self._headers = Headers(headers)
         self._cookies = Cookies(cookies)
         self.max_redirects = max_redirects
@@ -135,6 +141,14 @@ class BaseClient:
     def cookies(self, cookies: CookieTypes) -> None:
         self._cookies = Cookies(cookies)
 
+    @property
+    def params(self) -> QueryParams:
+        return self._params
+
+    @params.setter
+    def params(self, params: QueryParamTypes) -> None:
+        self._params = QueryParams(params)
+
     def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
         pass  # pragma: no cover
 
@@ -162,6 +176,15 @@ class BaseClient:
             return merged_headers
         return headers
 
+    def merge_queryparams(
+        self, params: QueryParamTypes = None
+    ) -> typing.Optional[QueryParamTypes]:
+        if params or self.params:
+            merged_queryparams = QueryParams(self.params)
+            merged_queryparams.update(params)
+            return merged_queryparams
+        return params
+
     async def _get_response(
         self,
         request: AsyncRequest,
@@ -299,6 +322,7 @@ class BaseClient:
         url = self.merge_url(url)
         headers = self.merge_headers(headers)
         cookies = self.merge_cookies(cookies)
+        params = self.merge_queryparams(params)
         request = AsyncRequest(
             method,
             url,
index 8d4b36269233f3aa90841644e5719d80548c2e1e..261eefb24ad6e0729af64d63e390d84d3e6abb73 100644 (file)
@@ -316,9 +316,34 @@ class QueryParams(typing.Mapping[str, str]):
             return self._dict[key]
         return default
 
+    def update(self, params: QueryParamTypes = None) -> None:  # type: ignore
+        if not params:
+            return
+
+        params = QueryParams(params)
+        for param in params:
+            self[param] = params[param]
+
     def __getitem__(self, key: typing.Any) -> str:
         return self._dict[key]
 
+    def __setitem__(self, key: str, value: str) -> None:
+        self._dict[key] = value
+
+        found_indexes = []
+        for idx, (item_key, _) in enumerate(self._list):
+            if item_key == key:
+                found_indexes.append(idx)
+
+        for idx in reversed(found_indexes[1:]):
+            del self._list[idx]
+
+        if found_indexes:
+            idx = found_indexes[0]
+            self._list[idx] = (key, value)
+        else:
+            self._list.append((key, value))
+
     def __contains__(self, key: typing.Any) -> bool:
         return key in self._dict
 
diff --git a/tests/client/test_queryparams.py b/tests/client/test_queryparams.py
new file mode 100644 (file)
index 0000000..0558dd2
--- /dev/null
@@ -0,0 +1,51 @@
+import json
+
+from httpx import (
+    AsyncDispatcher,
+    AsyncRequest,
+    AsyncResponse,
+    CertTypes,
+    Client,
+    QueryParams,
+    TimeoutTypes,
+    VerifyTypes,
+)
+from httpx.models import URL
+
+
+class MockDispatch(AsyncDispatcher):
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        if request.url.path.startswith("/echo_queryparams"):
+            body = json.dumps({"ok": "ok"}).encode()
+            return AsyncResponse(200, content=body, request=request)
+
+
+def test_client_queryparams():
+    client = Client(params={"a": "b"})
+    assert isinstance(client.params, QueryParams)
+    assert client.params["a"] == "b"
+
+
+def test_client_queryparams_string():
+    client = Client(params="a=b")
+    assert isinstance(client.params, QueryParams)
+    assert client.params["a"] == "b"
+
+
+def test_client_queryparams_echo():
+    url = "http://example.org/echo_queryparams"
+    client_queryparams = "first=str"
+    request_queryparams = {"second": "dict"}
+    with Client(dispatch=MockDispatch(), params=client_queryparams) as client:
+        response = client.get(url, params=request_queryparams)
+
+    assert response.status_code == 200
+    assert response.url == URL(
+        "http://example.org/echo_queryparams?first=str&second=dict"
+    )