]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Expand URL interface (#1601)
authorTom Christie <tom@tomchristie.com>
Tue, 27 Apr 2021 08:01:14 +0000 (09:01 +0100)
committerGitHub <noreply@github.com>
Tue, 27 Apr 2021 08:01:14 +0000 (09:01 +0100)
* Expand URL interface

* Add URL query param manipulation methods

httpx/_models.py
tests/models/test_url.py

index 7b749ceeb4a9a7d4e510fc10fe95666674a325d7..9325f444fdafa024c0a60f338d725df699260c63 100644 (file)
@@ -112,7 +112,7 @@ class URL:
     """
 
     def __init__(
-        self, url: typing.Union["URL", str, RawURL] = "", params: QueryParamTypes = None
+        self, url: typing.Union["URL", str, RawURL] = "", **kwargs: typing.Any
     ) -> None:
         if isinstance(url, (str, tuple)):
             if isinstance(url, tuple):
@@ -144,14 +144,8 @@ class URL:
                 f"Invalid type for url.  Expected str or httpx.URL, got {type(url)}: {url!r}"
             )
 
-        # Add any query parameters, merging with any in the URL if needed.
-        if params:
-            if self._uri_reference.query:
-                url_params = QueryParams(self._uri_reference.query).merge(params)
-                query_string = str(url_params)
-            else:
-                query_string = str(QueryParams(params))
-            self._uri_reference = self._uri_reference.copy_with(query=query_string)
+        if kwargs:
+            self._uri_reference = self.copy_with(**kwargs)._uri_reference
 
     @property
     def scheme(self) -> str:
@@ -293,12 +287,27 @@ class URL:
     def query(self) -> bytes:
         """
         The URL query string, as raw bytes, excluding the leading b"?".
-        Note that URL decoding can only be applied on URL query strings
-        at the point of decoding the individual parameter names/values.
+
+        This is neccessarily a bytewise interface, because we cannot
+        perform URL decoding of this representation until we've parsed
+        the keys and values into a QueryParams instance.
+
+        For example:
+
+        url = httpx.URL("https://example.com/?filter=some%20search%20terms")
+        assert url.query == b"filter=some%20search%20terms"
         """
         query = self._uri_reference.query or ""
         return query.encode("ascii")
 
+    @property
+    def params(self) -> "QueryParams":
+        """
+        The URL query parameters, neatly parsed and packaged into an immutable
+        multidict representation.
+        """
+        return QueryParams(self._uri_reference.query)
+
     @property
     def raw_path(self) -> bytes:
         """
@@ -382,6 +391,7 @@ class URL:
             "query": bytes,
             "raw_path": bytes,
             "fragment": str,
+            "params": object,
         }
         for key, value in kwargs.items():
             if key not in allowed:
@@ -434,12 +444,28 @@ class URL:
             if kwargs.get("path") is not None:
                 kwargs["path"] = quote(kwargs["path"])
 
-            # Ensure query=<str> for rfc3986
             if kwargs.get("query") is not None:
+                # Ensure query=<str> for rfc3986
                 kwargs["query"] = kwargs["query"].decode("ascii")
 
+            if "params" in kwargs:
+                params = kwargs.pop("params")
+                kwargs["query"] = None if not params else str(QueryParams(params))
+
         return URL(self._uri_reference.copy_with(**kwargs).unsplit())
 
+    def copy_set_param(self, key: str, value: typing.Any = None) -> "URL":
+        return self.copy_with(params=self.params.set(key, value))
+
+    def copy_add_param(self, key: str, value: typing.Any = None) -> "URL":
+        return self.copy_with(params=self.params.add(key, value))
+
+    def copy_remove_param(self, key: str) -> "URL":
+        return self.copy_with(params=self.params.remove(key))
+
+    def copy_merge_params(self, params: QueryParamTypes) -> "URL":
+        return self.copy_with(params=self.params.merge(params))
+
     def join(self, url: URLTypes) -> "URL":
         """
         Return an absolute URL, using this URL as the base.
@@ -595,7 +621,7 @@ class QueryParams(typing.Mapping[str, str]):
             return self._dict[str(key)][0]
         return default
 
-    def get_list(self, key: typing.Any) -> typing.List[str]:
+    def get_list(self, key: str) -> typing.List[str]:
         """
         Get all values from the query param for a given key.
 
