]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add QueryParams class
authorTom Christie <tom@tomchristie.com>
Wed, 8 May 2019 11:01:48 +0000 (12:01 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 8 May 2019 11:01:48 +0000 (12:01 +0100)
httpcore/__init__.py
httpcore/backends/default.py
httpcore/models.py
tests/models/test_queryparams.py [new file with mode: 0644]

index f637dc583606a66e78d16b98451b0595ed884d4a..8ab2b3c9199852d8add3e1406bb8953ef644c4a4 100644 (file)
@@ -23,6 +23,6 @@ from .exceptions import (
     TooManyRedirects,
 )
 from .interfaces import Adapter, BaseReader, BaseWriter
-from .models import URL, Headers, Origin, Request, Response
+from .models import URL, Headers, Origin, QueryParams, Request, Response
 
 __version__ = "0.2.1"
index 1e8996fba7691f624f440f4987c052eb88d07c9f..6e80b99a8e1fe7901bcb750702f6e96ad9bed94b 100644 (file)
@@ -28,13 +28,16 @@ OptionalTimeout = typing.Optional[TimeoutConfig]
 # Clients which have been opened using a `with` block, or which have
 # had `close()` closed, will not exhibit this issue in the first place.
 
-_write = asyncio.selector_events._SelectorSocketTransport.write
 
-def _fixed_write(self, exc):
+_write = asyncio.selector_events._SelectorSocketTransport.write  # type: ignore
+
+
+def _fixed_write(self, data: bytes) -> None:  # type: ignore
     if not self._loop.is_closed():
-        _write(self, exc)
+        _write(self, data)
 
-asyncio.selector_events._SelectorSocketTransport.write = _fixed_write
+
+asyncio.selector_events._SelectorSocketTransport.write = _fixed_write  # type: ignore
 
 
 class Reader(BaseReader):
@@ -138,8 +141,6 @@ async def connect(
         if ident is None:
             ident = ssl_object.selected_npn_protocol()
 
-    stream_writer.transport.set_write_buffer_limits(high=0, low=0)
-
     reader = Reader(stream_reader=stream_reader, timeout=timeout)
     writer = Writer(stream_writer=stream_writer, timeout=timeout)
     protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
index 97a63cc83ce884a0bf7a0d7f4742f6770cc955af..495d6f829ce819c10e1f2a4a13a058971ed5b219 100644 (file)
@@ -1,5 +1,6 @@
 import cgi
 import typing
+from urllib.parse import parse_qsl, urlencode
 
 import chardet
 import idna
@@ -30,6 +31,13 @@ from .utils import (
 
 URLTypes = typing.Union["URL", str]
 
+QueryParamTypes = typing.Union[
+    "QueryParams",
+    typing.Mapping[str, str],
+    typing.List[typing.Tuple[typing.Any, typing.Any]],
+    str,
+]
+
 HeaderTypes = typing.Union[
     "Headers",
     typing.Dict[typing.AnyStr, typing.AnyStr],
@@ -40,7 +48,9 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]]
 
 
 class URL:
-    def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
+    def __init__(
+        self, url: URLTypes, allow_relative: bool = False, params: QueryParamTypes = None
+    ) -> None:
         if isinstance(url, rfc3986.uri.URIReference):
             self.components = url
         elif isinstance(url, str):
@@ -156,6 +166,10 @@ class URL:
 
 
 class Origin:
+    """
+    The URL scheme and authority information, as a comparable, hashable object.
+    """
+
     def __init__(self, url: URLTypes) -> None:
         if not isinstance(url, URL):
             url = URL(url)
@@ -175,9 +189,74 @@ class Origin:
         return hash((self.is_ssl, self.host, self.port))
 
 
+class QueryParams(typing.Mapping):
+    def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
+        assert len(args) < 2, "Too many arguments."
+        assert not (args and kwargs), "Cannot mix named and unnamed arguments."
+
+        value = args[0] if args else kwargs
+
+        if isinstance(value, str):
+            items = parse_qsl(value)
+        elif isinstance(value, QueryParams):
+            items = value.multi_items()
+        elif isinstance(value, list):
+            items = value
+        else:
+            items = value.items()  # type: ignore
+
+        self._list = [(str(k), str(v)) for k, v in items]
+        self._dict = {str(k): str(v) for k, v in items}
+
+    def getlist(self, key: typing.Any) -> typing.List[str]:
+        return [item_value for item_key, item_value in self._list if item_key == key]
+
+    def keys(self) -> typing.KeysView:
+        return self._dict.keys()
+
+    def values(self) -> typing.ValuesView:
+        return self._dict.values()
+
+    def items(self) -> typing.ItemsView:
+        return self._dict.items()
+
+    def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
+        return list(self._list)
+
+    def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+        if key in self._dict:
+            return self._dict[key]
+        return default
+
+    def __getitem__(self, key: typing.Any) -> str:
+        return self._dict[key]
+
+    def __contains__(self, key: typing.Any) -> bool:
+        return key in self._dict
+
+    def __iter__(self) -> typing.Iterator[typing.Any]:
+        return iter(self.keys())
+
+    def __len__(self) -> int:
+        return len(self._dict)
+
+    def __eq__(self, other: typing.Any) -> bool:
+        if not isinstance(other, self.__class__):
+            return False
+        return sorted(self._list) == sorted(other._list)
+
+    def __str__(self) -> str:
+        return urlencode(self._list)
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        query_string = str(self)
+        return f"{class_name}({query_string!r})"
+
+
 class Headers(typing.MutableMapping[str, str]):
     """
-    A case-insensitive multidict.
+    HTTP headers, as a case-insensitive multi-dict.
     """
 
     def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
@@ -200,8 +279,8 @@ class Headers(typing.MutableMapping[str, str]):
     @property
     def encoding(self) -> str:
         """
-        Header encoding is mandated as ascii, but utf-8 or iso-8859-1 may be
-        seen in the wild.
+        Header encoding is mandated as ascii, but we allow fallbacks to utf-8
+        or iso-8859-1.
         """
         if self._encoding is None:
             for encoding in ["ascii", "utf-8"]:
diff --git a/tests/models/test_queryparams.py b/tests/models/test_queryparams.py
new file mode 100644 (file)
index 0000000..90e4a4b
--- /dev/null
@@ -0,0 +1,33 @@
+from httpcore import QueryParams
+
+
+def test_queryparams():
+    q = QueryParams("a=123&a=456&b=789")
+    assert "a" in q
+    assert "A" not in q
+    assert "c" not in q
+    assert q["a"] == "456"
+    assert q.get("a") == "456"
+    assert q.get("nope", default=None) is None
+    assert q.getlist("a") == ["123", "456"]
+    assert list(q.keys()) == ["a", "b"]
+    assert list(q.values()) == ["456", "789"]
+    assert list(q.items()) == [("a", "456"), ("b", "789")]
+    assert len(q) == 2
+    assert list(q) == ["a", "b"]
+    assert dict(q) == {"a": "456", "b": "789"}
+    assert str(q) == "a=123&a=456&b=789"
+    assert repr(q) == "QueryParams('a=123&a=456&b=789')"
+    assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
+        [("a", "123"), ("b", "456")]
+    )
+    assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
+    assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
+        {"b": "456", "a": "123"}
+    )
+    assert QueryParams() == QueryParams({})
+    assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
+    assert QueryParams({"a": "123", "b": "456"}) != "invalid"
+
+    q = QueryParams([("a", "123"), ("a", "456")])
+    assert QueryParams(q) == q