From: Tom Christie Date: Wed, 8 May 2019 11:01:48 +0000 (+0100) Subject: Add QueryParams class X-Git-Tag: 0.3.0~38^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=211bef130b972e007a9d79e84ac38ee3c566f0b3;p=thirdparty%2Fhttpx.git Add QueryParams class --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index f637dc58..8ab2b3c9 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -23,6 +23,6 @@ from .exceptions import ( TooManyRedirects, ) from .interfaces import Adapter, BaseReader, BaseWriter -from .models import URL, Headers, Origin, Request, Response +from .models import URL, Headers, Origin, QueryParams, Request, Response __version__ = "0.2.1" diff --git a/httpcore/backends/default.py b/httpcore/backends/default.py index 1e8996fb..6e80b99a 100644 --- a/httpcore/backends/default.py +++ b/httpcore/backends/default.py @@ -28,13 +28,16 @@ OptionalTimeout = typing.Optional[TimeoutConfig] # Clients which have been opened using a `with` block, or which have # had `close()` closed, will not exhibit this issue in the first place. -_write = asyncio.selector_events._SelectorSocketTransport.write -def _fixed_write(self, exc): +_write = asyncio.selector_events._SelectorSocketTransport.write # type: ignore + + +def _fixed_write(self, data: bytes) -> None: # type: ignore if not self._loop.is_closed(): - _write(self, exc) + _write(self, data) -asyncio.selector_events._SelectorSocketTransport.write = _fixed_write + +asyncio.selector_events._SelectorSocketTransport.write = _fixed_write # type: ignore class Reader(BaseReader): @@ -138,8 +141,6 @@ async def connect( if ident is None: ident = ssl_object.selected_npn_protocol() - stream_writer.transport.set_write_buffer_limits(high=0, low=0) - reader = Reader(stream_reader=stream_reader, timeout=timeout) writer = Writer(stream_writer=stream_writer, timeout=timeout) protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 diff --git a/httpcore/models.py b/httpcore/models.py index 97a63cc8..495d6f82 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -1,5 +1,6 @@ import cgi import typing +from urllib.parse import parse_qsl, urlencode import chardet import idna @@ -30,6 +31,13 @@ from .utils import ( URLTypes = typing.Union["URL", str] +QueryParamTypes = typing.Union[ + "QueryParams", + typing.Mapping[str, str], + typing.List[typing.Tuple[typing.Any, typing.Any]], + str, +] + HeaderTypes = typing.Union[ "Headers", typing.Dict[typing.AnyStr, typing.AnyStr], @@ -40,7 +48,9 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]] class URL: - def __init__(self, url: URLTypes, allow_relative: bool = False) -> None: + def __init__( + self, url: URLTypes, allow_relative: bool = False, params: QueryParamTypes = None + ) -> None: if isinstance(url, rfc3986.uri.URIReference): self.components = url elif isinstance(url, str): @@ -156,6 +166,10 @@ class URL: class Origin: + """ + The URL scheme and authority information, as a comparable, hashable object. + """ + def __init__(self, url: URLTypes) -> None: if not isinstance(url, URL): url = URL(url) @@ -175,9 +189,74 @@ class Origin: return hash((self.is_ssl, self.host, self.port)) +class QueryParams(typing.Mapping): + def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None: + assert len(args) < 2, "Too many arguments." + assert not (args and kwargs), "Cannot mix named and unnamed arguments." + + value = args[0] if args else kwargs + + if isinstance(value, str): + items = parse_qsl(value) + elif isinstance(value, QueryParams): + items = value.multi_items() + elif isinstance(value, list): + items = value + else: + items = value.items() # type: ignore + + self._list = [(str(k), str(v)) for k, v in items] + self._dict = {str(k): str(v) for k, v in items} + + def getlist(self, key: typing.Any) -> typing.List[str]: + return [item_value for item_key, item_value in self._list if item_key == key] + + def keys(self) -> typing.KeysView: + return self._dict.keys() + + def values(self) -> typing.ValuesView: + return self._dict.values() + + def items(self) -> typing.ItemsView: + return self._dict.items() + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + return list(self._list) + + def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + if key in self._dict: + return self._dict[key] + return default + + def __getitem__(self, key: typing.Any) -> str: + return self._dict[key] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self._list) == sorted(other._list) + + def __str__(self) -> str: + return urlencode(self._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + class Headers(typing.MutableMapping[str, str]): """ - A case-insensitive multidict. + HTTP headers, as a case-insensitive multi-dict. """ def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None: @@ -200,8 +279,8 @@ class Headers(typing.MutableMapping[str, str]): @property def encoding(self) -> str: """ - Header encoding is mandated as ascii, but utf-8 or iso-8859-1 may be - seen in the wild. + Header encoding is mandated as ascii, but we allow fallbacks to utf-8 + or iso-8859-1. """ if self._encoding is None: for encoding in ["ascii", "utf-8"]: diff --git a/tests/models/test_queryparams.py b/tests/models/test_queryparams.py new file mode 100644 index 00000000..90e4a4b7 --- /dev/null +++ b/tests/models/test_queryparams.py @@ -0,0 +1,33 @@ +from httpcore import QueryParams + + +def test_queryparams(): + q = QueryParams("a=123&a=456&b=789") + assert "a" in q + assert "A" not in q + assert "c" not in q + assert q["a"] == "456" + assert q.get("a") == "456" + assert q.get("nope", default=None) is None + assert q.getlist("a") == ["123", "456"] + assert list(q.keys()) == ["a", "b"] + assert list(q.values()) == ["456", "789"] + assert list(q.items()) == [("a", "456"), ("b", "789")] + assert len(q) == 2 + assert list(q) == ["a", "b"] + assert dict(q) == {"a": "456", "b": "789"} + assert str(q) == "a=123&a=456&b=789" + assert repr(q) == "QueryParams('a=123&a=456&b=789')" + assert QueryParams({"a": "123", "b": "456"}) == QueryParams( + [("a", "123"), ("b", "456")] + ) + assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456") + assert QueryParams({"a": "123", "b": "456"}) == QueryParams( + {"b": "456", "a": "123"} + ) + assert QueryParams() == QueryParams({}) + assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456") + assert QueryParams({"a": "123", "b": "456"}) != "invalid" + + q = QueryParams([("a", "123"), ("a", "456")]) + assert QueryParams(q) == q