@@ -606,7 +632,7 @@ class QueryParams(typing.Mapping[str, str]):
         """
         return list(self._dict.get(str(key), []))
 
-    def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
+    def set(self, key: str, value: typing.Any = None) -> "QueryParams":
         """
         Return a new QueryParams instance, setting the value of a key.
 
@@ -621,7 +647,7 @@ class QueryParams(typing.Mapping[str, str]):
         q._dict[str(key)] = [primitive_value_to_str(value)]
         return q
 
-    def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
+    def add(self, key: str, value: typing.Any = None) -> "QueryParams":
         """
         Return a new QueryParams instance, setting or appending the value of a key.
 
@@ -636,7 +662,7 @@ class QueryParams(typing.Mapping[str, str]):
         q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
         return q
 
-    def remove(self, key: typing.Any) -> "QueryParams":
+    def remove(self, key: str) -> "QueryParams":
         """
         Return a new QueryParams instance, removing the value of a key.
 
@@ -681,6 +707,9 @@ class QueryParams(typing.Mapping[str, str]):
     def __len__(self) -> int:
         return len(self._dict)
 
+    def __bool__(self) -> bool:
+        return bool(self._dict)
+
     def __hash__(self) -> int:
         return hash(str(self))
 
@@ -971,7 +1000,9 @@ class Request:
             self.method = method.decode("ascii").upper()
         else:
             self.method = method.upper()
-        self.url = URL(url, params=params)
+        self.url = URL(url)
+        if params is not None:
+            self.url = self.url.copy_merge_params(params=params)
         self.headers = Headers(headers)
         if cookies:
             Cookies(cookies).set_cookie_header(self)
index 393503107221bf3bfa186725bd7ec4bfe258deef..c28d070f88b6e67edfd68fcc8083f74da7c19d32 100644 (file)
@@ -100,11 +100,13 @@ def test_url_eq_str():
 def test_url_params():
     url = httpx.URL("https://example.org:123/path/to/somewhere", params={"a": "123"})
     assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
+    assert url.params == httpx.QueryParams({"a": "123"})
 
     url = httpx.URL(
         "https://example.org:123/path/to/somewhere?b=456", params={"a": "123"}
     )
-    assert str(url) == "https://example.org:123/path/to/somewhere?b=456&a=123"
+    assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
+    assert url.params == httpx.QueryParams({"a": "123"})
 
 
 def test_url_join():
@@ -122,6 +124,38 @@ def test_url_join():
     assert url.join("../../somewhere-else") == "https://example.org:123/somewhere-else"
 
 
+def test_url_set_param_manipulation():
+    """
+    Some basic URL query parameter manipulation.
+    """
+    url = httpx.URL("https://example.org:123/?a=123")
+    assert url.copy_set_param("a", "456") == "https://example.org:123/?a=456"
+
+
+def test_url_add_param_manipulation():
+    """
+    Some basic URL query parameter manipulation.
+    """
+    url = httpx.URL("https://example.org:123/?a=123")
+    assert url.copy_add_param("a", "456") == "https://example.org:123/?a=123&a=456"
+
+
+def test_url_remove_param_manipulation():
+    """
+    Some basic URL query parameter manipulation.
+    """
+    url = httpx.URL("https://example.org:123/?a=123")
+    assert url.copy_remove_param("a") == "https://example.org:123/"
+
+
+def test_url_merge_params_manipulation():
+    """
+    Some basic URL query parameter manipulation.
+    """
+    url = httpx.URL("https://example.org:123/?a=123")
+    assert url.copy_merge_params({"b": "456"}) == "https://example.org:123/?a=123&b=456"
+
+
 def test_relative_url_join():
     url = httpx.URL("/path/to/somewhere")
     assert url.join("/somewhere-else") == "/somewhere-else"