]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Use rfc3986 for URL handling
authorTom Christie <tom@tomchristie.com>
Tue, 30 Apr 2019 16:39:40 +0000 (17:39 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 30 Apr 2019 16:39:40 +0000 (17:39 +0100)
httpcore/__init__.py
httpcore/adapters/redirects.py
httpcore/dispatch/connection.py
httpcore/dispatch/http2.py
httpcore/models.py
httpcore/utils.py
requirements.txt
setup.py
tests/models/test_requests.py
tests/models/test_responses.py

index 618e41d8c3e86520e86cecdc725736531514843d..98c51c22984afea809705f87997008dc777a3217 100644 (file)
@@ -8,6 +8,7 @@ from .dispatch.http11 import HTTP11Connection
 from .exceptions import (
     ConnectTimeout,
     DecodingError,
+    InvalidURL,
     PoolTimeout,
     ProtocolError,
     ReadTimeout,
index 1749872719b37735d31ff103f6a6bfa3086f0754..9fea6c961468865d15d0ddf594202221b0323eb0 100644 (file)
@@ -1,12 +1,10 @@
 import typing
-from urllib.parse import urljoin, urlparse
 
 from ..config import DEFAULT_MAX_REDIRECTS
 from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
 from ..interfaces import Adapter
 from ..models import URL, Headers, Request, Response
 from ..status_codes import codes
-from ..utils import requote_uri
 
 
 class RedirectAdapter(Adapter):
@@ -98,25 +96,19 @@ class RedirectAdapter(Adapter):
         """
         location = response.headers["Location"]
 
-        # Handle redirection without scheme (see: RFC 1808 Section 4)
-        if location.startswith("//"):
-            location = f"{request.url.scheme}:{location}"
-
-        # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
-        parsed = urlparse(location)
-        if parsed.fragment == "" and request.url.fragment:
-            parsed = parsed._replace(fragment=request.url.fragment)
-        url = parsed.geturl()
+        url = URL(location, allow_relative=True)
 
         # Facilitate relative 'location' headers, as allowed by RFC 7231.
         # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
         # Compliant with RFC3986, we percent encode the url.
-        if not parsed.netloc:
-            url = urljoin(str(request.url), requote_uri(url))
-        else:
-            url = requote_uri(url)
+        if not url.is_absolute:
+            url = url.resolve_with(request.url.copy_with(fragment=None))
+
+        # Attach previous fragment if needed (RFC 7231 7.1.2)
+        if request.url.fragment and not url.fragment:
+            url = url.copy_with(fragment=request.url.fragment)
 
-        return URL(url)
+        return url
 
     def redirect_headers(self, request: Request, url: URL) -> Headers:
         """
index ffe82fcb0e4d25056d7756f669b53fb677a12fa5..c342a0c3858eb0e3535a45e4fe8bd31ed0a945aa 100644 (file)
@@ -57,7 +57,7 @@ class HTTPConnection(Adapter):
         assert isinstance(ssl, SSLConfig)
         assert isinstance(timeout, TimeoutConfig)
 
-        hostname = self.origin.hostname
+        host = self.origin.host
         port = self.origin.port
         ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
 
@@ -66,7 +66,7 @@ class HTTPConnection(Adapter):
         else:
             on_release = functools.partial(self.release_func, self)
 
-        reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
+        reader, writer, protocol = await connect(host, port, ssl_context, timeout)
         if protocol == Protocol.HTTP_2:
             self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
         else:
index 062e8f544554f25b01eb132ebead1a45a304d198..787f40e55604893458df368112cad97daf5df49e 100644 (file)
@@ -99,7 +99,7 @@ class HTTP2Connection(Adapter):
         stream_id = self.h2_state.get_next_available_stream_id()
         headers = [
             (b":method", request.method.encode()),
-            (b":authority", request.url.hostname.encode()),
+            (b":authority", request.url.host.encode()),
             (b":scheme", request.url.scheme.encode()),
             (b":path", request.url.full_path.encode()),
         ] + request.headers.raw
index 77031b68366267bf2bae79a17c2c705de44363be..3ad1fcefcddcde631d387f13311936b00dfb7696 100644 (file)
@@ -1,8 +1,8 @@
 import cgi
 import typing
-from urllib.parse import urlsplit
 
 import chardet
+import rfc3986
 
 from .config import SSLConfig, TimeoutConfig
 from .decoders import (
@@ -12,7 +12,7 @@ from .decoders import (
     IdentityDecoder,
     MultiDecoder,
 )
-from .exceptions import ResponseClosed, ResponseNotRead, StreamConsumed
+from .exceptions import InvalidURL, ResponseClosed, ResponseNotRead, StreamConsumed
 from .status_codes import codes
 from .utils import (
     get_reason_phrase,
@@ -33,42 +33,45 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]]
 
 
 class URL:
-    def __init__(self, url: URLTypes) -> None:
+    def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
         if isinstance(url, str):
-            self.components = urlsplit(url)
+            self.components = rfc3986.api.uri_reference(url).normalize()
+        elif isinstance(url, rfc3986.uri.URIReference):
+            self.components = url
         else:
             self.components = url.components
 
-        if not self.components.scheme:
-            raise ValueError("No scheme included in URL.")
-        if self.components.scheme not in ("http", "https"):
-            raise ValueError('URL scheme must be "http" or "https".')
-        if not self.components.hostname:
-            raise ValueError("No hostname included in URL.")
+        if not allow_relative:
+            if not self.scheme:
+                raise InvalidURL("No scheme included in URL.")
+            if self.scheme not in ("http", "https"):
+                raise InvalidURL('URL scheme must be "http" or "https".')
+            if not self.host:
+                raise InvalidURL("No hostname included in URL.")
 
     @property
     def scheme(self) -> str:
-        return self.components.scheme
+        return self.components.scheme or ""
 
     @property
-    def netloc(self) -> str:
-        return self.components.netloc
+    def authority(self) -> str:
+        return self.components.authority or ""
 
     @property
     def path(self) -> str:
-        return self.components.path
+        return self.components.path or "/"
 
     @property
     def query(self) -> str:
-        return self.components.query
+        return self.components.query or ""
 
     @property
     def fragment(self) -> str:
-        return self.components.fragment
+        return self.components.fragment or ""
 
     @property
-    def hostname(self) -> str:
-        return self.components.hostname
+    def host(self) -> str:
+        return self.components.host or ""
 
     @property
     def port(self) -> int:
@@ -89,10 +92,22 @@ class URL:
     def is_secure(self) -> bool:
         return self.components.scheme == "https"
 
+    @property
+    def is_absolute(self) -> bool:
+        return self.components.is_absolute()
+
     @property
     def origin(self) -> "Origin":
         return Origin(self)
 
+    def copy_with(self, **kwargs: typing.Any) -> "URL":
+        return URL(self.components.copy_with(**kwargs))
+
+    def resolve_with(self, base_url: URLTypes) -> "URL":
+        if isinstance(base_url, URL):
+            base_url = base_url.components
+        return URL(self.components.resolve_with(base_url))
+
     def __hash__(self) -> int:
         return hash(str(self))
 
@@ -100,7 +115,7 @@ class URL:
         return isinstance(other, URL) and str(self) == str(other)
 
     def __str__(self) -> str:
-        return self.components.geturl()
+        return self.components.unsplit()
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
@@ -109,23 +124,23 @@ class URL:
 
 
 class Origin:
-    def __init__(self, url: typing.Union[str, URL]) -> None:
-        if isinstance(url, str):
+    def __init__(self, url: URLTypes) -> None:
+        if not isinstance(url, URL):
             url = URL(url)
         self.is_ssl = url.scheme == "https"
-        self.hostname = url.hostname.lower()
+        self.host = url.host
         self.port = url.port
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
             isinstance(other, self.__class__)
             and self.is_ssl == other.is_ssl
-            and self.hostname == other.hostname
+            and self.host == other.host
             and self.port == other.port
         )
 
     def __hash__(self) -> int:
-        return hash((self.is_ssl, self.hostname, self.port))
+        return hash((self.is_ssl, self.host, self.port))
 
 
 class Headers(typing.MutableMapping[str, str]):
@@ -365,8 +380,8 @@ class Request:
         )
         has_accept_encoding = "accept-encoding" in self.headers
 
