]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Rollin'
authorTom Christie <tom@tomchristie.com>
Fri, 26 Apr 2019 16:00:47 +0000 (17:00 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 26 Apr 2019 16:00:47 +0000 (17:00 +0100)
16 files changed:
httpcore/__init__.py
httpcore/adapters.py
httpcore/auth.py
httpcore/client.py
httpcore/cookies.py
httpcore/decoders.py
httpcore/exceptions.py
httpcore/http11.py
httpcore/http2.py
httpcore/models.py
httpcore/redirects.py
httpcore/status_codes.py [new file with mode: 0644]
httpcore/structures.py [new file with mode: 0644]
httpcore/sync.py
httpcore/utils.py [new file with mode: 0644]
tests/test_requests.py

index 30ed38e6072eb39dadacd4fb39027d39661263c2..ef8d6a5d449feb5dbc46e035f36e421d33a03e9f 100644 (file)
@@ -14,7 +14,7 @@ from .exceptions import (
 )
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
-from .models import URL, Origin, Request, Response
+from .models import URL, Headers, Origin, Request, Response
 from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
 from .sync import SyncClient, SyncConnectionPool
 
index 0c14e89ceb765823cf2013cbe59f33a0bf3acf2e..72d5880f0de6cbead08bb6dcdb147f490de52c7c 100644 (file)
@@ -10,7 +10,7 @@ class Adapter:
         method: str,
         url: typing.Union[str, URL],
         *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         **options: typing.Any,
     ) -> Response:
index 949d8a9fb687cd6e9835ee8f04f46959b3c6a763..38c434a241433133f7f21fe72b84a850d0c94207 100644 (file)
@@ -15,4 +15,4 @@ class AuthAdapter(Adapter):
         return await self.dispatch.send(request, **options)
 
     async def close(self) -> None:
-        self.dispatch.close()
+        await self.dispatch.close()
index 022ae7ffe5378453883a54e5d76051effc11f235..2d98ba30784da4551b815d4af28dd27ef086950d 100644 (file)
@@ -40,7 +40,7 @@ class Client:
         url: typing.Union[str, URL],
         *,
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         stream: bool = False,
         allow_redirects: bool = True,
         ssl: typing.Optional[SSLConfig] = None,
@@ -61,7 +61,7 @@ class Client:
         self,
         url: typing.Union[str, URL],
         *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         stream: bool = False,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
@@ -75,7 +75,7 @@ class Client:
         url: typing.Union[str, URL],
         *,
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         stream: bool = False,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
index f6fd2b0381f5d0e730dcd6d9faff0b4a7c9331f0..ed9e97c4f97463b0ded987bd04310a2cb7aaf188 100644 (file)
@@ -15,4 +15,4 @@ class CookieAdapter(Adapter):
         return await self.dispatch.send(request, **options)
 
     async def close(self) -> None:
-        self.dispatch.close()
+        await self.dispatch.close()
index b56745c49a5ca64c9c852f2eeae1d20f18c3dc32..e47143d740a7c734d60f1ccae8074b806d6debf3 100644 (file)
@@ -110,17 +110,17 @@ class MultiDecoder(Decoder):
 
 
 SUPPORTED_DECODERS = {
-    b"identity": IdentityDecoder,
-    b"deflate": DeflateDecoder,
-    b"gzip": GZipDecoder,
-    b"br": BrotliDecoder,
+    "identity": IdentityDecoder,
+    "deflate": DeflateDecoder,
+    "gzip": GZipDecoder,
+    "br": BrotliDecoder,
 }
 
 
 if brotli is None:
