]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Propogate HTTPProxy config from Client(#377)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Wed, 25 Sep 2019 12:17:20 +0000 (07:17 -0500)
committerGitHub <noreply@github.com>
Wed, 25 Sep 2019 12:17:20 +0000 (07:17 -0500)
httpx/client.py
httpx/dispatch/proxy_http.py
tests/client/test_proxies.py

index 30beeb75fd328345d6281c1f4c1abc0071e1a956..73b30699b006aa4d29d7d9d46b31a7f885c87057 100644 (file)
@@ -107,13 +107,6 @@ class BaseClient:
         else:
             self.base_url = URL(base_url)
 
-        if proxies is None and trust_env:
-            proxies = typing.cast(ProxiesTypes, get_environment_proxies())
-
-        self.proxies: typing.Dict[str, AsyncDispatcher] = _proxies_to_dispatchers(
-            proxies
-        )
-
         if params is None:
             params = {}
 
@@ -125,6 +118,20 @@ class BaseClient:
         self.dispatch = async_dispatch
         self.concurrency_backend = backend
 
+        if proxies is None and trust_env:
+            proxies = typing.cast(ProxiesTypes, get_environment_proxies())
+
+        self.proxies: typing.Dict[str, AsyncDispatcher] = _proxies_to_dispatchers(
+            proxies,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            http_versions=http_versions,
+            pool_limits=pool_limits,
+            backend=backend,
+            trust_env=trust_env,
+        )
+
     @property
     def headers(self) -> Headers:
         return self._headers
@@ -196,16 +203,11 @@ class BaseClient:
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         if request.url.scheme not in ("http", "https"):
             raise InvalidURL('URL scheme must be "http" or "https".')
 
-        if proxies is not None:
-            dispatch_proxies = _proxies_to_dispatchers(proxies)
-        else:
-            dispatch_proxies = self.proxies
-        dispatch = self._dispatcher_for_request(request, dispatch_proxies)
+        dispatch = self._dispatcher_for_request(request, self.proxies)
 
         async def get_response(request: AsyncRequest) -> AsyncResponse:
             try:
@@ -351,7 +353,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "GET",
@@ -366,7 +367,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def options(
@@ -383,7 +383,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "OPTIONS",
@@ -398,7 +397,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def head(
@@ -415,7 +413,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "HEAD",
@@ -430,7 +427,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def post(
@@ -450,7 +446,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "POST",
@@ -468,7 +463,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def put(
@@ -488,7 +482,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "PUT",
@@ -506,7 +499,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def patch(
@@ -526,7 +518,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "PATCH",
@@ -544,7 +535,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def delete(
@@ -564,7 +554,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "DELETE",
@@ -582,7 +571,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def request(
@@ -603,7 +591,6 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         request = self.build_request(
             method=method,
@@ -624,7 +611,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
         return response
 
@@ -639,7 +625,6 @@ class AsyncClient(BaseClient):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self._get_response(
             request=request,
@@ -650,7 +635,6 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     async def close(self) -> None:
@@ -729,7 +713,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         request = self.build_request(
             method=method,
@@ -750,7 +733,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
         return response
 
@@ -765,7 +747,6 @@ class Client(BaseClient):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         concurrency_backend = self.concurrency_backend
 
@@ -779,7 +760,6 @@ class Client(BaseClient):
             "cert": cert,
             "timeout": timeout,
             "trust_env": trust_env,
-            "proxies": proxies,
         }
         async_response = concurrency_backend.run(coroutine, *args, **kwargs)
 
@@ -824,7 +804,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "GET",
@@ -839,7 +818,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def options(
@@ -856,7 +834,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "OPTIONS",
@@ -871,7 +848,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def head(
@@ -888,7 +864,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "HEAD",
@@ -903,7 +878,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def post(
@@ -923,7 +897,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "POST",
@@ -941,7 +914,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def put(
@@ -961,7 +933,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "PUT",
@@ -979,7 +950,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def patch(
@@ -999,7 +969,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "PATCH",
@@ -1017,7 +986,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def delete(
@@ -1037,7 +1005,6 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
-        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "DELETE",
@@ -1055,7 +1022,6 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
-            proxies=proxies,
         )
 
     def close(self) -> None:
@@ -1074,16 +1040,32 @@ class Client(BaseClient):
         self.close()
 
 
-def _proxy_from_url(url: URLTypes) -> AsyncDispatcher:
-    url = URL(url)
-    if url.scheme in ("http", "https"):
-        return HTTPProxy(url)
-    raise ValueError(f"Unknown proxy for {url!r}")
-
-
 def _proxies_to_dispatchers(
-    proxies: typing.Optional[ProxiesTypes]
+    proxies: typing.Optional[ProxiesTypes],
+    verify: VerifyTypes,
+    cert: typing.Optional[CertTypes],
+    timeout: TimeoutTypes,
+    http_versions: typing.Optional[HTTPVersionTypes],
+    pool_limits: PoolLimits,
+    backend: ConcurrencyBackend,
+    trust_env: bool,
 ) -> typing.Dict[str, AsyncDispatcher]:
+    def _proxy_from_url(url: URLTypes) -> AsyncDispatcher:
+        nonlocal verify, cert, timeout, http_versions, pool_limits, backend, trust_env
+        url = URL(url)
+        if url.scheme in ("http", "https"):
+            return HTTPProxy(
+                url,
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                pool_limits=pool_limits,
+                backend=backend,
+                trust_env=trust_env,
+                http_versions=http_versions,
+            )
+        raise ValueError(f"Unknown proxy for {url!r}")
+
     if proxies is None:
         return {}
     elif isinstance(proxies, (str, URL)):
index fe8bbd3c44e958e66217b5ef5704239eff56a9af..be2e289ffcc94796c202ad6d93763654ff8e9dd4 100644 (file)
@@ -7,6 +7,7 @@ from ..config import (
     DEFAULT_POOL_LIMITS,
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
+    HTTPVersionTypes,
     PoolLimits,
     SSLConfig,
     TimeoutTypes,
@@ -51,8 +52,10 @@ class HTTPProxy(ConnectionPool):
         proxy_mode: HTTPProxyMode = HTTPProxyMode.DEFAULT,
         verify: VerifyTypes = True,
         cert: CertTypes = None,
+        trust_env: bool = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+        http_versions: HTTPVersionTypes = None,
         backend: ConcurrencyBackend = None,
     ):
 
@@ -62,6 +65,8 @@ class HTTPProxy(ConnectionPool):
             timeout=timeout,
             pool_limits=pool_limits,
             backend=backend,
+            trust_env=trust_env,
+            http_versions=http_versions,
         )
 
         self.proxy_url = URL(proxy_url)
index 19521b7e43e25f09bbf2eaec3c0f47201f6f3576..ba59677d7a12274f4a2b1cf590b86278e9ec31f0 100644 (file)
@@ -27,3 +27,22 @@ def test_proxies_parameter(proxies, expected_proxies):
         assert client.proxies[proxy_key].proxy_url == url
 
     assert len(expected_proxies) == len(client.proxies)
+
+
+def test_proxies_has_same_properties_as_dispatch():
+    client = httpx.AsyncClient(proxies="http://127.0.0.1")
+    pool = client.dispatch
+    proxy = client.proxies["all"]
+
+    assert isinstance(pool, httpx.ConnectionPool)
+    assert isinstance(proxy, httpx.HTTPProxy)
+
+    for prop in [
+        "verify",
+        "cert",
+        "timeout",
+        "pool_limits",
+        "http_versions",
+        "backend",
+    ]:
+        assert getattr(pool, prop) == getattr(proxy, prop)