-        if not has_host:
-            auto_headers.append((b"host", self.url.netloc.encode("ascii")))
+        if not has_host and self.url.authority:
+            auto_headers.append((b"host", self.url.authority.encode("ascii")))
         if not has_content_length:
             if self.is_streaming:
                 auto_headers.append((b"transfer-encoding", b"chunked"))
index 33c0d3c1cc7add08242da601d78bef83e952d696..d3638585ce6bfe54751a32e88b6e335ac362be5f 100644 (file)
@@ -1,58 +1,6 @@
 import codecs
 import http
 import typing
-from urllib.parse import quote
-
-from .exceptions import InvalidURL
-
-# The unreserved URI characters (RFC 3986)
-UNRESERVED_SET = frozenset(
-    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
-)
-
-
-def unquote_unreserved(uri: str) -> str:
-    """
-    Un-escape any percent-escape sequences in a URI that are unreserved
-    characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
-    """
-    parts = uri.split("%")
-    for i in range(1, len(parts)):
-        h = parts[i][0:2]
-        if len(h) == 2 and h.isalnum():
-            try:
-                c = chr(int(h, 16))
-            except ValueError:
-                raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
-
-            if c in UNRESERVED_SET:
-                parts[i] = c + parts[i][2:]
-            else:
-                parts[i] = "%" + parts[i]
-        else:
-            parts[i] = "%" + parts[i]
-    return "".join(parts)
-
-
-def requote_uri(uri: str) -> str:
-    """
-    Re-quote the given URI.
-
-    This function passes the given URI through an unquote/quote cycle to
-    ensure that it is fully and consistently quoted.
-    """
-    safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
-    safe_without_percent = "!#$&'()*+,/:;=?@[]~"
-    try:
-        # Unquote only the unreserved characters
-        # Then quote only illegal characters (do not quote reserved,
-        # unreserved, or '%')
-        return quote(unquote_unreserved(uri), safe=safe_with_percent)
-    except InvalidURL:
-        # We couldn't unquote the given URI, so let's try quoting it, but
-        # there may be unquoted '%'s in the URI. We need to make sure they're
-        # properly quoted so they do not cause issues elsewhere.
-        return quote(uri, safe=safe_without_percent)
 
 
 def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
