]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Move HSTS preload checking to client (#184)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Thu, 1 Aug 2019 09:26:45 +0000 (04:26 -0500)
committerTom Christie <tom@tomchristie.com>
Thu, 1 Aug 2019 09:26:45 +0000 (10:26 +0100)
httpx/client.py
httpx/models.py
tests/client/test_client.py
tests/models/test_url.py

index 46be340ccff4c6890d0c2e4936a121d432c5bf98..5e814b38ee48a755ec3bce8becc76ff637bb9ffe 100644 (file)
@@ -2,6 +2,8 @@ import inspect
 import typing
 from types import TracebackType
 
+import hstspreload
+
 from .auth import HTTPBasicAuth
 from .concurrency import AsyncioBackend
 from .config import (
@@ -105,6 +107,12 @@ class BaseClient:
         self.concurrency_backend = backend
         self.trust_env = trust_env
 
+    def merge_url(self, url: URLTypes) -> URL:
+        url = self.base_url.join(relative_url=url)
+        if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
+            url = url.copy_with(scheme="https")
+        return url
+
     def merge_cookies(
         self, cookies: CookieTypes = None
     ) -> typing.Optional[CookieTypes]:
@@ -564,7 +572,7 @@ class AsyncClient(BaseClient):
         timeout: TimeoutTypes = None,
         trust_env: bool = True,
     ) -> AsyncResponse:
-        url = self.base_url.join(url)
+        url = self.merge_url(url)
         headers = self.merge_headers(headers)
         cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
@@ -648,7 +656,7 @@ class Client(BaseClient):
         timeout: TimeoutTypes = None,
         trust_env: bool = True,
     ) -> Response:
-        url = self.base_url.join(url)
+        url = self.merge_url(url)
         headers = self.merge_headers(headers)
         cookies = self.merge_cookies(cookies)
         request = AsyncRequest(
index 32e412fc762343612a1ef7133c2705665c14b50c..5e0c827ed5f78d7e33f4bbeb1eb43519bf9334e2 100644 (file)
@@ -8,7 +8,6 @@ from http.cookiejar import Cookie, CookieJar
 from urllib.parse import parse_qsl, urlencode
 
 import chardet
-import hstspreload
 import rfc3986
 
 from .config import USER_AGENT
@@ -109,14 +108,6 @@ class URL:
             if not self.host:
                 raise InvalidURL("No host included in URL.")
 
-        # If the URL is HTTP but the host is on the HSTS preload list switch to HTTPS.
-        if (
-            self.scheme == "http"
-            and self.host
-            and hstspreload.in_hsts_preload(self.host)
-        ):
-            self._uri_reference = self._uri_reference.copy_with(scheme="https")
-
     @property
     def scheme(self) -> str:
         return self._uri_reference.scheme or ""
index 8c402f5443bad5e5fd7976df3718913237cd4754..f85fe77f48008569aa1da1e4523e66ba90a3f55a 100644 (file)
@@ -150,3 +150,11 @@ def test_base_url(server):
         response = http.get("/")
     assert response.status_code == 200
     assert str(response.url) == base_url
+
+
+def test_merge_url():
+    client = httpx.Client(base_url="https://www.paypal.com/")
+    url = client.merge_url("http://www.paypal.com")
+
+    assert url.scheme == "https"
+    assert url.is_ssl
index a556ed89a1224ea317a35b1c4345e58fa2112875..7cbfc1c9ebdeded0a2e29c5f9155f8bba1ab5db9 100644 (file)
@@ -176,11 +176,3 @@ def test_url_set():
     url_set = set(urls)
 
     assert all(url in urls for url in url_set)
-
-
-def test_hsts_preload_converted_to_https():
-    url = URL("http://www.paypal.com")
-
-    assert url.is_ssl
-    assert url.scheme == "https"
-    assert url == "https://www.paypal.com"