from .exceptions import (
ConnectTimeout,
DecodingError,
+ InvalidURL,
PoolTimeout,
ProtocolError,
ReadTimeout,
import typing
-from urllib.parse import urljoin, urlparse
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
from ..models import URL, Headers, Request, Response
from ..status_codes import codes
-from ..utils import requote_uri
class RedirectAdapter(Adapter):
"""
location = response.headers["Location"]
- # Handle redirection without scheme (see: RFC 1808 Section 4)
- if location.startswith("//"):
- location = f"{request.url.scheme}:{location}"
-
- # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
- parsed = urlparse(location)
- if parsed.fragment == "" and request.url.fragment:
- parsed = parsed._replace(fragment=request.url.fragment)
- url = parsed.geturl()
+ url = URL(location, allow_relative=True)
# Facilitate relative 'location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
# Compliant with RFC3986, we percent encode the url.
- if not parsed.netloc:
- url = urljoin(str(request.url), requote_uri(url))
- else:
- url = requote_uri(url)
+ if not url.is_absolute:
+ url = url.resolve_with(request.url.copy_with(fragment=None))
+
+ # Attach previous fragment if needed (RFC 7231 7.1.2)
+ if request.url.fragment and not url.fragment:
+ url = url.copy_with(fragment=request.url.fragment)
- return URL(url)
+ return url
def redirect_headers(self, request: Request, url: URL) -> Headers:
"""
assert isinstance(ssl, SSLConfig)
assert isinstance(timeout, TimeoutConfig)
- hostname = self.origin.hostname
+ host = self.origin.host
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
else:
on_release = functools.partial(self.release_func, self)
- reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
+ reader, writer, protocol = await connect(host, port, ssl_context, timeout)
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
else:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
(b":method", request.method.encode()),
- (b":authority", request.url.hostname.encode()),
+ (b":authority", request.url.host.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers.raw
import cgi
import typing
-from urllib.parse import urlsplit
import chardet
+import rfc3986
from .config import SSLConfig, TimeoutConfig
from .decoders import (
IdentityDecoder,
MultiDecoder,
)
-from .exceptions import ResponseClosed, ResponseNotRead, StreamConsumed
+from .exceptions import InvalidURL, ResponseClosed, ResponseNotRead, StreamConsumed
from .status_codes import codes
from .utils import (
get_reason_phrase,
class URL:
- def __init__(self, url: URLTypes) -> None:
+ def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
if isinstance(url, str):
- self.components = urlsplit(url)
+ self.components = rfc3986.api.uri_reference(url).normalize()
+ elif isinstance(url, rfc3986.uri.URIReference):
+ self.components = url
else:
self.components = url.components
- if not self.components.scheme:
- raise ValueError("No scheme included in URL.")
- if self.components.scheme not in ("http", "https"):
- raise ValueError('URL scheme must be "http" or "https".')
- if not self.components.hostname:
- raise ValueError("No hostname included in URL.")
+ if not allow_relative:
+ if not self.scheme:
+ raise InvalidURL("No scheme included in URL.")
+ if self.scheme not in ("http", "https"):
+ raise InvalidURL('URL scheme must be "http" or "https".')
+ if not self.host:
+ raise InvalidURL("No hostname included in URL.")
@property
def scheme(self) -> str:
- return self.components.scheme
+ return self.components.scheme or ""
@property
- def netloc(self) -> str:
- return self.components.netloc
+ def authority(self) -> str:
+ return self.components.authority or ""
@property
def path(self) -> str:
- return self.components.path
+ return self.components.path or "/"
@property
def query(self) -> str:
- return self.components.query
+ return self.components.query or ""
@property
def fragment(self) -> str:
- return self.components.fragment
+ return self.components.fragment or ""
@property
- def hostname(self) -> str:
- return self.components.hostname
+ def host(self) -> str:
+ return self.components.host or ""
@property
def port(self) -> int:
def is_secure(self) -> bool:
return self.components.scheme == "https"
+ @property
+ def is_absolute(self) -> bool:
+ return self.components.is_absolute()
+
@property
def origin(self) -> "Origin":
return Origin(self)
+ def copy_with(self, **kwargs: typing.Any) -> "URL":
+ return URL(self.components.copy_with(**kwargs))
+
+ def resolve_with(self, base_url: URLTypes) -> "URL":
+ if isinstance(base_url, URL):
+ base_url = base_url.components
+ return URL(self.components.resolve_with(base_url))
+
def __hash__(self) -> int:
return hash(str(self))
return isinstance(other, URL) and str(self) == str(other)
def __str__(self) -> str:
- return self.components.geturl()
+ return self.components.unsplit()
def __repr__(self) -> str:
class_name = self.__class__.__name__
class Origin:
- def __init__(self, url: typing.Union[str, URL]) -> None:
- if isinstance(url, str):
+ def __init__(self, url: URLTypes) -> None:
+ if not isinstance(url, URL):
url = URL(url)
self.is_ssl = url.scheme == "https"
- self.hostname = url.hostname.lower()
+ self.host = url.host
self.port = url.port
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_ssl == other.is_ssl
- and self.hostname == other.hostname
+ and self.host == other.host
and self.port == other.port
)
def __hash__(self) -> int:
- return hash((self.is_ssl, self.hostname, self.port))
+ return hash((self.is_ssl, self.host, self.port))
class Headers(typing.MutableMapping[str, str]):
)
has_accept_encoding = "accept-encoding" in self.headers
- if not has_host:
- auto_headers.append((b"host", self.url.netloc.encode("ascii")))
+ if not has_host and self.url.authority:
+ auto_headers.append((b"host", self.url.authority.encode("ascii")))
if not has_content_length:
if self.is_streaming:
auto_headers.append((b"transfer-encoding", b"chunked"))
import codecs
import http
import typing
-from urllib.parse import quote
-
-from .exceptions import InvalidURL
-
-# The unreserved URI characters (RFC 3986)
-UNRESERVED_SET = frozenset(
- "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
-)
-
-
-def unquote_unreserved(uri: str) -> str:
- """
- Un-escape any percent-escape sequences in a URI that are unreserved
- characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
- """
- parts = uri.split("%")
- for i in range(1, len(parts)):
- h = parts[i][0:2]
- if len(h) == 2 and h.isalnum():
- try:
- c = chr(int(h, 16))
- except ValueError:
- raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
-
- if c in UNRESERVED_SET:
- parts[i] = c + parts[i][2:]
- else:
- parts[i] = "%" + parts[i]
- else:
- parts[i] = "%" + parts[i]
- return "".join(parts)
-
-
-def requote_uri(uri: str) -> str:
- """
- Re-quote the given URI.
-
- This function passes the given URI through an unquote/quote cycle to
- ensure that it is fully and consistently quoted.
- """
- safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
- safe_without_percent = "!#$&'()*+,/:;=?@[]~"
- try:
- # Unquote only the unreserved characters
- # Then quote only illegal characters (do not quote reserved,
- # unreserved, or '%')
- return quote(unquote_unreserved(uri), safe=safe_with_percent)
- except InvalidURL:
- # We couldn't unquote the given URI, so let's try quoting it, but
- # there may be unquoted '%'s in the URI. We need to make sure they're
- # properly quoted so they do not cause issues elsewhere.
- return quote(uri, safe=safe_without_percent)
def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
chardet
h11
h2
+rfc3986
# Optional
brotlipy
author_email="tom@tomchristie.com",
packages=get_packages("httpcore"),
data_files=[("", ["LICENSE.md"])],
- install_requires=["h11", "h2", "certifi", "chardet"],
+ install_requires=["h11", "h2", "certifi", "chardet", "rfc3986"],
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",
def test_invalid_urls():
- with pytest.raises(ValueError):
+ with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "example.org")
- with pytest.raises(ValueError):
+ with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "invalid://example.org")
- with pytest.raises(ValueError):
+ with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "http:///foo")
def test_response_set_explicit_encoding():
- headers = {"Content-Type": "text-plain; charset=utf-8"} # Deliberately incorrect charset
+ headers = {
+ "Content-Type": "text-plain; charset=utf-8"
+ } # Deliberately incorrect charset
response = httpcore.Response(
200, content="Latin 1: ÿ".encode("latin-1"), headers=headers
)
- response.encoding = 'latin-1'
+ response.encoding = "latin-1"
assert response.text == "Latin 1: ÿ"
assert response.encoding == "latin-1"
assert response.is_closed
+@pytest.mark.asyncio
+async def test_raw_interface():
+ response = httpcore.Response(200, content=b"Hello, world!")
+
+ raw = b""
+ async for part in response.raw():
+ raw += part
+ assert raw == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_stream_interface():
+ response = httpcore.Response(200, content=b"Hello, world!")
+
+ content = b""
+ async for part in response.stream():
+ content += part
+ assert content == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_stream_interface_after_read():
+ response = httpcore.Response(200, content=b"Hello, world!")
+
+ await response.read()
+
+ content = b""
+ async for part in response.stream():
+ content += part
+ assert content == b"Hello, world!"
+
+
@pytest.mark.asyncio
async def test_streaming_response():
response = httpcore.Response(200, content=streaming_body())