]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop `URL(allow_relative=bool)` (#1073)
authorTom Christie <tom@tomchristie.com>
Thu, 23 Jul 2020 09:16:51 +0000 (10:16 +0100)
committerGitHub <noreply@github.com>
Thu, 23 Jul 2020 09:16:51 +0000 (10:16 +0100)
* Drop URL(allow_relative=bool)

* Update httpx/_models.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Linting

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_client.py
httpx/_models.py
httpx/_utils.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/models/test_requests.py
tests/models/test_url.py

index 24812d93d70bee4db94cbb4d062b11da9b44b2cf..e7d018bb44588c06563fe8677bdcbcd940aa6b21 100644 (file)
@@ -44,6 +44,7 @@ from ._types import (
 )
 from ._utils import (
     NetRCInfo,
+    enforce_http_url,
     get_environment_proxies,
     get_logger,
     same_origin,
@@ -69,7 +70,7 @@ class BaseClient:
         trust_env: bool = True,
     ):
         if base_url is None:
-            self.base_url = URL("", allow_relative=True)
+            self.base_url = URL("")
         else:
             self.base_url = URL(base_url)
 
@@ -318,7 +319,7 @@ class BaseClient:
         """
         location = response.headers["Location"]
 
-        url = URL(location, allow_relative=True)
+        url = URL(location)
 
         # Check that we can handle the scheme
         if url.scheme and url.scheme not in ("http", "https"):
@@ -539,6 +540,8 @@ class Client(BaseClient):
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
         """
+        enforce_http_url(url)
+
         if self._proxies and not should_not_be_proxied(url):
             is_default_port = (url.scheme == "http" and url.port == 80) or (
                 url.scheme == "https" and url.port == 443
@@ -690,7 +693,6 @@ class Client(BaseClient):
         """
         Sends a single request, without handling any redirections.
         """
-
         transport = self._transport_for_url(request.url)
 
         with map_exceptions(HTTPCORE_EXC_MAP, request=request):
@@ -1071,6 +1073,8 @@ class AsyncClient(BaseClient):
         Returns the transport instance that should be used for a given URL.
         This will either be the standard connection pool, or a proxy.
         """
+        enforce_http_url(url)
+
         if self._proxies and not should_not_be_proxied(url):
             is_default_port = (url.scheme == "http" and url.port == 80) or (
                 url.scheme == "https" and url.port == 443
@@ -1130,9 +1134,6 @@ class AsyncClient(BaseClient):
         allow_redirects: bool = True,
         timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
     ) -> Response:
-        if request.url.scheme not in ("http", "https"):
-            raise InvalidURL('URL scheme must be "http" or "https".')
-
         timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
 
         auth = self._build_auth(request, auth)
@@ -1225,7 +1226,6 @@ class AsyncClient(BaseClient):
         """
         Sends a single request, without handling any redirections.
         """
-
         transport = self._transport_for_url(request.url)
 
         with map_exceptions(HTTPCORE_EXC_MAP, request=request):
index b992d9d7842fff7d0c14577efcf855b5c4efdec6..73db360d4c2b5c100d8581a71387557570521222 100644 (file)
@@ -24,7 +24,6 @@ from ._decoders import (
 from ._exceptions import (
     CookieConflict,
     HTTPStatusError,
-    InvalidURL,
     NotRedirectResponse,
     RequestNotRead,
     ResponseClosed,
@@ -55,12 +54,7 @@ from ._utils import (
 
 
 class URL:
-    def __init__(
-        self,
-        url: URLTypes,
-        allow_relative: bool = False,
-        params: QueryParamTypes = None,
-    ) -> None:
+    def __init__(self, url: URLTypes, params: QueryParamTypes = None) -> None:
         if isinstance(url, str):
             self._uri_reference = rfc3986.api.iri_reference(url).encode()
         else:
@@ -80,13 +74,6 @@ class URL:
                 query_string = str(QueryParams(params))
             self._uri_reference = self._uri_reference.copy_with(query=query_string)
 
-        # Enforce absolute URLs by default.
-        if not allow_relative:
-            if not self.scheme:
-                raise InvalidURL("No scheme included in URL.")
-            if not self.host:
-                raise InvalidURL("No host included in URL.")
-
     @property
     def scheme(self) -> str:
         return self._uri_reference.scheme or ""
@@ -195,10 +182,7 @@ class URL:
 
             kwargs["authority"] = authority
 
-        return URL(
-            self._uri_reference.copy_with(**kwargs).unsplit(),
-            allow_relative=self.is_relative_url,
-        )
+        return URL(self._uri_reference.copy_with(**kwargs).unsplit(),)
 
     def join(self, relative_url: URLTypes) -> "URL":
         """
@@ -210,7 +194,7 @@ class URL:
         # We drop any fragment portion, because RFC 3986 strictly
         # treats URLs with a fragment portion as not being absolute URLs.
         base_uri = self._uri_reference.copy_with(fragment=None)
-        relative_url = URL(relative_url, allow_relative=True)
+        relative_url = URL(relative_url)
         return URL(relative_url._uri_reference.resolve_with(base_uri).unsplit())
 
     def __hash__(self) -> int:
index b85ec758547359082f14a9e6da4f30d75527b37c..b43f38ee31bc561d35d99dd8f4558aa0ec13d64d 100644 (file)
@@ -14,6 +14,7 @@ from time import perf_counter
 from types import TracebackType
 from urllib.request import getproxies
 
+from ._exceptions import InvalidURL
 from ._types import PrimitiveData
 
 if typing.TYPE_CHECKING:  # pragma: no cover
@@ -260,6 +261,18 @@ def get_logger(name: str) -> Logger:
     return typing.cast(Logger, logger)
 
 
+def enforce_http_url(url: "URL") -> None:
+    """
+    Raise an appropriate InvalidURL for any non-HTTP URLs.
+    """
+    if not url.scheme:
+        raise InvalidURL("No scheme included in URL.")
+    if not url.host:
+        raise InvalidURL("No host included in URL.")
+    if url.scheme not in ("http", "https"):
+        raise InvalidURL('URL scheme must be "http" or "https".')
+
+
 def same_origin(url: "URL", other: "URL") -> bool:
     """
     Return 'True' if the given URLs share the same origin.
index aabee25049b4e2f8c1beb9898fb2b10068689d67..df2af9178ae4ae8b26649e638b99f146fc36f0ac 100644 (file)
@@ -23,6 +23,10 @@ async def test_get_invalid_url(server):
     async with httpx.AsyncClient() as client:
         with pytest.raises(httpx.InvalidURL):
             await client.get("invalid://example.org")
+        with pytest.raises(httpx.InvalidURL):
+            await client.get("://example.org")
+        with pytest.raises(httpx.InvalidURL):
+            await client.get("http://")
 
 
 @pytest.mark.usefixtures("async_environment")
index 75384f2662773941e3937efa45674a6124db3c7d..3d72582b08b3edba9b29a6cf7dc1762d317ba113 100644 (file)
@@ -22,6 +22,16 @@ def test_get(server):
     assert response.elapsed > timedelta(0)
 
 
+def test_get_invalid_url(server):
+    with httpx.Client() as client:
+        with pytest.raises(httpx.InvalidURL):
+            client.get("invalid://example.org")
+        with pytest.raises(httpx.InvalidURL):
+            client.get("://example.org")
+        with pytest.raises(httpx.InvalidURL):
+            client.get("http://")
+
+
 def test_build_request(server):
     url = server.url.copy_with(path="/echo_headers")
     headers = {"Custom-header": "value"}
index c72e0af96423c4247b25d45f71a2574c11ed9ceb..7fd61a15a4c081b8d537580defa0be5be32fd6d9 100644 (file)
@@ -110,11 +110,3 @@ def test_url():
     assert request.url.scheme == "https"
     assert request.url.port == 443
     assert request.url.full_path == "/abc?foo=bar"
-
-
-def test_invalid_urls():
-    with pytest.raises(httpx.InvalidURL):
-        httpx.Request("GET", "example.org")
-
-    with pytest.raises(httpx.InvalidURL):
-        httpx.Request("GET", "http:///foo")
index de9da7d2f7d49db6a409e09a73e3bec11f142ba8..daa84e0ac559d3c439fa0faf28a33d845ddaab49 100644 (file)
@@ -1,6 +1,6 @@
 import pytest
 
-from httpx import URL, InvalidURL
+from httpx import URL
 
 
 @pytest.mark.parametrize(
@@ -116,9 +116,6 @@ def test_url_join_rfc3986():
 
     url = URL("http://example.com/b/c/d;p?q")
 
-    with pytest.raises(InvalidURL):
-        assert url.join("g:h") == "g:h"
-
     assert url.join("g") == "http://example.com/b/c/g"
     assert url.join("./g") == "http://example.com/b/c/g"
     assert url.join("g/") == "http://example.com/b/c/g/"