]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Base URL improvements (#1130)
authorTom Christie <tom@tomchristie.com>
Wed, 5 Aug 2020 17:56:25 +0000 (18:56 +0100)
committerGitHub <noreply@github.com>
Wed, 5 Aug 2020 17:56:25 +0000 (18:56 +0100)
* URL.join(url=...), not URL.join(relative_url=...)

* Fix URL.join()

* Support no argument 'httpx.URL()' usage

* Support client.base_url as a property

* Resolve base_url joining behaviour

* Fix coverage

* Update _client.py

httpx/_client.py
httpx/_models.py
tests/client/test_client.py
tests/client/test_properties.py

index 0a2e525a2f68d98e78b8f4e8643209a41d81bf5c..645c83e0f171f02832cd6bd4de9669ebe71b761f 100644 (file)
@@ -66,13 +66,10 @@ class BaseClient:
         cookies: CookieTypes = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        base_url: URLTypes = None,
+        base_url: URLTypes = "",
         trust_env: bool = True,
     ):
-        if base_url is None:
-            self.base_url = URL("")
-        else:
-            self.base_url = URL(base_url)
+        self._base_url = self._enforce_trailing_slash(URL(base_url))
 
         self.auth = auth
         self._params = QueryParams(params)
@@ -87,6 +84,11 @@ class BaseClient:
     def trust_env(self) -> bool:
         return self._trust_env
 
+    def _enforce_trailing_slash(self, url: URL) -> URL:
+        if url.path.endswith("/"):
+            return url
+        return url.copy_with(path=url.path + "/")
+
     def _get_proxy_map(
         self, proxies: typing.Optional[ProxiesTypes], allow_env_proxies: bool,
     ) -> typing.Dict[str, typing.Optional[Proxy]]:
@@ -107,6 +109,17 @@ class BaseClient:
             proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies
             return {"all": proxy}
 
+    @property
+    def base_url(self) -> URL:
+        """
+        Base URL to use when sending requests with relative URLs.
+        """
+        return self._base_url
+
+    @base_url.setter
+    def base_url(self, url: URLTypes) -> None:
+        self._base_url = self._enforce_trailing_slash(URL(url))
+
     @property
     def headers(self) -> Headers:
         """
@@ -208,7 +221,13 @@ class BaseClient:
         Merge a URL argument together with any 'base_url' on the client,
         to create the URL used for the outgoing request.
         """
-        return self.base_url.join(url)
+        merge_url = URL(url)
+        if merge_url.is_relative_url:
+            # We always ensure the base_url paths include the trailing '/',
+            # and always strip any leading '/' from the merge URL.
+            merge_url = merge_url.copy_with(path=merge_url.path.lstrip("/"))
+            return self.base_url.join(merge_url)
+        return merge_url
 
     def _merge_cookies(
         self, cookies: CookieTypes = None
@@ -441,7 +460,7 @@ class Client(BaseClient):
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        base_url: URLTypes = None,
+        base_url: URLTypes = "",
         transport: httpcore.SyncHTTPTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
@@ -972,7 +991,7 @@ class AsyncClient(BaseClient):
         limits: Limits = DEFAULT_LIMITS,
         pool_limits: Limits = None,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        base_url: URLTypes = None,
+        base_url: URLTypes = "",
         transport: httpcore.AsyncHTTPTransport = None,
         app: typing.Callable = None,
         trust_env: bool = True,
index 41c7a274e6a628d27b5760a3194410c4b643118f..e89851529190c279ff398b1f5d1a83359d9fc9a6 100644 (file)
@@ -55,7 +55,7 @@ from ._utils import (
 
 
 class URL:
-    def __init__(self, url: URLTypes, params: QueryParamTypes = None) -> None:
+    def __init__(self, url: URLTypes = "", params: QueryParamTypes = None) -> None:
         if isinstance(url, str):
             self._uri_reference = rfc3986.api.iri_reference(url).encode()
         else:
index ea57c11c3545de0ba5d696231a3b3a9fc74829f3..b05735ea5e20812ec6797de84ad4e27520176d5e 100644 (file)
@@ -174,13 +174,31 @@ def test_base_url(server):
     assert response.url == base_url
 
 
-def test_merge_url():
+def test_merge_absolute_url():
     client = httpx.Client(base_url="https://www.example.com/")
-    request = client.build_request("GET", "http://www.example.com")
-    assert request.url.scheme == "http"
+    request = client.build_request("GET", "http://www.example.com/")
+    assert request.url == httpx.URL("http://www.example.com/")
     assert not request.url.is_ssl
 
 
+def test_merge_relative_url():
+    client = httpx.Client(base_url="https://www.example.com/")
+    request = client.build_request("GET", "/testing/123")
+    assert request.url == httpx.URL("https://www.example.com/testing/123")
+
+
+def test_merge_relative_url_with_path():
+    client = httpx.Client(base_url="https://www.example.com/some/path")
+    request = client.build_request("GET", "/testing/123")
+    assert request.url == httpx.URL("https://www.example.com/some/path/testing/123")
+
+
+def test_merge_relative_url_with_dotted_path():
+    client = httpx.Client(base_url="https://www.example.com/some/path")
+    request = client.build_request("GET", "../testing/123")
+    assert request.url == httpx.URL("https://www.example.com/some/testing/123")
+
+
 def test_pool_limits_deprecated():
     limits = httpx.Limits()
 
index 011c593cd3a5d3e2690bb054bab5df07cabbb590..3532774727acd26e062946d7245dae2513962c79 100644 (file)
@@ -1,4 +1,25 @@
-from httpx import AsyncClient, Cookies, Headers
+from httpx import URL, AsyncClient, Cookies, Headers
+
+
+def test_client_base_url():
+    client = AsyncClient()
+    client.base_url = "https://www.example.org/"  # type: ignore
+    assert isinstance(client.base_url, URL)
+    assert client.base_url == URL("https://www.example.org/")
+
+
+def test_client_base_url_without_trailing_slash():
+    client = AsyncClient()
+    client.base_url = "https://www.example.org/path"  # type: ignore
+    assert isinstance(client.base_url, URL)
+    assert client.base_url == URL("https://www.example.org/path/")
+
+
+def test_client_base_url_with_trailing_slash():
+    client = AsyncClient()
+    client.base_url = "https://www.example.org/path/"  # type: ignore
+    assert isinstance(client.base_url, URL)
+    assert client.base_url == URL("https://www.example.org/path/")
 
 
 def test_client_headers():