# Clients which have been opened using a `with` block, or which have
# had `close()` closed, will not exhibit this issue in the first place.
-_write = asyncio.selector_events._SelectorSocketTransport.write
-def _fixed_write(self, exc):
+_write = asyncio.selector_events._SelectorSocketTransport.write # type: ignore
+
+
+def _fixed_write(self, data: bytes) -> None: # type: ignore
if not self._loop.is_closed():
- _write(self, exc)
+ _write(self, data)
-asyncio.selector_events._SelectorSocketTransport.write = _fixed_write
+
+asyncio.selector_events._SelectorSocketTransport.write = _fixed_write # type: ignore
class Reader(BaseReader):
if ident is None:
ident = ssl_object.selected_npn_protocol()
- stream_writer.transport.set_write_buffer_limits(high=0, low=0)
-
reader = Reader(stream_reader=stream_reader, timeout=timeout)
writer = Writer(stream_writer=stream_writer, timeout=timeout)
protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
import cgi
import typing
+from urllib.parse import parse_qsl, urlencode
import chardet
import idna
URLTypes = typing.Union["URL", str]
+QueryParamTypes = typing.Union[
+ "QueryParams",
+ typing.Mapping[str, str],
+ typing.List[typing.Tuple[typing.Any, typing.Any]],
+ str,
+]
+
HeaderTypes = typing.Union[
"Headers",
typing.Dict[typing.AnyStr, typing.AnyStr],
class URL:
- def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
+ def __init__(
+ self, url: URLTypes, allow_relative: bool = False, params: QueryParamTypes = None
+ ) -> None:
if isinstance(url, rfc3986.uri.URIReference):
self.components = url
elif isinstance(url, str):
class Origin:
+ """
+ The URL scheme and authority information, as a comparable, hashable object.
+ """
+
def __init__(self, url: URLTypes) -> None:
if not isinstance(url, URL):
url = URL(url)
return hash((self.is_ssl, self.host, self.port))
+class QueryParams(typing.Mapping):
+ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
+ assert len(args) < 2, "Too many arguments."
+ assert not (args and kwargs), "Cannot mix named and unnamed arguments."
+
+ value = args[0] if args else kwargs
+
+ if isinstance(value, str):
+ items = parse_qsl(value)
+ elif isinstance(value, QueryParams):
+ items = value.multi_items()
+ elif isinstance(value, list):
+ items = value
+ else:
+ items = value.items() # type: ignore
+
+ self._list = [(str(k), str(v)) for k, v in items]
+ self._dict = {str(k): str(v) for k, v in items}
+
+ def getlist(self, key: typing.Any) -> typing.List[str]:
+ return [item_value for item_key, item_value in self._list if item_key == key]
+
+ def keys(self) -> typing.KeysView:
+ return self._dict.keys()
+
+ def values(self) -> typing.ValuesView:
+ return self._dict.values()
+
+ def items(self) -> typing.ItemsView:
+ return self._dict.items()
+
+ def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
+ return list(self._list)
+
+ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+ if key in self._dict:
+ return self._dict[key]
+ return default
+
+ def __getitem__(self, key: typing.Any) -> str:
+ return self._dict[key]
+
+ def __contains__(self, key: typing.Any) -> bool:
+ return key in self._dict
+
+ def __iter__(self) -> typing.Iterator[typing.Any]:
+ return iter(self.keys())
+
+ def __len__(self) -> int:
+ return len(self._dict)
+
+ def __eq__(self, other: typing.Any) -> bool:
+ if not isinstance(other, self.__class__):
+ return False
+ return sorted(self._list) == sorted(other._list)
+
+ def __str__(self) -> str:
+ return urlencode(self._list)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ query_string = str(self)
+ return f"{class_name}({query_string!r})"
+
+
class Headers(typing.MutableMapping[str, str]):
"""
- A case-insensitive multidict.
+ HTTP headers, as a case-insensitive multi-dict.
"""
def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
@property
def encoding(self) -> str:
"""
- Header encoding is mandated as ascii, but utf-8 or iso-8859-1 may be
- seen in the wild.
+ Header encoding is mandated as ascii, but we allow fallbacks to utf-8
+ or iso-8859-1.
"""
if self._encoding is None:
for encoding in ["ascii", "utf-8"]:
--- /dev/null
+from httpcore import QueryParams
+
+
+def test_queryparams():
+ q = QueryParams("a=123&a=456&b=789")
+ assert "a" in q
+ assert "A" not in q
+ assert "c" not in q
+ assert q["a"] == "456"
+ assert q.get("a") == "456"
+ assert q.get("nope", default=None) is None
+ assert q.getlist("a") == ["123", "456"]
+ assert list(q.keys()) == ["a", "b"]
+ assert list(q.values()) == ["456", "789"]
+ assert list(q.items()) == [("a", "456"), ("b", "789")]
+ assert len(q) == 2
+ assert list(q) == ["a", "b"]
+ assert dict(q) == {"a": "456", "b": "789"}
+ assert str(q) == "a=123&a=456&b=789"
+ assert repr(q) == "QueryParams('a=123&a=456&b=789')"
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
+ [("a", "123"), ("b", "456")]
+ )
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
+ {"b": "456", "a": "123"}
+ )
+ assert QueryParams() == QueryParams({})
+ assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
+ assert QueryParams({"a": "123", "b": "456"}) != "invalid"
+
+ q = QueryParams([("a", "123"), ("a", "456")])
+ assert QueryParams(q) == q