]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Configure Proxy from Client (#353)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Fri, 20 Sep 2019 17:33:06 +0000 (12:33 -0500)
committerGitHub <noreply@github.com>
Fri, 20 Sep 2019 17:33:06 +0000 (12:33 -0500)
* Incorporate suggestions

* Start of proxy config

* Incorporate suggestions

* Add proxies to high-level API, docs

* Update client.py

13 files changed:
docs/advanced.md
docs/api.md
docs/environment_variables.md
httpx/api.py
httpx/client.py
httpx/dispatch/connection_pool.py
httpx/models.py
httpx/utils.py
tests/client/test_proxies.py [new file with mode: 0644]
tests/conftest.py
tests/dispatch/test_proxy_http.py
tests/dispatch/utils.py
tests/test_utils.py

index dba5128ded7d3d8cdb220b0244f51a1b46e45403..6acad3b67bbfa491bf8244a86f953644262b1021 100644 (file)
@@ -111,3 +111,52 @@ password example-password
 
 ...
 ```
+
+## HTTP Proxying
+
+HTTPX supports setting up proxies the same way that Requests does via the `proxies` parameter.
+For example to forward all HTTP traffic to `http://127.0.0.1:3080` and all HTTPS traffic
+to `http://127.0.0.1:3081` your `proxies` config would look like this:
+
+```python
+>>> client = httpx.Client(proxies={
+  "http": "http://127.0.0.1:3080",
+  "https": "http://127.0.0.1:3081"
+})
+```
+
+Proxies can be configured for a specific scheme and host, all schemes of a host,
+all hosts for a scheme, or for all requests. When determining which proxy configuration
+to use for a given request this same order is used.
+
+```python
+>>> client = httpx.Client(proxies={
+    "http://example.com":  "...",  # Host+Scheme
+    "all://example.com":  "...",  # Host
+    "http": "...",  # Scheme
+    "all": "...",  # All
+})
+>>> client = httpx.Client(proxies="...")  # Shortcut for 'all'
+```
+
+!!! warning
+    To make sure that proxies cannot read your traffic,
+    and even if the proxy_url uses HTTPS, it is recommended to
+    use HTTPS and tunnel requests if possible.
+
+By default `HTTPProxy` will operate as a forwarding proxy for `http://...` requests
+and will establish a `CONNECT` TCP tunnel for `https://` requests. This doesn't change
+regardless of the `proxy_url` being `http` or `https`.
+
+Proxies can be configured to have different behavior such as forwarding or tunneling all requests:
+
+```python
+proxy = httpx.HTTPProxy(
+    proxy_url="https://127.0.0.1",
+    proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY
+)
+client = httpx.Client(proxies=proxy)
+
+# This request will be tunnelled instead of forwarded.
+client.get("http://example.com")
+```
index 83fe1e81cb535239c70156277b9607ef9ad85ec0..104ba278c13d5535cfa4989f5a1688cb7bf34d1b 100644 (file)
@@ -8,14 +8,14 @@
     enable HTTP/2 and connection pooling for more efficient and
     long-lived connections.
 
-* `get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `post(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `put(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `patch(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `delete(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `request(method, url, [data], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
+* `get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `post(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `put(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `patch(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `delete(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `request(method, url, [data], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
 * `build_request(method, url, [data], [files], [json], [params], [headers], [cookies])`
 
 ## `Client`
 * `def __init__([auth], [headers], [cookies], [verify], [cert], [timeout], [pool_limits], [max_redirects], [app], [dispatch])`
 * `.headers` - **Headers**
 * `.cookies` - **Cookies**
-* `def .get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .post(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .put(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .patch(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .delete(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
-* `def .request(method, url, [data], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout])`
+* `def .get(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .options(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .head(url, [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .post(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .put(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .patch(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .delete(url, [data], [json], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
+* `def .request(method, url, [data], [params], [headers], [cookies], [auth], [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
 * `def .build_request(method, url, [data], [files], [json], [params], [headers], [cookies])`
-* `def .send(request, [stream], [allow_redirects], [verify], [cert], [timeout])`
+* `def .send(request, [stream], [allow_redirects], [verify], [cert], [timeout], [proxies])`
 * `def .close()`
 
 ## `Response`
index d8f3960d49f4a92ddbc45003e5f326d26dce1468..394fde2524be2ddd178e77a52ecd791c578403f0 100644 (file)
@@ -79,3 +79,17 @@ SERVER_TRAFFIC_SECRET_0 XXXX
 CLIENT_HANDSHAKE_TRAFFIC_SECRET XXXX
 CLIENT_TRAFFIC_SECRET_0 XXXX
 ```
+
+`HTTP_PROXY`, `HTTPS_PROXY`, `ALL_PROXY`
+----------------------------------------
+
+Valid values: A URL to a proxy
+
+Sets the proxy to be used for `http`, `https`, or all requests respectively.
+
+```bash
+export HTTP_PROXY=http://127.0.0.1:3080
+
+# This request will be sent through the proxy
+python -c "import httpx; httpx.get('http://example.com')"
+```
index 5f0d639cda60d83db4a07eeb398f752add27d984..fe6316b13b752ee38f404c10ba26dc781e2e78c6 100644 (file)
@@ -6,6 +6,7 @@ from .models import (
     AuthTypes,
     CookieTypes,
     HeaderTypes,
+    ProxiesTypes,
     QueryParamTypes,
     RequestData,
     RequestFiles,
@@ -31,6 +32,7 @@ def request(
     verify: VerifyTypes = True,
     stream: bool = False,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     with Client(http_versions=["HTTP/1.1"]) as client:
         return client.request(
@@ -65,6 +67,7 @@ def get(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "GET",
@@ -95,6 +98,7 @@ def options(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "OPTIONS",
@@ -125,6 +129,7 @@ def head(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "HEAD",
@@ -158,6 +163,7 @@ def post(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "POST",
@@ -194,6 +200,7 @@ def put(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "PUT",
@@ -230,6 +237,7 @@ def patch(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "PATCH",
@@ -266,6 +274,7 @@ def delete(
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
     trust_env: bool = None,
+    proxies: ProxiesTypes = None,
 ) -> Response:
     return request(
         "DELETE",
index 090b405fb796f9ae6ecb07bec7d223ce828b341a..266d980c9bbe1ba13905929d39e1065ac75d54db 100644 (file)
@@ -20,6 +20,7 @@ from .config import (
 from .dispatch.asgi import ASGIDispatch
 from .dispatch.base import AsyncDispatcher, Dispatcher
 from .dispatch.connection_pool import ConnectionPool
+from .dispatch.proxy_http import HTTPProxy
 from .dispatch.threaded import ThreadedDispatcher
 from .dispatch.wsgi import WSGIDispatch
 from .exceptions import HTTPError, InvalidURL
@@ -38,6 +39,7 @@ from .models import (
     CookieTypes,
     Headers,
     HeaderTypes,
+    ProxiesTypes,
     QueryParamTypes,
     RequestData,
     RequestFiles,
@@ -45,18 +47,20 @@ from .models import (
     ResponseContent,
     URLTypes,
 )
-from .utils import ElapsedTimer, get_netrc_login
+from .utils import ElapsedTimer, get_environment_proxies, get_netrc_login
 
 
 class BaseClient:
     def __init__(
         self,
+        *,
         auth: AuthTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
         verify: VerifyTypes = True,
         cert: CertTypes = None,
         http_versions: HTTPVersionTypes = None,
+        proxies: ProxiesTypes = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
@@ -101,6 +105,13 @@ class BaseClient:
         else:
             self.base_url = URL(base_url)
 
+        if proxies is None and trust_env:
+            proxies = typing.cast(ProxiesTypes, get_environment_proxies())
+
+        self.proxies: typing.Dict[str, AsyncDispatcher] = _proxies_to_dispatchers(
+            proxies
+        )
+
         self.auth = auth
         self._headers = Headers(headers)
         self._cookies = Cookies(cookies)
@@ -162,20 +173,30 @@ class BaseClient:
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         if request.url.scheme not in ("http", "https"):
             raise InvalidURL('URL scheme must be "http" or "https".')
 
+        if proxies is not None:
+            dispatch_proxies = _proxies_to_dispatchers(proxies)
+        else:
+            dispatch_proxies = self.proxies
+        dispatch = self._dispatcher_for_request(request, dispatch_proxies)
+
         async def get_response(request: AsyncRequest) -> AsyncResponse:
             try:
                 with ElapsedTimer() as timer:
-                    response = await self.dispatch.send(
+                    response = await dispatch.send(
                         request, verify=verify, cert=cert, timeout=timeout
                     )
                 response.elapsed = timer.elapsed
             except HTTPError as exc:
-                # Add the original request to any HTTPError
-                exc.request = request
+                # Add the original request to any HTTPError unless
+                # there'a already a request attached in the case of
+                # a ProxyError.
+                if exc.request is None:
+                    exc.request = request
                 raise
 
             self.cookies.extract_cookies(response)
@@ -238,6 +259,31 @@ class BaseClient:
 
         return None
 
+    def _dispatcher_for_request(
+        self, request: AsyncRequest, proxies: typing.Dict[str, AsyncDispatcher]
+    ) -> AsyncDispatcher:
+        """Gets the AsyncDispatcher instance that should be used for a given Request"""
+        if proxies:
+            url = request.url
+            is_default_port = (url.scheme == "http" and url.port == 80) or (
+                url.scheme == "https" and url.port == 443
+            )
+            hostname = f"{url.host}:{url.port}"
+            proxy_keys = (
+                f"{url.scheme}://{hostname}",
+                f"{url.scheme}://{url.host}" if is_default_port else None,
+                f"all://{hostname}",
+                f"all://{url.host}" if is_default_port else None,
+                url.scheme,
+                "all",
+            )
+            for proxy_key in proxy_keys:
+                if proxy_key and proxy_key in proxies:
+                    dispatcher = proxies[proxy_key]
+                    return dispatcher
+
+        return self.dispatch
+
     def build_request(
         self,
         method: str,
@@ -281,6 +327,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "GET",
@@ -295,6 +342,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def options(
@@ -311,6 +359,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "OPTIONS",
@@ -325,6 +374,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def head(
@@ -341,6 +391,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "HEAD",
@@ -355,6 +406,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def post(
@@ -374,6 +426,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "POST",
@@ -391,6 +444,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def put(
@@ -410,6 +464,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "PUT",
@@ -427,6 +482,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def patch(
@@ -446,6 +502,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "PATCH",
@@ -463,6 +520,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def delete(
@@ -482,6 +540,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self.request(
             "DELETE",
@@ -499,6 +558,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def request(
@@ -519,6 +579,7 @@ class AsyncClient(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         request = self.build_request(
             method=method,
@@ -539,6 +600,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
         return response
 
@@ -553,6 +615,7 @@ class AsyncClient(BaseClient):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> AsyncResponse:
         return await self._get_response(
             request=request,
@@ -563,6 +626,7 @@ class AsyncClient(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     async def close(self) -> None:
@@ -641,6 +705,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         request = self.build_request(
             method=method,
@@ -661,6 +726,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
         return response
 
@@ -675,6 +741,7 @@ class Client(BaseClient):
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         concurrency_backend = self.concurrency_backend
 
@@ -688,6 +755,7 @@ class Client(BaseClient):
             "cert": cert,
             "timeout": timeout,
             "trust_env": trust_env,
+            "proxies": proxies,
         }
         async_response = concurrency_backend.run(coroutine, *args, **kwargs)
 
@@ -732,6 +800,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "GET",
@@ -746,6 +815,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def options(
@@ -762,6 +832,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "OPTIONS",
@@ -776,6 +847,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def head(
@@ -792,6 +864,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "HEAD",
@@ -806,6 +879,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def post(
@@ -825,6 +899,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "POST",
@@ -842,6 +917,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def put(
@@ -861,6 +937,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "PUT",
@@ -878,6 +955,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def patch(
@@ -897,6 +975,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "PATCH",
@@ -914,6 +993,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def delete(
@@ -933,6 +1013,7 @@ class Client(BaseClient):
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
         trust_env: bool = None,
+        proxies: ProxiesTypes = None,
     ) -> Response:
         return self.request(
             "DELETE",
@@ -950,6 +1031,7 @@ class Client(BaseClient):
             cert=cert,
             timeout=timeout,
             trust_env=trust_env,
+            proxies=proxies,
         )
 
     def close(self) -> None:
@@ -966,3 +1048,29 @@ class Client(BaseClient):
         traceback: TracebackType = None,
     ) -> None:
         self.close()
+
+
+def _proxy_from_url(url: URLTypes) -> AsyncDispatcher:
+    url = URL(url)
+    if url.scheme in ("http", "https"):
+        return HTTPProxy(url)
+    raise ValueError(f"Unknown proxy for {url!r}")
+
+
+def _proxies_to_dispatchers(
+    proxies: typing.Optional[ProxiesTypes]
+) -> typing.Dict[str, AsyncDispatcher]:
+    if proxies is None:
+        return {}
+    elif isinstance(proxies, (str, URL)):
+        return {"all": _proxy_from_url(proxies)}
+    elif isinstance(proxies, AsyncDispatcher):
+        return {"all": proxies}
+    else:
+        new_proxies = {}
+        for key, dispatcher_or_url in proxies.items():
+            if isinstance(dispatcher_or_url, (str, URL)):
+                new_proxies[str(key)] = _proxy_from_url(dispatcher_or_url)
+            else:
+                new_proxies[str(key)] = dispatcher_or_url
+        return new_proxies
index 2cc11c9884e61d856f4b2314bb349c60759d1684..1a23e8c80eea1883b5990eab48cc0fef55a93cf7 100644 (file)
@@ -128,7 +128,7 @@ class ConnectionPool(AsyncDispatcher):
         return response
 
     async def acquire_connection(self, origin: Origin) -> HTTPConnection:
-        logger.debug("acquire_connection origin={origin!r}")
+        logger.debug(f"acquire_connection origin={origin!r}")
         connection = self.pop_connection(origin)
 
         if connection is None:
index 126dc578cb84b00a6d9b0f09d4802251e47fac46..8d4b36269233f3aa90841644e5719d80548c2e1e 100644 (file)
@@ -42,6 +42,7 @@ from .utils import (
 
 if typing.TYPE_CHECKING:
     from .middleware.base import BaseMiddleware  # noqa: F401
+    from .dispatch.base import AsyncDispatcher  # noqa: F401
 
 PrimitiveData = typing.Optional[typing.Union[str, int, float, bool]]
 
@@ -68,6 +69,12 @@ AuthTypes = typing.Union[
     "BaseMiddleware",
 ]
 
+ProxiesTypes = typing.Union[
+    URLTypes,
+    "AsyncDispatcher",
+    typing.Dict[URLTypes, typing.Union[URLTypes, "AsyncDispatcher"]],
+]
+
 AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
 
 RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]]
index b2bb96c7456754b1f51b18fc390b6c4d28d2511a..cf6010cb128aa2c2d3220b566c0077a681b78216 100644 (file)
@@ -9,6 +9,7 @@ from datetime import timedelta
 from pathlib import Path
 from time import perf_counter
 from types import TracebackType
+from urllib.request import getproxies
 
 
 def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
@@ -174,14 +175,44 @@ def get_logger(name: str) -> logging.Logger:
     return logging.getLogger(name)
 
 
+def get_environment_proxies() -> typing.Dict[str, str]:
+    """Gets proxy information from the environment"""
+
+    # urllib.request.getproxies() falls back on System
+    # Registry and Config for proxies on Windows and macOS.
+    # We don't want to propagate non-HTTP proxies into
+    # our configuration such as 'TRAVIS_APT_PROXY'.
+    proxies = {
+        key: val
+        for key, val in getproxies().items()
+        if ("://" in key or key in ("http", "https"))
+    }
+
+    # Favor lowercase environment variables over uppercase.
+    all_proxy = get_environ_lower_and_upper("ALL_PROXY")
+    if all_proxy is not None:
+        proxies["all"] = all_proxy
+
+    return proxies
+
+
+def get_environ_lower_and_upper(key: str) -> typing.Optional[str]:
+    """Gets a value from os.environ with both the lowercase and uppercase
+    environment variable. Prioritizes the lowercase environment variable.
+    """
+    for key in (key.lower(), key.upper()):
+        value = os.environ.get(key, None)
+        if value is not None and isinstance(value, str):
+            return value
+    return None
+
+
 def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
     return value.encode(encoding) if isinstance(value, str) else value
 
 
-def to_str(str_or_bytes: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
-    return (
-        str_or_bytes if isinstance(str_or_bytes, str) else str_or_bytes.decode(encoding)
-    )
+def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
+    return value if isinstance(value, str) else value.decode(encoding)
 
 
 def unquote(value: str) -> str:
diff --git a/tests/client/test_proxies.py b/tests/client/test_proxies.py
new file mode 100644 (file)
index 0000000..19521b7
--- /dev/null
@@ -0,0 +1,29 @@
+import pytest
+
+import httpx
+
+
+@pytest.mark.parametrize(
+    ["proxies", "expected_proxies"],
+    [
+        ("http://127.0.0.1", [("all", "http://127.0.0.1")]),
+        ({"all": "http://127.0.0.1"}, [("all", "http://127.0.0.1")]),
+        (
+            {"http": "http://127.0.0.1", "https": "https://127.0.0.1"},
+            [("http", "http://127.0.0.1"), ("https", "https://127.0.0.1")],
+        ),
+        (httpx.HTTPProxy("http://127.0.0.1"), [("all", "http://127.0.0.1")]),
+        (
+            {"https": httpx.HTTPProxy("https://127.0.0.1"), "all": "http://127.0.0.1"},
+            [("all", "http://127.0.0.1"), ("https", "https://127.0.0.1")],
+        ),
+    ],
+)
+def test_proxies_parameter(proxies, expected_proxies):
+    client = httpx.Client(proxies=proxies)
+
+    for proxy_key, url in expected_proxies:
+        assert proxy_key in client.proxies
+        assert client.proxies[proxy_key].proxy_url == url
+
+    assert len(expected_proxies) == len(client.proxies)
index 6b68649d6959e22fb8aad70c421132461cfb00d6..658a6f943eec88a3d6175e54e2d6083f515fe7b1 100644 (file)
@@ -9,10 +9,10 @@ import pytest
 import trustme
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives.serialization import (
-    load_pem_private_key,
     BestAvailableEncryption,
     Encoding,
     PrivateFormat,
+    load_pem_private_key,
 )
 from uvicorn.config import Config
 from uvicorn.main import Server
index 0302f46b5f05afab0871d6ce067dbb8fc720a72f..845f363208e854c68fd727444f37a11dc159cdfe 100644 (file)
@@ -88,13 +88,24 @@ async def test_proxy_tunnel_start_tls(backend):
     raw_io = MockRawSocketBackend(
         data_to_send=(
             [
+                # Tunnel Response
                 b"HTTP/1.1 200 OK\r\n"
                 b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
                 b"Server: proxy-server\r\n"
                 b"\r\n",
+                # Response 1
                 b"HTTP/1.1 404 Not Found\r\n"
                 b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
                 b"Server: origin-server\r\n"
+                b"Connection: keep-alive\r\n"
+                b"Content-Length: 0\r\n"
+                b"\r\n",
+                # Response 2
+                b"HTTP/1.1 200 OK\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: origin-server\r\n"
+                b"Connection: keep-alive\r\n"
+                b"Content-Length: 0\r\n"
                 b"\r\n",
             ]
         ),
@@ -105,23 +116,38 @@ async def test_proxy_tunnel_start_tls(backend):
         backend=raw_io,
         proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
     ) as proxy:
-        response = await proxy.request("GET", f"https://example.com")
+        resp = await proxy.request("GET", f"https://example.com")
 
-        assert response.status_code == 404
-        assert response.headers["Server"] == "origin-server"
+        assert resp.status_code == 404
+        assert resp.headers["Server"] == "origin-server"
 
-        assert response.request.method == "GET"
-        assert response.request.url == "https://example.com"
-        assert response.request.headers["Host"] == "example.com"
+        assert resp.request.method == "GET"
+        assert resp.request.url == "https://example.com"
+        assert resp.request.headers["Host"] == "example.com"
+
+        await resp.read()
+
+        # Make another request to see that the tunnel is re-used.
+        resp = await proxy.request("GET", f"https://example.com/target")
+
+        assert resp.status_code == 200
+        assert resp.headers["Server"] == "origin-server"
+
+        assert resp.request.method == "GET"
+        assert resp.request.url == "https://example.com/target"
+        assert resp.request.headers["Host"] == "example.com"
+
+        await resp.read()
 
     recv = raw_io.received_data
-    assert len(recv) == 4
+    assert len(recv) == 5
     assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
     assert recv[1].startswith(
         b"CONNECT example.com:443 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
     )
     assert recv[2] == b"--- START_TLS(example.com) ---"
     assert recv[3].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
+    assert recv[4].startswith(b"GET /target HTTP/1.1\r\nhost: example.com\r\n")
 
 
 @pytest.mark.parametrize(
index b8bbb780648725e35e2abbbd9aee66489aa31aeb..38c1b24ec9d228fc5c335d80edd6f2267537a5cc 100644 (file)
@@ -216,5 +216,8 @@ class MockRawSocketStream(BaseTCPStream):
             return b""
         return self.backend.data_to_send.pop(0)
 
+    def is_connection_dropped(self) -> bool:
+        return False
+
     async def close(self) -> None:
         pass
index 9b28a6ff31bdcb3f343b1ba0496c1fe304f80ddc..8eed10e36173ff4b2807b0d34275ab80ec106743 100644 (file)
@@ -8,6 +8,7 @@ import httpx
 from httpx import utils
 from httpx.utils import (
     ElapsedTimer,
+    get_environment_proxies,
     get_netrc_login,
     guess_json_utf,
     parse_header_links,
@@ -128,3 +129,29 @@ async def test_elapsed_timer():
         0.1
     )  # test to ensure time spent after timer exits isn't accounted for.
     assert timer.elapsed.total_seconds() == pytest.approx(0.1, abs=0.05)
+
+
+@pytest.mark.parametrize(
+    ["environment", "proxies"],
+    [
+        ({}, {}),
+        ({"HTTP_PROXY": "http://127.0.0.1"}, {"http": "http://127.0.0.1"}),
+        (
+            {"https_proxy": "http://127.0.0.1", "HTTP_PROXY": "https://127.0.0.1"},
+            {"https": "http://127.0.0.1", "http": "https://127.0.0.1"},
+        ),
+        (
+            {"all_proxy": "http://127.0.0.1", "ALL_PROXY": "https://1.1.1.1"},
+            {"all": "http://127.0.0.1"},
+        ),
+        (
+            {"https_proxy": "http://127.0.0.1", "HTTPS_PROXY": "https://1.1.1.1"},
+            {"https": "http://127.0.0.1"},
+        ),
+        ({"TRAVIS_APT_PROXY": "http://127.0.0.1"}, {}),
+    ],
+)
+def test_get_environment_proxies(environment, proxies):
+    os.environ.update(environment)
+
+    assert get_environment_proxies() == proxies