]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
URL matching (#1098)
authorTom Christie <tom@tomchristie.com>
Fri, 31 Jul 2020 09:11:49 +0000 (10:11 +0100)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 09:11:49 +0000 (10:11 +0100)
* Add internal URLMatcher class

* Use URLMatcher for proxy lookups in transport_for_url

* Docstring

* Pin pytest

* Update httpx/_utils.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_client.py
httpx/_utils.py
tests/client/test_proxies.py
tests/test_utils.py

index 518531986d4e49345d75babb4261e194d41893fc..dca8718fcfed28cef081b5dc1e1573749af20c02 100644 (file)
@@ -44,6 +44,7 @@ from ._types import (
 )
 from ._utils import (
     NetRCInfo,
+    URLMatcher,
     enforce_http_url,
     get_environment_proxies,
     get_logger,
@@ -471,8 +472,8 @@ class Client(BaseClient):
             app=app,
             trust_env=trust_env,
         )
-        self._proxies: typing.Dict[str, httpcore.SyncHTTPTransport] = {
-            key: self._init_proxy_transport(
+        self._proxies: typing.Dict[URLMatcher, httpcore.SyncHTTPTransport] = {
+            URLMatcher(key): self._init_proxy_transport(
                 proxy,
                 verify=verify,
                 cert=cert,
@@ -482,6 +483,7 @@ class Client(BaseClient):
             )
             for key, proxy in proxy_map.items()
         }
+        self._proxies = dict(sorted(self._proxies.items()))
 
     def _init_transport(
         self,
@@ -539,21 +541,8 @@ class Client(BaseClient):
         enforce_http_url(url)
 
         if self._proxies and not should_not_be_proxied(url):
-            default_port = {"http": 80, "https": 443}[url.scheme]
-            is_default_port = url.port is None or url.port == default_port
-            port = url.port or default_port
-            hostname = f"{url.host}:{port}"
-            proxy_keys = (
-                f"{url.scheme}://{hostname}",
-                f"{url.scheme}://{url.host}" if is_default_port else None,
-                f"all://{hostname}",
-                f"all://{url.host}" if is_default_port else None,
-                url.scheme,
-                "all",
-            )
-            for proxy_key in proxy_keys:
-                if proxy_key and proxy_key in self._proxies:
-                    transport = self._proxies[proxy_key]
+            for matcher, transport in self._proxies.items():
+                if matcher.matches(url):
                     return transport
 
         return self._transport
@@ -1000,8 +989,8 @@ class AsyncClient(BaseClient):
             app=app,
             trust_env=trust_env,
         )
-        self._proxies: typing.Dict[str, httpcore.AsyncHTTPTransport] = {
-            key: self._init_proxy_transport(
+        self._proxies: typing.Dict[URLMatcher, httpcore.AsyncHTTPTransport] = {
+            URLMatcher(key): self._init_proxy_transport(
                 proxy,
                 verify=verify,
                 cert=cert,
@@ -1011,6 +1000,7 @@ class AsyncClient(BaseClient):
             )
             for key, proxy in proxy_map.items()
         }
+        self._proxies = dict(sorted(self._proxies.items()))
 
     def _init_transport(
         self,
@@ -1068,21 +1058,8 @@ class AsyncClient(BaseClient):
         enforce_http_url(url)
 
         if self._proxies and not should_not_be_proxied(url):
-            default_port = {"http": 80, "https": 443}[url.scheme]
-            is_default_port = url.port is None or url.port == default_port
-            port = url.port or default_port
-            hostname = f"{url.host}:{port}"
-            proxy_keys = (
-                f"{url.scheme}://{hostname}",
-                f"{url.scheme}://{url.host}" if is_default_port else None,
-                f"all://{hostname}",
-                f"all://{url.host}" if is_default_port else None,
-                url.scheme,
-                "all",
-            )
-            for proxy_key in proxy_keys:
-                if proxy_key and proxy_key in self._proxies:
-                    transport = self._proxies[proxy_key]
+            for matcher, transport in self._proxies.items():
+                if matcher.matches(url):
                     return transport
 
         return self._transport
index f01cfe4ecc5dbfdf1347c81214df3f8ba8535728..2533a86d2a557442d25461a329cc2bdb714a2338 100644 (file)
@@ -429,5 +429,89 @@ class ElapsedTimer:
         return timedelta(seconds=self.end - self.start)
 
 
+class URLMatcher:
+    """
+    A utility class currently used for making lookups against proxy keys...
+
+    # Wildcard matching...
+    >>> pattern = URLMatcher("all")
+    >>> pattern.matches(httpx.URL("http://example.com"))
+    True
+
+    # Witch scheme matching...
+    >>> pattern = URLMatcher("https")
+    >>> pattern.matches(httpx.URL("https://example.com"))
+    True
+    >>> pattern.matches(httpx.URL("http://example.com"))
+    False
+
+    # With domain matching...
+    >>> pattern = URLMatcher("https://example.com")
+    >>> pattern.matches(httpx.URL("https://example.com"))
+    True
+    >>> pattern.matches(httpx.URL("http://example.com"))
+    False
+    >>> pattern.matches(httpx.URL("https://other.com"))
+    False
+
+    # Wildcard scheme, with domain matching...
+    >>> pattern = URLMatcher("all://example.com")
+    >>> pattern.matches(httpx.URL("https://example.com"))
+    True
+    >>> pattern.matches(httpx.URL("http://example.com"))
+    True
+    >>> pattern.matches(httpx.URL("https://other.com"))
+    False
+
+    # With port matching...
+    >>> pattern = URLMatcher("https://example.com:1234")
+    >>> pattern.matches(httpx.URL("https://example.com:1234"))
+    True
+    >>> pattern.matches(httpx.URL("https://example.com"))
+    False
+    """
+
+    def __init__(self, pattern: str) -> None:
+        from ._models import URL
+
+        if pattern and ":" not in pattern:
+            pattern += "://"
+
+        url = URL(pattern)
+        self.pattern = pattern
+        self.scheme = "" if url.scheme == "all" else url.scheme
+        self.host = url.host
+        self.port = url.port
+
+    def matches(self, other: "URL") -> bool:
+        if self.scheme and self.scheme != other.scheme:
+            return False
+        if self.host and self.host != other.host:
+            return False
+        if self.port is not None and self.port != other.port:
+            return False
+        return True
+
+    @property
+    def priority(self) -> tuple:
+        """
+        The priority allows URLMatcher instances to be sortable, so that
+        we can match from most specific to least specific.
+        """
+        port_priority = -1 if self.port is not None else 0
+        host_priority = -len(self.host)
+        scheme_priority = -len(self.scheme)
+        return (port_priority, host_priority, scheme_priority)
+
+    def __hash__(self) -> int:
+        return hash(self.pattern)
+
+    def __lt__(self, other: "URLMatcher") -> bool:
+        return self.priority < other.priority
+
+    def __eq__(self, other: typing.Any) -> bool:
+        return isinstance(other, URLMatcher) and self.pattern == other.pattern
+
+
 def warn_deprecated(message: str) -> None:  # pragma: nocover
     warnings.warn(message, DeprecationWarning, stacklevel=2)
index f5af90cc81d9915bb77b0466e7c740cc632930dc..6120bc1d385cd9fb3006b60b76b9468d5ebdfd06 100644 (file)
@@ -2,6 +2,7 @@ import httpcore
 import pytest
 
 import httpx
+from httpx._utils import URLMatcher
 
 
 def url_to_origin(url: str):
@@ -36,8 +37,9 @@ def test_proxies_parameter(proxies, expected_proxies):
     client = httpx.AsyncClient(proxies=proxies)
 
     for proxy_key, url in expected_proxies:
-        assert proxy_key in client._proxies
-        proxy = client._proxies[proxy_key]
+        matcher = URLMatcher(proxy_key)
+        assert matcher in client._proxies
+        proxy = client._proxies[matcher]
         assert isinstance(proxy, httpcore.AsyncHTTPProxy)
         assert proxy.proxy_origin == url_to_origin(url)
 
@@ -54,15 +56,15 @@ PROXY_URL = "http://[::1]"
         ("http://example.com", {}, None),
         ("http://example.com", {"https": PROXY_URL}, None),
         ("http://example.com", {"http://example.net": PROXY_URL}, None),
-        ("http://example.com:443", {"http://example.com": PROXY_URL}, None),
+        ("http://example.com:443", {"http://example.com": PROXY_URL}, PROXY_URL),
         ("http://example.com", {"all": PROXY_URL}, PROXY_URL),
         ("http://example.com", {"http": PROXY_URL}, PROXY_URL),
         ("http://example.com", {"all://example.com": PROXY_URL}, PROXY_URL),
-        ("http://example.com", {"all://example.com:80": PROXY_URL}, PROXY_URL),
+        ("http://example.com", {"all://example.com:80": PROXY_URL}, None),
         ("http://example.com", {"http://example.com": PROXY_URL}, PROXY_URL),
-        ("http://example.com", {"http://example.com:80": PROXY_URL}, PROXY_URL),
+        ("http://example.com", {"http://example.com:80": PROXY_URL}, None),
         ("http://example.com:8080", {"http://example.com:8080": PROXY_URL}, PROXY_URL),
-        ("http://example.com:8080", {"http://example.com": PROXY_URL}, None),
+        ("http://example.com:8080", {"http://example.com": PROXY_URL}, PROXY_URL),
         (
             "http://example.com",
             {
index fa30ee87cd9bfb320254926e04c71b70826cca3a..88fb1000bbac33b5339cecbcd709008445d46909 100644 (file)
@@ -1,5 +1,6 @@
 import asyncio
 import os
+import random
 
 import pytest
 
@@ -7,6 +8,7 @@ import httpx
 from httpx._utils import (
     ElapsedTimer,
     NetRCInfo,
+    URLMatcher,
     get_ca_bundle_from_env,
     get_environment_proxies,
     guess_json_utf,
@@ -307,3 +309,43 @@ def test_not_same_origin():
     origin1 = httpx.URL("https://example.com")
     origin2 = httpx.URL("HTTP://EXAMPLE.COM")
     assert not same_origin(origin1, origin2)
+
+
+@pytest.mark.parametrize(
+    ["pattern", "url", "expected"],
+    [
+        ("http://example.com", "http://example.com", True,),
+        ("http://example.com", "https://example.com", False,),
+        ("http://example.com", "http://other.com", False,),
+        ("http://example.com:123", "http://example.com:123", True,),
+        ("http://example.com:123", "http://example.com:456", False,),
+        ("http://example.com:123", "http://example.com", False,),
+        ("all://example.com", "http://example.com", True,),
+        ("all://example.com", "https://example.com", True,),
+        ("http://", "http://example.com", True,),
+        ("http://", "https://example.com", False,),
+        ("http", "http://example.com", True,),
+        ("http", "https://example.com", False,),
+        ("all", "https://example.com:123", True,),
+        ("", "https://example.com:123", True,),
+    ],
+)
+def test_url_matches(pattern, url, expected):
+    matcher = URLMatcher(pattern)
+    assert matcher.matches(httpx.URL(url)) == expected
+
+
+def test_matcher_priority():
+    matchers = [
+        URLMatcher("all://"),
+        URLMatcher("http://"),
+        URLMatcher("http://example.com"),
+        URLMatcher("http://example.com:123"),
+    ]
+    random.shuffle(matchers)
+    assert sorted(matchers) == [
+        URLMatcher("http://example.com:123"),
+        URLMatcher("http://example.com"),
+        URLMatcher("http://"),
+        URLMatcher("all://"),
+    ]