-    SUPPORTED_DECODERS.pop(b"br")  # pragma: nocover
+    SUPPORTED_DECODERS.pop("br")  # pragma: nocover
 
 
-ACCEPT_ENCODING = b", ".join(
-    [key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
+ACCEPT_ENCODING = ", ".join(
+    [key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
 )
index 94154b073e29b9e644e3595ba4514237f84177a5..337b74a2e4f96a8fb282ad4fe3879deef3eea186 100644 (file)
@@ -52,3 +52,8 @@ class ResponseClosed(Exception):
     Attempted to read or stream response content, but the request has been
     closed without loading the body.
     """
+
+
+class InvalidURL(Exception):
+    """
+    """
index 0075b524cead5ffb4c0addebd3dc8e0ec25d0034..97e8b6a82a41dc4788795ee4e223b2fd6f4e7375 100644 (file)
@@ -51,7 +51,7 @@ class HTTP11Connection(Adapter):
         #  Start sending the request.
         method = request.method.encode()
         target = request.url.full_path
-        headers = request.headers
+        headers = request.headers.raw
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event, timeout)
 
index e89029cf9467cb8554c98d791a3eabb58a28fba0..40448176230f39e6d20cf4bd0eaf8e76c0c54cae 100644 (file)
@@ -96,7 +96,7 @@ class HTTP2Connection(Adapter):
             (b":authority", request.url.hostname.encode()),
             (b":scheme", request.url.scheme.encode()),
             (b":path", request.url.full_path.encode()),
-        ] + request.headers
+        ] + request.headers.raw
         self.h2_state.send_headers(stream_id, headers)
         data_to_send = self.h2_state.data_to_send()
         await self.writer.write(data_to_send, timeout)
index 7229a76c18c8c0db594e407a61a2ec7623452829..791bbc906fc5a1310d2de85a5ac147aaed8ea710 100644 (file)
@@ -39,6 +39,10 @@ class URL:
     def query(self) -> str:
         return self.components.query
 
+    @property
+    def fragment(self) -> str:
+        return self.components.fragment
+
     @property
     def hostname(self) -> str:
         return self.components.hostname
@@ -87,32 +91,140 @@ class Origin:
         return hash((self.is_ssl, self.hostname, self.port))
 
 
+class Headers(typing.MutableMapping[str, str]):
+    """
+    A case-insensitive multidict.
+    """
+
+    def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
+        self._list = [(k.lower(), v) for k, v in headers]
+
+    @property
+    def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+        return self._list
+
+    def keys(self) -> typing.List[str]:  # type: ignore
+        return [key.decode("latin-1") for key, value in self._list]
+
+    def values(self) -> typing.List[str]:  # type: ignore
+        return [value.decode("latin-1") for key, value in self._list]
+
+    def items(self) -> typing.List[typing.Tuple[str, str]]:  # type: ignore
+        return [
+            (key.decode("latin-1"), value.decode("latin-1"))
+            for key, value in self._list
+        ]
+
+    def get(self, key: str, default: typing.Any = None) -> typing.Any:
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def getlist(self, key: str) -> typing.List[str]:
+        get_header_key = key.lower().encode("latin-1")
+        return [
+            item_value.decode("latin-1")
+            for item_key, item_value in self._list
+            if item_key == get_header_key
+        ]
+
+    def __getitem__(self, key: str) -> str:
+        get_header_key = key.lower().encode("latin-1")
+        for header_key, header_value in self._list:
+            if header_key == get_header_key:
+                return header_value.decode("latin-1")
+        raise KeyError(key)
+
+    def __setitem__(self, key: str, value: str) -> None:
+        """
+        Set the header `key` to `value`, removing any duplicate entries.
+        Retains insertion order.
+        """
+        set_key = key.lower().encode("latin-1")
+        set_value = value.encode("latin-1")
+
+        found_indexes = []
+        for idx, (item_key, item_value) in enumerate(self._list):
+            if item_key == set_key:
+                found_indexes.append(idx)
+
+        for idx in reversed(found_indexes[1:]):
+            del self._list[idx]
+
+        if found_indexes:
+            idx = found_indexes[0]
+            self._list[idx] = (set_key, set_value)
+        else:
+            self._list.append((set_key, set_value))
+
+    def __delitem__(self, key: str) -> None:
+        """
+        Remove the header `key`.
+        """
+        del_key = key.lower().encode("latin-1")
+
+        pop_indexes = []
+        for idx, (item_key, item_value) in enumerate(self._list):
+            if item_key == del_key:
+                pop_indexes.append(idx)
+
+        for idx in reversed(pop_indexes):
+            del self._list[idx]
+
+    def __contains__(self, key: typing.Any) -> bool:
+        get_header_key = key.lower().encode("latin-1")
+        for header_key, header_value in self._list:
+            if header_key == get_header_key:
+                return True
+        return False
+
+    def __iter__(self) -> typing.Iterator[typing.Any]:
+        return iter(self.keys())
+
+    def __len__(self) -> int:
+        return len(self._list)
+
+    def __eq__(self, other: typing.Any) -> bool:
+        if not isinstance(other, Headers):
+            return False
+        return sorted(self._list) == sorted(other._list)
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        as_dict = dict(self.items())
+        if len(as_dict) == len(self):
+            return f"{class_name}({as_dict!r})"
+        return f"{class_name}(raw={self.raw!r})"
+
+
 class Request:
     def __init__(
         self,
         method: str,
         url: typing.Union[str, URL],
         *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
     ):
         self.method = method.upper()
         self.url = URL(url) if isinstance(url, str) else url
-        self.headers = list(headers)
         if isinstance(body, bytes):
             self.is_streaming = False
             self.body = body
         else:
             self.is_streaming = True
             self.body_aiter = body
-        self.headers = self._auto_headers() + self.headers
+        self.headers = self.build_headers(headers)
 
-    def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+    def build_headers(
+        self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
+    ) -> Headers:
         has_host = False
         has_content_length = False
         has_accept_encoding = False
 
-        for header, value in self.headers:
+        for header, value in init_headers:
             header = header.strip().lower()
             if header == b"host":
                 has_host = True
@@ -121,20 +233,20 @@ class Request:
             elif header == b"accept-encoding":
                 has_accept_encoding = True
 
-        headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+        auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
 
         if not has_host:
-            headers.append((b"host", self.url.netloc.encode("ascii")))
+            auto_headers.append((b"host", self.url.netloc.encode("ascii")))
         if not has_content_length:
             if self.is_streaming:
-                headers.append((b"transfer-encoding", b"chunked"))
+                auto_headers.append((b"transfer-encoding", b"chunked"))
             elif self.body:
                 content_length = str(len(self.body)).encode()
-                headers.append((b"content-length", content_length))
+                auto_headers.append((b"content-length", content_length))
         if not has_accept_encoding:
-            headers.append((b"accept-encoding", ACCEPT_ENCODING))
+            auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
 
-        return headers
+        return Headers(auto_headers + init_headers)
 
     async def stream(self) -> typing.AsyncIterator[bytes]:
         if self.is_streaming:
@@ -151,7 +263,7 @@ class Response:
         *,
         reason: typing.Optional[str] = None,
         protocol: typing.Optional[str] = None,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         on_close: typing.Callable = None,
     ):
@@ -164,18 +276,17 @@ class Response:
         else:
             self.reason = reason
         self.protocol = protocol
-        self.headers = list(headers)
+        self.headers = Headers(headers)
         self.on_close = on_close
         self.is_closed = False
         self.is_streamed = False
 
         decoders = []  # type: typing.List[Decoder]
-        for header, value in self.headers:
-            if header.strip().lower() == b"content-encoding":
-                for part in value.split(b","):
-                    part = part.strip().lower()
-                    decoder_cls = SUPPORTED_DECODERS[part]
-                    decoders.append(decoder_cls())
+        value = self.headers.get("content-encoding", "identity")
+        for part in value.split(","):
+            part = part.strip().lower()
+            decoder_cls = SUPPORTED_DECODERS[part]
+            decoders.append(decoder_cls())
 
         if len(decoders) == 0:
             self.decoder = IdentityDecoder()  # type: Decoder
index 0657ebca5cced3139eb0df8d94cb9603ad714653..557b7468a608fc8c76b5b1971dfcd78d7fd24667 100644 (file)
@@ -1,8 +1,11 @@
 import typing
+from urllib.parse import urljoin, urlparse
 
 from .adapters import Adapter
 from .exceptions import TooManyRedirects
-from .models import Request, Response
+from .models import URL, Request, Response
+from .status_codes import codes
+from .utils import requote_uri
 
 
 class RedirectAdapter(Adapter):
@@ -29,7 +32,55 @@ class RedirectAdapter(Adapter):
         return response
 
     async def close(self) -> None:
-        self.dispatch.close()
+        await self.dispatch.close()
 
     def build_redirect_request(self, request: Request, response: Response) -> Request:
+        method = self.redirect_method(request, response)
+        url = self.redirect_url(request, response)
         raise NotImplementedError()
+
+    def redirect_method(self, request: Request, response: Response) -> str:
+        """
+        When being redirected we may want to change the method of the request
+        based on certain specs or browser behavior.
+        """
+        method = request.method
+
+        # https://tools.ietf.org/html/rfc7231#section-6.4.4
+        if response.status_code == codes["see_other"] and method != "HEAD":
+            method = "GET"
+
+        # Do what the browsers do, despite standards...
+        # First, turn 302s into GETs.
+        if response.status_code == codes["found"] and method != "HEAD":
+            method = "GET"
+
+        # Second, if a POST is responded to with a 301, turn it into a GET.
+        # This bizarre behaviour is explained in Issue 1704.
+        if response.status_code == codes["moved"] and method == "POST":
+            method = "GET"
+
+        return method
+
+    def redirect_url(self, request: Request, response: Response) -> URL:
+        location = response.headers["Location"]
+
+        # Handle redirection without scheme (see: RFC 1808 Section 4)
+        if location.startswith("//"):
+            location = f"{request.url.scheme}:{location}"
+
+        # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
+        parsed = urlparse(location)
+        if parsed.fragment == "" and request.url.fragment:
+            parsed = parsed._replace(fragment=request.url.fragment)
+        url = parsed.geturl()
+
+        # Facilitate relative 'location' headers, as allowed by RFC 7231.
+        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+        # Compliant with RFC3986, we percent encode the url.
+        if not parsed.netloc:
+            url = urljoin(str(request.url), requote_uri(url))
+        else:
+            url = requote_uri(url)
+
+        return URL(url)
diff --git a/httpcore/status_codes.py b/httpcore/status_codes.py
new file mode 100644 (file)
index 0000000..9265a71
--- /dev/null
@@ -0,0 +1,123 @@
+"""
+The ``codes`` object defines a mapping from common names for HTTP statuses
+to their numerical codes, accessible either as attributes or as dictionary
+items.
+>>> requests.codes['temporary_redirect']
+307
+>>> requests.codes.teapot
+418
+Some codes have multiple names, and both upper- and lower-case versions of
+the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
+``codes.okay`` all correspond to the HTTP status code 200.
+"""
+
+import typing
+
+from .structures import LookupDict
+
+_codes = {
+    # Informational.
+    100: ("continue",),
+    101: ("switching_protocols",),
+    102: ("processing",),
+    103: ("checkpoint",),
+    122: ("uri_too_long", "request_uri_too_long"),
+    200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", "✓"),
+    201: ("created",),
+    202: ("accepted",),
+    203: ("non_authoritative_info", "non_authoritative_information"),
+    204: ("no_content",),
+    205: ("reset_content", "reset"),
+    206: ("partial_content", "partial"),
+    207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
+    208: ("already_reported",),
+    226: ("im_used",),
+    # Redirection.
+    300: ("multiple_choices",),
+    301: ("moved_permanently", "moved", "\\o-"),
+    302: ("found",),
+    303: ("see_other", "other"),
+    304: ("not_modified",),
+    305: ("use_proxy",),
+    306: ("switch_proxy",),
+    307: ("temporary_redirect", "temporary_moved", "temporary"),
+    308: (
+        "permanent_redirect",
+        "resume_incomplete",
+        "resume",
+    ),  # These 2 to be removed in 3.0
+    # Client Error.
+    400: ("bad_request", "bad"),
+    401: ("unauthorized",),
+    402: ("payment_required", "payment"),
+    403: ("forbidden",),
+    404: ("not_found", "-o-"),
+    405: ("method_not_allowed", "not_allowed"),
+    406: ("not_acceptable",),
+    407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
+    408: ("request_timeout", "timeout"),
+    409: ("conflict",),
+    410: ("gone",),
+    411: ("length_required",),
+    412: ("precondition_failed", "precondition"),
+    413: ("request_entity_too_large",),
+    414: ("request_uri_too_large",),
+    415: ("unsupported_media_type", "unsupported_media", "media_type"),
+    416: (
+        "requested_range_not_satisfiable",
+        "requested_range",
+        "range_not_satisfiable",
+    ),
+    417: ("expectation_failed",),
+    418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
+    421: ("misdirected_request",),
+    422: ("unprocessable_entity", "unprocessable"),
+    423: ("locked",),
+    424: ("failed_dependency", "dependency"),
+    425: ("unordered_collection", "unordered"),
+    426: ("upgrade_required", "upgrade"),
+    428: ("precondition_required", "precondition"),
+    429: ("too_many_requests", "too_many"),
+    431: ("header_fields_too_large", "fields_too_large"),
+    444: ("no_response", "none"),
+    449: ("retry_with", "retry"),
+    450: ("blocked_by_windows_parental_controls", "parental_controls"),
+    451: ("unavailable_for_legal_reasons", "legal_reasons"),
+    499: ("client_closed_request",),
+    # Server Error.
+    500: ("internal_server_error", "server_error", "/o\\", "✗"),
+    501: ("not_implemented",),
+    502: ("bad_gateway",),
+    503: ("service_unavailable", "unavailable"),
+    504: ("gateway_timeout",),
+    505: ("http_version_not_supported", "http_version"),
+    506: ("variant_also_negotiates",),
+    507: ("insufficient_storage",),
+    509: ("bandwidth_limit_exceeded", "bandwidth"),
+    510: ("not_extended",),
+    511: ("network_authentication_required", "network_auth", "network_authentication"),
+}  # type: typing.Dict[int, typing.Sequence[str]]
+
+codes = LookupDict(name="status_codes")
+
+
+def _init() -> None:
+    for code, titles in _codes.items():
+        for title in titles:
+            setattr(codes, title, code)
+            if not title.startswith(("\\", "/")):
+                setattr(codes, title.upper(), code)
+
+    def doc(code: int) -> str:
+        names = ", ".join("``%s``" % n for n in _codes[code])
+        return "* %d: %s" % (code, names)
+
+    global __doc__
+    __doc__ = (
+        __doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes))
+        if __doc__ is not None
+        else None
+    )
+
+
+_init()
diff --git a/httpcore/structures.py b/httpcore/structures.py
new file mode 100644 (file)
index 0000000..127af5c
--- /dev/null
@@ -0,0 +1,20 @@
+import typing
+
+
+class LookupDict(dict):
+    """Dictionary lookup object."""
+
+    def __init__(self, name: str = None) -> None:
+        self.name = name
+        super(LookupDict, self).__init__()
+
+    def __repr__(self) -> str:
+        return "<lookup '%s'>" % (self.name)
+
+    def __getitem__(self, key: typing.Any) -> typing.Any:
+        # We allow fall-through here, so values default to None
+
+        return self.__dict__.get(key, None)
+
+    def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+        return self.__dict__.get(key, default)
index 737d3fcfe2e8c61557d18895ce60f55067a18c38..89b586287a2c57d377c07bccc9cff635efb9fdd5 100644 (file)
@@ -5,7 +5,7 @@ from types import TracebackType
 from .adapters import Adapter
 from .config import SSLConfig, TimeoutConfig
 from .connection_pool import ConnectionPool
-from .models import URL, Response
+from .models import URL, Headers, Response
 
 
 class SyncResponse:
@@ -22,7 +22,7 @@ class SyncResponse:
         return self._response.reason
 
     @property
-    def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+    def headers(self) -> Headers:
         return self._response.headers
 
     @property
@@ -54,7 +54,7 @@ class SyncClient:
         method: str,
         url: typing.Union[str, URL],
         *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         **options: typing.Any
     ) -> SyncResponse:
diff --git a/httpcore/utils.py b/httpcore/utils.py
new file mode 100644 (file)
index 0000000..dc7f1d2
--- /dev/null
@@ -0,0 +1,52 @@
+from urllib.parse import quote
+
+from .exceptions import InvalidURL
+
+# The unreserved URI characters (RFC 3986)
+UNRESERVED_SET = frozenset(
+    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
+)
+
+
+def unquote_unreserved(uri: str) -> str:
+    """Un-escape any percent-escape sequences in a URI that are unreserved
+    characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
+    :rtype: str
+    """
+    parts = uri.split("%")
+    for i in range(1, len(parts)):
+        h = parts[i][0:2]
+        if len(h) == 2 and h.isalnum():
+            try:
+                c = chr(int(h, 16))
+            except ValueError:
+                raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
+
+            if c in UNRESERVED_SET:
+                parts[i] = c + parts[i][2:]
+            else:
+                parts[i] = "%" + parts[i]
+        else:
+            parts[i] = "%" + parts[i]
+    return "".join(parts)
+
+
+def requote_uri(uri: str) -> str:
+    """
+    Re-quote the given URI.
+
+    This function passes the given URI through an unquote/quote cycle to
+    ensure that it is fully and consistently quoted.
+    """
+    safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
+    safe_without_percent = "!#$&'()*+,/:;=?@[]~"
+    try:
+        # Unquote only the unreserved characters
+        # Then quote only illegal characters (do not quote reserved,
+        # unreserved, or '%')
+        return quote(unquote_unreserved(uri), safe=safe_with_percent)
+    except InvalidURL:
+        # We couldn't unquote the given URI, so let's try quoting it, but
+        # there may be unquoted '%'s in the URI. We need to make sure they're
+        # properly quoted so they do not cause issues elsewhere.
+        return quote(uri, safe=safe_without_percent)
index c88b70a037d7205b084c648cf9ef28e2495580a2..840dde1a8aa910107f108568b1f15ce068abd72e 100644 (file)
@@ -5,19 +5,20 @@ import httpcore
 
 def test_host_header():
     request = httpcore.Request("GET", "http://example.org")
-    assert request.headers == [
-        (b"host", b"example.org"),
-        (b"accept-encoding", b"deflate, gzip, br"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
+    )
 
 
 def test_content_length_header():
     request = httpcore.Request("POST", "http://example.org", body=b"test 123")
-    assert request.headers == [
-        (b"host", b"example.org"),
-        (b"content-length", b"8"),
-        (b"accept-encoding", b"deflate, gzip, br"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [
+            (b"host", b"example.org"),
+            (b"content-length", b"8"),
+            (b"accept-encoding", b"deflate, gzip, br"),
+        ]
+    )
 
 
 def test_transfer_encoding_header():
@@ -27,31 +28,31 @@ def test_transfer_encoding_header():
     body = streaming_body(b"test 123")
 
     request = httpcore.Request("POST", "http://example.org", body=body)
-    assert request.headers == [
-        (b"host", b"example.org"),
-        (b"transfer-encoding", b"chunked"),
-        (b"accept-encoding", b"deflate, gzip, br"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [
+            (b"host", b"example.org"),
+            (b"transfer-encoding", b"chunked"),
+            (b"accept-encoding", b"deflate, gzip, br"),
+        ]
+    )
 
 
 def test_override_host_header():
     headers = [(b"host", b"1.2.3.4:80")]
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
-    assert request.headers == [
-        (b"accept-encoding", b"deflate, gzip, br"),
-        (b"host", b"1.2.3.4:80"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
+    )
 
 
 def test_override_accept_encoding_header():
     headers = [(b"accept-encoding", b"identity")]
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
-    assert request.headers == [
-        (b"host", b"example.org"),
-        (b"accept-encoding", b"identity"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [(b"host", b"example.org"), (b"accept-encoding", b"identity")]
+    )
 
 
 def test_override_content_length_header():
@@ -62,11 +63,13 @@ def test_override_content_length_header():
     headers = [(b"content-length", b"8")]
 
     request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
-    assert request.headers == [
-        (b"host", b"example.org"),
-        (b"accept-encoding", b"deflate, gzip, br"),
-        (b"content-length", b"8"),
-    ]
+    assert request.headers == httpcore.Headers(
+        [
+            (b"host", b"example.org"),
+            (b"accept-encoding", b"deflate, gzip, br"),
+            (b"content-length", b"8"),
+        ]
+    )
 
 
 def test_url():