]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support Client(base_url=...)
authorTom Christie <tom@tomchristie.com>
Fri, 21 Jun 2019 14:03:01 +0000 (15:03 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 21 Jun 2019 14:03:01 +0000 (15:03 +0100)
http3/client.py
http3/models.py
tests/client/test_client.py

index fd92ef32bcb5dbe836226aa1ee3982d1f78d8d3c..6c0557b88828e5848ee38dd8cd2aedecd073bb36 100644 (file)
@@ -51,6 +51,7 @@ class BaseClient:
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
+        base_url: URLTypes = None,
         dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
         app: typing.Callable = None,
         backend: ConcurrencyBackend = None,
@@ -79,6 +80,11 @@ class BaseClient:
         else:
             async_dispatch = dispatch
 
+        if base_url is None:
+            self.base_url = URL('', allow_relative=True)
+        else:
+            self.base_url = URL(base_url)
+
         self.auth = auth
         self.cookies = Cookies(cookies)
         self.max_redirects = max_redirects
@@ -238,7 +244,7 @@ class BaseClient:
         # Facilitate relative 'Location' headers, as allowed by RFC 7231.
         # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
         if url.is_relative_url:
-            url = url.resolve_with(request.url)
+            url = request.url.join(url)
 
         # Attach previous fragment if needed (RFC 7231 7.1.2)
         if request.url.fragment and not url.fragment:
@@ -506,6 +512,8 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
     ) -> AsyncResponse:
+        url = self.base_url.join(url)
+        cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
             method,
             url,
@@ -514,7 +522,7 @@ class AsyncClient(BaseClient):
             json=json,
             params=params,
             headers=headers,
-            cookies=self.merge_cookies(cookies),
+            cookies=cookies,
         )
         response = await self.send(
             request,
@@ -585,6 +593,8 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
     ) -> Response:
+        url = self.base_url.join(url)
+        cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
             method,
             url,
@@ -593,7 +603,7 @@ class Client(BaseClient):
             json=json,
             params=params,
             headers=headers,
-            cookies=self.merge_cookies(cookies),
+            cookies=cookies,
         )
         concurrency_backend = self.concurrency_backend
 
index 94720a51e1c49691e7caa1d4e39af5681e63e091..bd3333a1e464779b9077b4b7ae524cd15fe281e5 100644 (file)
@@ -183,14 +183,18 @@ class URL:
     def copy_with(self, **kwargs: typing.Any) -> "URL":
         return URL(self.components.copy_with(**kwargs))
 
-    def resolve_with(self, base_url: URLTypes) -> "URL":
+    def join(self, relative_url: URLTypes) -> "URL":
         """
-        Return an absolute URL, using base_url as the base.
+        Return an absolute URL, using given this URL as the base.
         """
+        if self.is_relative_url:
+            return URL(relative_url)
+
         # We drop any fragment portion, because RFC 3986 strictly
         # treats URLs with a fragment portion as not being absolute URLs.
-        base_url = URL(base_url).copy_with(fragment=None)
-        return URL(self.components.resolve_with(base_url.components))
+        base_components = self.components.copy_with(fragment=None)
+        relative_url = URL(relative_url, allow_relative=True)
+        return URL(relative_url.components.resolve_with(base_components))
 
     def __hash__(self) -> int:
         return hash(str(self))
index 48924cb59adee3f08d5694102ed8cb867a3a2901..4e52c6f89ec62b507f4f3ee195e2c0618492ea76 100644 (file)
@@ -141,3 +141,12 @@ def test_delete(server):
         response = http.delete("http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.reason_phrase == "OK"
+
+
+@threadpool
+def test_base_url(server):
+    base_url = "http://127.0.0.1:8000/"
+    with http3.Client(base_url=base_url) as http:
+        response = http.get('/')
+    assert response.status_code == 200
+    assert str(response.url) == base_url