From: Tom Christie Date: Tue, 30 Apr 2019 16:39:40 +0000 (+0100) Subject: Use rfc3986 for URL handling X-Git-Tag: 0.3.0~56 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=302fe93df9fa44218b5089379fb9f3236169dec0;p=thirdparty%2Fhttpx.git Use rfc3986 for URL handling --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 618e41d8..98c51c22 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -8,6 +8,7 @@ from .dispatch.http11 import HTTP11Connection from .exceptions import ( ConnectTimeout, DecodingError, + InvalidURL, PoolTimeout, ProtocolError, ReadTimeout, diff --git a/httpcore/adapters/redirects.py b/httpcore/adapters/redirects.py index 17498727..9fea6c96 100644 --- a/httpcore/adapters/redirects.py +++ b/httpcore/adapters/redirects.py @@ -1,12 +1,10 @@ import typing -from urllib.parse import urljoin, urlparse from ..config import DEFAULT_MAX_REDIRECTS from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects from ..interfaces import Adapter from ..models import URL, Headers, Request, Response from ..status_codes import codes -from ..utils import requote_uri class RedirectAdapter(Adapter): @@ -98,25 +96,19 @@ class RedirectAdapter(Adapter): """ location = response.headers["Location"] - # Handle redirection without scheme (see: RFC 1808 Section 4) - if location.startswith("//"): - location = f"{request.url.scheme}:{location}" - - # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2) - parsed = urlparse(location) - if parsed.fragment == "" and request.url.fragment: - parsed = parsed._replace(fragment=request.url.fragment) - url = parsed.geturl() + url = URL(location, allow_relative=True) # Facilitate relative 'location' headers, as allowed by RFC 7231. # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') # Compliant with RFC3986, we percent encode the url. - if not parsed.netloc: - url = urljoin(str(request.url), requote_uri(url)) - else: - url = requote_uri(url) + if not url.is_absolute: + url = url.resolve_with(request.url.copy_with(fragment=None)) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) - return URL(url) + return url def redirect_headers(self, request: Request, url: URL) -> Headers: """ diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index ffe82fcb..c342a0c3 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -57,7 +57,7 @@ class HTTPConnection(Adapter): assert isinstance(ssl, SSLConfig) assert isinstance(timeout, TimeoutConfig) - hostname = self.origin.hostname + host = self.origin.host port = self.origin.port ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None @@ -66,7 +66,7 @@ class HTTPConnection(Adapter): else: on_release = functools.partial(self.release_func, self) - reader, writer, protocol = await connect(hostname, port, ssl_context, timeout) + reader, writer, protocol = await connect(host, port, ssl_context, timeout) if protocol == Protocol.HTTP_2: self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release) else: diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 062e8f54..787f40e5 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -99,7 +99,7 @@ class HTTP2Connection(Adapter): stream_id = self.h2_state.get_next_available_stream_id() headers = [ (b":method", request.method.encode()), - (b":authority", request.url.hostname.encode()), + (b":authority", request.url.host.encode()), (b":scheme", request.url.scheme.encode()), (b":path", request.url.full_path.encode()), ] + request.headers.raw diff --git a/httpcore/models.py b/httpcore/models.py index 77031b68..3ad1fcef 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -1,8 +1,8 @@ import cgi import typing -from urllib.parse import urlsplit import chardet +import rfc3986 from .config import SSLConfig, TimeoutConfig from .decoders import ( @@ -12,7 +12,7 @@ from .decoders import ( IdentityDecoder, MultiDecoder, ) -from .exceptions import ResponseClosed, ResponseNotRead, StreamConsumed +from .exceptions import InvalidURL, ResponseClosed, ResponseNotRead, StreamConsumed from .status_codes import codes from .utils import ( get_reason_phrase, @@ -33,42 +33,45 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]] class URL: - def __init__(self, url: URLTypes) -> None: + def __init__(self, url: URLTypes, allow_relative: bool = False) -> None: if isinstance(url, str): - self.components = urlsplit(url) + self.components = rfc3986.api.uri_reference(url).normalize() + elif isinstance(url, rfc3986.uri.URIReference): + self.components = url else: self.components = url.components - if not self.components.scheme: - raise ValueError("No scheme included in URL.") - if self.components.scheme not in ("http", "https"): - raise ValueError('URL scheme must be "http" or "https".') - if not self.components.hostname: - raise ValueError("No hostname included in URL.") + if not allow_relative: + if not self.scheme: + raise InvalidURL("No scheme included in URL.") + if self.scheme not in ("http", "https"): + raise InvalidURL('URL scheme must be "http" or "https".') + if not self.host: + raise InvalidURL("No hostname included in URL.") @property def scheme(self) -> str: - return self.components.scheme + return self.components.scheme or "" @property - def netloc(self) -> str: - return self.components.netloc + def authority(self) -> str: + return self.components.authority or "" @property def path(self) -> str: - return self.components.path + return self.components.path or "/" @property def query(self) -> str: - return self.components.query + return self.components.query or "" @property def fragment(self) -> str: - return self.components.fragment + return self.components.fragment or "" @property - def hostname(self) -> str: - return self.components.hostname + def host(self) -> str: + return self.components.host or "" @property def port(self) -> int: @@ -89,10 +92,22 @@ class URL: def is_secure(self) -> bool: return self.components.scheme == "https" + @property + def is_absolute(self) -> bool: + return self.components.is_absolute() + @property def origin(self) -> "Origin": return Origin(self) + def copy_with(self, **kwargs: typing.Any) -> "URL": + return URL(self.components.copy_with(**kwargs)) + + def resolve_with(self, base_url: URLTypes) -> "URL": + if isinstance(base_url, URL): + base_url = base_url.components + return URL(self.components.resolve_with(base_url)) + def __hash__(self) -> int: return hash(str(self)) @@ -100,7 +115,7 @@ class URL: return isinstance(other, URL) and str(self) == str(other) def __str__(self) -> str: - return self.components.geturl() + return self.components.unsplit() def __repr__(self) -> str: class_name = self.__class__.__name__ @@ -109,23 +124,23 @@ class URL: class Origin: - def __init__(self, url: typing.Union[str, URL]) -> None: - if isinstance(url, str): + def __init__(self, url: URLTypes) -> None: + if not isinstance(url, URL): url = URL(url) self.is_ssl = url.scheme == "https" - self.hostname = url.hostname.lower() + self.host = url.host self.port = url.port def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) and self.is_ssl == other.is_ssl - and self.hostname == other.hostname + and self.host == other.host and self.port == other.port ) def __hash__(self) -> int: - return hash((self.is_ssl, self.hostname, self.port)) + return hash((self.is_ssl, self.host, self.port)) class Headers(typing.MutableMapping[str, str]): @@ -365,8 +380,8 @@ class Request: ) has_accept_encoding = "accept-encoding" in self.headers - if not has_host: - auto_headers.append((b"host", self.url.netloc.encode("ascii"))) + if not has_host and self.url.authority: + auto_headers.append((b"host", self.url.authority.encode("ascii"))) if not has_content_length: if self.is_streaming: auto_headers.append((b"transfer-encoding", b"chunked")) diff --git a/httpcore/utils.py b/httpcore/utils.py index 33c0d3c1..d3638585 100644 --- a/httpcore/utils.py +++ b/httpcore/utils.py @@ -1,58 +1,6 @@ import codecs import http import typing -from urllib.parse import quote - -from .exceptions import InvalidURL - -# The unreserved URI characters (RFC 3986) -UNRESERVED_SET = frozenset( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~" -) - - -def unquote_unreserved(uri: str) -> str: - """ - Un-escape any percent-escape sequences in a URI that are unreserved - characters. This leaves all reserved, illegal and non-ASCII bytes encoded. - """ - parts = uri.split("%") - for i in range(1, len(parts)): - h = parts[i][0:2] - if len(h) == 2 and h.isalnum(): - try: - c = chr(int(h, 16)) - except ValueError: - raise InvalidURL("Invalid percent-escape sequence: '%s'" % h) - - if c in UNRESERVED_SET: - parts[i] = c + parts[i][2:] - else: - parts[i] = "%" + parts[i] - else: - parts[i] = "%" + parts[i] - return "".join(parts) - - -def requote_uri(uri: str) -> str: - """ - Re-quote the given URI. - - This function passes the given URI through an unquote/quote cycle to - ensure that it is fully and consistently quoted. - """ - safe_with_percent = "!#$%&'()*+,/:;=?@[]~" - safe_without_percent = "!#$&'()*+,/:;=?@[]~" - try: - # Unquote only the unreserved characters - # Then quote only illegal characters (do not quote reserved, - # unreserved, or '%') - return quote(unquote_unreserved(uri), safe=safe_with_percent) - except InvalidURL: - # We couldn't unquote the given URI, so let's try quoting it, but - # there may be unquoted '%'s in the URI. We need to make sure they're - # properly quoted so they do not cause issues elsewhere. - return quote(uri, safe=safe_without_percent) def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes: diff --git a/requirements.txt b/requirements.txt index dd8ea66f..a6c986b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ certifi chardet h11 h2 +rfc3986 # Optional brotlipy diff --git a/setup.py b/setup.py index 6be5ebaf..ea9c2897 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( author_email="tom@tomchristie.com", packages=get_packages("httpcore"), data_files=[("", ["LICENSE.md"])], - install_requires=["h11", "h2", "certifi", "chardet"], + install_requires=["h11", "h2", "certifi", "chardet", "rfc3986"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Web Environment", diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index eb0c3668..77baa774 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -93,11 +93,11 @@ def test_url(): def test_invalid_urls(): - with pytest.raises(ValueError): + with pytest.raises(httpcore.InvalidURL): httpcore.Request("GET", "example.org") - with pytest.raises(ValueError): + with pytest.raises(httpcore.InvalidURL): httpcore.Request("GET", "invalid://example.org") - with pytest.raises(ValueError): + with pytest.raises(httpcore.InvalidURL): httpcore.Request("GET", "http:///foo") diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 83de2671..16c2af4a 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -37,11 +37,13 @@ def test_response_default_encoding(): def test_response_set_explicit_encoding(): - headers = {"Content-Type": "text-plain; charset=utf-8"} # Deliberately incorrect charset + headers = { + "Content-Type": "text-plain; charset=utf-8" + } # Deliberately incorrect charset response = httpcore.Response( 200, content="Latin 1: ÿ".encode("latin-1"), headers=headers ) - response.encoding = 'latin-1' + response.encoding = "latin-1" assert response.text == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -71,6 +73,38 @@ async def test_read_response(): assert response.is_closed +@pytest.mark.asyncio +async def test_raw_interface(): + response = httpcore.Response(200, content=b"Hello, world!") + + raw = b"" + async for part in response.raw(): + raw += part + assert raw == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_stream_interface(): + response = httpcore.Response(200, content=b"Hello, world!") + + content = b"" + async for part in response.stream(): + content += part + assert content == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_stream_interface_after_read(): + response = httpcore.Response(200, content=b"Hello, world!") + + await response.read() + + content = b"" + async for part in response.stream(): + content += part + assert content == b"Hello, world!" + + @pytest.mark.asyncio async def test_streaming_response(): response = httpcore.Response(200, content=streaming_body())