from .dispatch.base import AsyncDispatcher, Dispatcher
from .dispatch.connection import HTTPConnection
from .dispatch.connection_pool import ConnectionPool
+from .dispatch.proxy_http import HTTPProxy, HTTPProxyMode
from .exceptions import (
ConnectTimeout,
CookieConflict,
NotRedirectResponse,
PoolTimeout,
ProtocolError,
+ ProxyError,
ReadTimeout,
RedirectBodyUnavailable,
RedirectLoop,
"BasePoolSemaphore",
"BaseBackgroundManager",
"ConnectionPool",
+ "HTTPProxy",
+ "HTTPProxyMode",
"ConnectTimeout",
"CookieConflict",
"DecodingError",
"ResponseClosed",
"ResponseNotRead",
"StreamConsumed",
+ "ProxyError",
"Timeout",
"TooManyRedirects",
"WriteTimeout",
auth: AuthTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
- # proxies
cert: CertTypes = None,
verify: VerifyTypes = True,
stream: bool = False,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
- allow_redirects: bool = False, # Note: Differs to usual default.
+ allow_redirects: bool = False, # NOTE: Differs to usual default.
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
- allow_redirects: bool = False, # Note: Differs to usual default.
+ allow_redirects: bool = False, # NOTE: Differs to usual default.
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
return response
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
- logger.debug(f"acquire_connection origin={origin!r}")
- connection = self.active_connections.pop_by_origin(origin, http2_only=True)
- if connection is None:
- connection = self.keepalive_connections.pop_by_origin(origin)
-
- if connection is not None and connection.is_connection_dropped():
- self.max_connections.release()
- connection = None
+ logger.debug("acquire_connection origin={origin!r}")
+ connection = self.pop_connection(origin)
if connection is None:
await self.max_connections.acquire()
self.keepalive_connections.clear()
for connection in connections:
await connection.close()
+
+ def pop_connection(self, origin: Origin) -> typing.Optional[HTTPConnection]:
+ connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+ if connection is None:
+ connection = self.keepalive_connections.pop_by_origin(origin)
+
+ if connection is not None and connection.is_connection_dropped():
+ self.max_connections.release()
+ connection = None
+
+ return connection
if isinstance(event, h11.Data):
yield bytes(event.data)
else:
- assert isinstance(event, h11.EndOfMessage)
+ assert isinstance(event, h11.EndOfMessage) or event is h11.PAUSED
break # pragma: no cover
async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
- # Start sending the request.
+ # Start sending the request.
if not self.initialized:
self.initiate_connection()
--- /dev/null
+import enum
+
+import h11
+
+from ..concurrency.base import ConcurrencyBackend
+from ..config import (
+ DEFAULT_POOL_LIMITS,
+ DEFAULT_TIMEOUT_CONFIG,
+ CertTypes,
+ PoolLimits,
+ SSLConfig,
+ TimeoutTypes,
+ VerifyTypes,
+)
+from ..exceptions import ProxyError
+from ..middleware.basic_auth import build_basic_auth_header
+from ..models import (
+ URL,
+ AsyncRequest,
+ AsyncResponse,
+ Headers,
+ HeaderTypes,
+ Origin,
+ URLTypes,
+)
+from ..utils import get_logger
+from .connection import HTTPConnection
+from .connection_pool import ConnectionPool
+from .http2 import HTTP2Connection
+from .http11 import HTTP11Connection
+
+logger = get_logger(__name__)
+
+
+class HTTPProxyMode(enum.Enum):
+ DEFAULT = "DEFAULT"
+ FORWARD_ONLY = "FORWARD_ONLY"
+ TUNNEL_ONLY = "TUNNEL_ONLY"
+
+
+class HTTPProxy(ConnectionPool):
+ """A proxy that sends requests to the recipient server
+ on behalf of the connecting client.
+ """
+
+ def __init__(
+ self,
+ proxy_url: URLTypes,
+ *,
+ proxy_headers: HeaderTypes = None,
+ proxy_mode: HTTPProxyMode = HTTPProxyMode.DEFAULT,
+ verify: VerifyTypes = True,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
+ pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+ backend: ConcurrencyBackend = None,
+ ):
+
+ super(HTTPProxy, self).__init__(
+ verify=verify,
+ cert=cert,
+ timeout=timeout,
+ pool_limits=pool_limits,
+ backend=backend,
+ )
+
+ self.proxy_url = URL(proxy_url)
+ self.proxy_mode = proxy_mode
+ self.proxy_headers = Headers(proxy_headers)
+
+ url = self.proxy_url
+ if url.username or url.password:
+ self.proxy_headers.setdefault(
+ "Proxy-Authorization",
+ build_basic_auth_header(url.username, url.password),
+ )
+ # Remove userinfo from the URL authority, e.g.:
+ # 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
+ credentials, _, authority = url.authority.rpartition("@")
+ self.proxy_url = url.copy_with(authority=authority)
+
+ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+ if self.should_forward_origin(origin):
+ logger.debug(
+ f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}"
+ )
+ return await super().acquire_connection(self.proxy_url.origin)
+ else:
+ logger.debug(
+ f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
+ )
+ return await self.tunnel_connection(origin)
+
+ async def tunnel_connection(self, origin: Origin) -> HTTPConnection:
+ """Creates a new HTTPConnection via the CONNECT method
+ usually reserved for proxying HTTPS connections.
+ """
+ connection = self.pop_connection(origin)
+
+ if connection is None:
+ connection = await self.request_tunnel_proxy_connection(origin)
+
+ # After we receive the 2XX response from the proxy that our
+ # tunnel is open we switch the connection's origin
+ # to the original so the tunnel can be re-used.
+ self.active_connections.remove(connection)
+ connection.origin = origin
+ self.active_connections.add(connection)
+
+ await self.tunnel_start_tls(origin, connection)
+ else:
+ self.active_connections.add(connection)
+
+ return connection
+
+ async def request_tunnel_proxy_connection(self, origin: Origin) -> HTTPConnection:
+ """Creates an HTTPConnection by setting up a TCP tunnel"""
+ proxy_headers = self.proxy_headers.copy()
+ proxy_headers.setdefault("Accept", "*/*")
+ proxy_request = AsyncRequest(
+ method="CONNECT", url=self.proxy_url.copy_with(), headers=proxy_headers
+ )
+ proxy_request.url.full_path = f"{origin.host}:{origin.port}"
+
+ await self.max_connections.acquire()
+
+ connection = HTTPConnection(
+ self.proxy_url.origin,
+ verify=self.verify,
+ cert=self.cert,
+ timeout=self.timeout,
+ backend=self.backend,
+ http_versions=["HTTP/1.1"], # Short-lived 'connection'
+ trust_env=self.trust_env,
+ release_func=self.release_connection,
+ )
+ self.active_connections.add(connection)
+
+ # See if our tunnel has been opened successfully
+ proxy_response = await connection.send(proxy_request)
+ logger.debug(
+ f"tunnel_response "
+ f"proxy_url={self.proxy_url!r} "
+ f"origin={origin!r} "
+ f"response={proxy_response!r}"
+ )
+ if not (200 <= proxy_response.status_code <= 299):
+ await proxy_response.read()
+ raise ProxyError(
+ f"Non-2XX response received from HTTP proxy "
+ f"({proxy_response.status_code})",
+ request=proxy_request,
+ response=proxy_response,
+ )
+ else:
+ proxy_response.on_close = None
+ await proxy_response.read()
+
+ return connection
+
+ async def tunnel_start_tls(
+ self, origin: Origin, connection: HTTPConnection
+ ) -> None:
+ """Runs start_tls() on a TCP-tunneled connection"""
+
+ # Store this information here so that we can transfer
+ # it to the new internal connection object after
+ # the old one goes to 'SWITCHED_PROTOCOL'.
+ http_version = "HTTP/1.1"
+ http_connection = connection.h11_connection
+ assert http_connection is not None
+ assert http_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
+ on_release = http_connection.on_release
+ stream = http_connection.stream
+
+ # If we need to start TLS again for the target server
+ # we need to pull the TCP stream off the internal
+ # HTTP connection object and run start_tls()
+ if origin.is_ssl:
+ ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
+ timeout = connection.timeout
+ ssl_context = await connection.get_ssl_context(ssl_config)
+ assert ssl_context is not None
+
+ logger.debug(
+ f"tunnel_start_tls "
+ f"proxy_url={self.proxy_url!r} "
+ f"origin={origin!r}"
+ )
+ stream = await self.backend.start_tls(
+ stream=stream,
+ hostname=origin.host,
+ ssl_context=ssl_context,
+ timeout=timeout,
+ )
+ http_version = stream.get_http_version()
+ logger.debug(
+ f"tunnel_tls_complete "
+ f"proxy_url={self.proxy_url!r} "
+ f"origin={origin!r} "
+ f"http_version={http_version!r}"
+ )
+
+ if http_version == "HTTP/2":
+ connection.h2_connection = HTTP2Connection(
+ stream, self.backend, on_release=on_release
+ )
+ else:
+ assert http_version == "HTTP/1.1"
+ connection.h11_connection = HTTP11Connection(
+ stream, self.backend, on_release=on_release
+ )
+
+ def should_forward_origin(self, origin: Origin) -> bool:
+ """Determines if the given origin should
+ be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'
+ then the proxy will forward all 'HTTP' requests and
+ tunnel all 'HTTPS' requests.
+ """
+ return (
+ self.proxy_mode == HTTPProxyMode.DEFAULT and not origin.is_ssl
+ ) or self.proxy_mode == HTTPProxyMode.FORWARD_ONLY
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+
+ if self.should_forward_origin(request.url.origin):
+ # Change the request to have the target URL
+ # as its full_path and switch the proxy URL
+ # for where the request will be sent.
+ target_url = str(request.url)
+ request.url = self.proxy_url.copy_with()
+ request.url.full_path = target_url
+ for name, value in self.proxy_headers.items():
+ request.headers.setdefault(name, value)
+
+ return await super().send(
+ request=request, verify=verify, cert=cert, timeout=timeout
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"HTTPProxy(proxy_url={self.proxy_url!r} "
+ f"proxy_headers={self.proxy_headers!r} "
+ f"proxy_mode={self.proxy_mode!r})"
+ )
"""
+class ProxyError(HTTPError):
+ """
+ Error from within a proxy
+ """
+
+
# HTTP exceptions...
if not self.host:
raise InvalidURL("No host included in URL.")
+ # Allow setting full_path to custom attributes requests
+ # like OPTIONS, CONNECT, and forwarding proxy requests.
+ self._full_path: typing.Optional[str] = None
+
@property
def scheme(self) -> str:
return self._uri_reference.scheme or ""
@property
def full_path(self) -> str:
+ if self._full_path is not None:
+ return self._full_path
path = self.path
if self.query:
path += "?" + self.query
return path
+ @full_path.setter
+ def full_path(self, value: typing.Optional[str]) -> None:
+ self._full_path = value
+
@property
def fragment(self) -> str:
return self._uri_reference.fragment or ""
for header in headers:
self[header] = headers[header]
+ def copy(self) -> "Headers":
+ return Headers(self.items(), encoding=self.encoding)
+
def __getitem__(self, key: str) -> str:
"""
Return a single header value.
from httpx import URL, AsyncioBackend
-ENVIRONMENT_VARIABLES = (
+ENVIRONMENT_VARIABLES = {
"SSL_CERT_FILE",
- "REQUESTS_CA_BUNDLE",
- "CURL_CA_BUNDLE",
+ "SSL_CERT_DIR",
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"NO_PROXY",
"SSLKEYLOGFILE",
-)
+}
@pytest.fixture(scope="function", autouse=True)
def clean_environ() -> typing.Dict[str, typing.Any]:
"""Keeps os.environ clean for every test without having to mock os.environ"""
original_environ = os.environ.copy()
- for key in ENVIRONMENT_VARIABLES:
- os.environ.pop(key, None)
+ os.environ.clear()
+ os.environ.update(
+ {
+ k: v
+ for k, v in original_environ.items()
+ if k not in ENVIRONMENT_VARIABLES and k.lower() not in ENVIRONMENT_VARIABLES
+ }
+ )
yield
os.environ.clear()
os.environ.update(original_environ)
--- /dev/null
+import pytest
+
+import httpx
+
+from .utils import MockRawSocketBackend
+
+
+async def test_proxy_tunnel_success(backend):
+ raw_io = MockRawSocketBackend(
+ data_to_send=(
+ [
+ b"HTTP/1.1 200 OK\r\n"
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: proxy-server\r\n"
+ b"\r\n",
+ b"HTTP/1.1 404 Not Found\r\n"
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: origin-server\r\n"
+ b"\r\n",
+ ]
+ ),
+ backend=backend,
+ )
+ async with httpx.HTTPProxy(
+ proxy_url="http://127.0.0.1:8000",
+ backend=raw_io,
+ proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+ ) as proxy:
+ response = await proxy.request("GET", f"http://example.com")
+
+ assert response.status_code == 404
+ assert response.headers["Server"] == "origin-server"
+
+ assert response.request.method == "GET"
+ assert response.request.url == "http://example.com"
+ assert response.request.headers["Host"] == "example.com"
+
+ recv = raw_io.received_data
+ assert len(recv) == 3
+ assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+ assert recv[1].startswith(
+ b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+ )
+ assert recv[2].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
+
+
+@pytest.mark.parametrize("status_code", [300, 304, 308, 401, 500])
+async def test_proxy_tunnel_non_2xx_response(backend, status_code):
+ raw_io = MockRawSocketBackend(
+ data_to_send=(
+ [
+ b"HTTP/1.1 %d Not Good\r\n" % status_code,
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: proxy-server\r\n"
+ b"\r\n",
+ ]
+ ),
+ backend=backend,
+ )
+
+ with pytest.raises(httpx.ProxyError) as e:
+ async with httpx.HTTPProxy(
+ proxy_url="http://127.0.0.1:8000",
+ backend=raw_io,
+ proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+ ) as proxy:
+ await proxy.request("GET", f"http://example.com")
+
+ # ProxyError.request should be the CONNECT request not the original request
+ assert e.value.request.method == "CONNECT"
+ assert e.value.request.headers["Host"] == "127.0.0.1:8000"
+ assert e.value.request.url.full_path == "example.com:80"
+
+ # ProxyError.response should be the CONNECT response
+ assert e.value.response.status_code == status_code
+ assert e.value.response.headers["Server"] == "proxy-server"
+
+ # Verify that the request wasn't sent after receiving an error from CONNECT
+ recv = raw_io.received_data
+ assert len(recv) == 2
+ assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+ assert recv[1].startswith(
+ b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+ )
+
+
+async def test_proxy_tunnel_start_tls(backend):
+ raw_io = MockRawSocketBackend(
+ data_to_send=(
+ [
+ b"HTTP/1.1 200 OK\r\n"
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: proxy-server\r\n"
+ b"\r\n",
+ b"HTTP/1.1 404 Not Found\r\n"
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: origin-server\r\n"
+ b"\r\n",
+ ]
+ ),
+ backend=backend,
+ )
+ async with httpx.HTTPProxy(
+ proxy_url="http://127.0.0.1:8000",
+ backend=raw_io,
+ proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+ ) as proxy:
+ response = await proxy.request("GET", f"https://example.com")
+
+ assert response.status_code == 404
+ assert response.headers["Server"] == "origin-server"
+
+ assert response.request.method == "GET"
+ assert response.request.url == "https://example.com"
+ assert response.request.headers["Host"] == "example.com"
+
+ recv = raw_io.received_data
+ assert len(recv) == 4
+ assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+ assert recv[1].startswith(
+ b"CONNECT example.com:443 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+ )
+ assert recv[2] == b"--- START_TLS(example.com) ---"
+ assert recv[3].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
+
+
+@pytest.mark.parametrize(
+ "proxy_mode", [httpx.HTTPProxyMode.FORWARD_ONLY, httpx.HTTPProxyMode.DEFAULT]
+)
+async def test_proxy_forwarding(backend, proxy_mode):
+ raw_io = MockRawSocketBackend(
+ data_to_send=(
+ [
+ b"HTTP/1.1 200 OK\r\n"
+ b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+ b"Server: origin-server\r\n"
+ b"\r\n"
+ ]
+ ),
+ backend=backend,
+ )
+ async with httpx.HTTPProxy(
+ proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode=proxy_mode
+ ) as proxy:
+ response = await proxy.request("GET", f"http://example.com")
+
+ assert response.status_code == 200
+ assert response.headers["Server"] == "origin-server"
+
+ assert response.request.method == "GET"
+ assert response.request.url == "http://127.0.0.1:8000"
+ assert response.request.url.full_path == "http://example.com"
+ assert response.request.headers["Host"] == "example.com"
+
+ recv = raw_io.received_data
+ assert len(recv) == 2
+ assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+ assert recv[1].startswith(
+ b"GET http://example.com HTTP/1.1\r\nhost: example.com\r\n"
+ )
+
+
+def test_proxy_url_with_username_and_password():
+ proxy = httpx.HTTPProxy("http://user:password@example.com:1080")
+
+ assert proxy.proxy_url == "http://example.com:1080"
+ assert proxy.proxy_headers["Proxy-Authorization"] == "Basic dXNlcjpwYXNzd29yZA=="
+
+
+def test_proxy_repr():
+ proxy = httpx.HTTPProxy(
+ "http://127.0.0.1:1080",
+ proxy_headers={"Custom": "Header"},
+ proxy_mode=httpx.HTTPProxyMode.DEFAULT,
+ )
+
+ assert repr(proxy) == (
+ "HTTPProxy(proxy_url=URL('http://127.0.0.1:1080') "
+ "proxy_headers=Headers({'custom': 'Header'}) "
+ "proxy_mode=<HTTPProxyMode.DEFAULT: 'DEFAULT'>)"
+ )
self.returning[stream_id] = False
self.conn.end_stream(stream_id)
self.buffer += self.conn.data_to_send()
+
+
+class MockRawSocketBackend:
+ def __init__(self, data_to_send=b"", backend=None):
+ self.backend = AsyncioBackend() if backend is None else backend
+ self.data_to_send = data_to_send
+ self.received_data = []
+ self.stream = MockRawSocketStream(self)
+
+ async def connect(
+ self,
+ hostname: str,
+ port: int,
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> BaseStream:
+ self.received_data.append(
+ b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
+ )
+ return self.stream
+
+ async def start_tls(
+ self,
+ stream: BaseStream,
+ hostname: str,
+ ssl_context: ssl.SSLContext,
+ timeout: TimeoutConfig,
+ ) -> BaseStream:
+ self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
+ return self.stream
+
+ # Defer all other attributes and methods to the underlying backend.
+ def __getattr__(self, name: str) -> typing.Any:
+ return getattr(self.backend, name)
+
+
+class MockRawSocketStream(BaseStream):
+ def __init__(self, backend: MockRawSocketBackend):
+ self.backend = backend
+
+ def get_http_version(self) -> str:
+ return "HTTP/1.1"
+
+ def write_no_block(self, data: bytes) -> None:
+ self.backend.received_data.append(data)
+
+ async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
+ if data:
+ self.write_no_block(data)
+
+ async def read(self, n, timeout, flag=None) -> bytes:
+ await sleep(self.backend.backend, 0)
+ if not self.backend.data_to_send:
+ return b""
+ return self.backend.data_to_send.pop(0)
+
+ async def close(self) -> None:
+ pass
assert all(url in urls for url in url_set)
+def test_url_full_path_setter():
+ url = URL("http://example.org")
+
+ url.full_path = "http://example.net"
+ assert url.full_path == "http://example.net"
+
+
def test_origin_from_url_string():
origin = Origin("https://example.com")
assert origin.scheme == "https"