index dd8ea66f3209b1a6411e17328bb2becb0a1ed07c..a6c986b1e04591e6752be3d671cede59d8da5a36 100644 (file)
@@ -2,6 +2,7 @@ certifi
 chardet
 h11
 h2
+rfc3986
 
 # Optional
 brotlipy
index 6be5ebaf6a663d15af66659ae1d663a883be1c12..ea9c289760a966c8d85746bdd635d69ca86729c7 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -47,7 +47,7 @@ setup(
     author_email="tom@tomchristie.com",
     packages=get_packages("httpcore"),
     data_files=[("", ["LICENSE.md"])],
-    install_requires=["h11", "h2", "certifi", "chardet"],
+    install_requires=["h11", "h2", "certifi", "chardet", "rfc3986"],
     classifiers=[
         "Development Status :: 3 - Alpha",
         "Environment :: Web Environment",
index eb0c3668c3f2b2296dd57d9785d115ae63080c9e..77baa774256c2aba3906f5b2aa97b6e13b0f3a6d 100644 (file)
@@ -93,11 +93,11 @@ def test_url():
 
 
 def test_invalid_urls():
-    with pytest.raises(ValueError):
+    with pytest.raises(httpcore.InvalidURL):
         httpcore.Request("GET", "example.org")
 
-    with pytest.raises(ValueError):
+    with pytest.raises(httpcore.InvalidURL):
         httpcore.Request("GET", "invalid://example.org")
 
-    with pytest.raises(ValueError):
+    with pytest.raises(httpcore.InvalidURL):
         httpcore.Request("GET", "http:///foo")
index 83de2671ff0cc0bed2059c73891cd88d713101fb..16c2af4ad80b539a02b966c70f6c6b8a90eaba49 100644 (file)
@@ -37,11 +37,13 @@ def test_response_default_encoding():
 
 
 def test_response_set_explicit_encoding():
-    headers = {"Content-Type": "text-plain; charset=utf-8"}  # Deliberately incorrect charset
+    headers = {
+        "Content-Type": "text-plain; charset=utf-8"
+    }  # Deliberately incorrect charset
     response = httpcore.Response(
         200, content="Latin 1: ÿ".encode("latin-1"), headers=headers
     )
-    response.encoding = 'latin-1'
+    response.encoding = "latin-1"
     assert response.text == "Latin 1: ÿ"
     assert response.encoding == "latin-1"
 
@@ -71,6 +73,38 @@ async def test_read_response():
     assert response.is_closed
 
 
+@pytest.mark.asyncio
+async def test_raw_interface():
+    response = httpcore.Response(200, content=b"Hello, world!")
+
+    raw = b""
+    async for part in response.raw():
+        raw += part
+    assert raw == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_stream_interface():
+    response = httpcore.Response(200, content=b"Hello, world!")
+
+    content = b""
+    async for part in response.stream():
+        content += part
+    assert content == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_stream_interface_after_read():
+    response = httpcore.Response(200, content=b"Hello, world!")
+
+    await response.read()
+
+    content = b""
+    async for part in response.stream():
+        content += part
+    assert content == b"Hello, world!"
+
+
 @pytest.mark.asyncio
 async def test_streaming_response():
     response = httpcore.Response(200, content=streaming_body())