]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Implement HTTPProxy dispatcher (#259)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Sun, 15 Sep 2019 15:47:35 +0000 (10:47 -0500)
committerGitHub <noreply@github.com>
Sun, 15 Sep 2019 15:47:35 +0000 (10:47 -0500)
13 files changed:
httpx/__init__.py
httpx/api.py
httpx/client.py
httpx/dispatch/connection_pool.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/dispatch/proxy_http.py [new file with mode: 0644]
httpx/exceptions.py
httpx/models.py
tests/conftest.py
tests/dispatch/test_proxy_http.py [new file with mode: 0644]
tests/dispatch/utils.py
tests/models/test_url.py

index bd58a005af1596c45d33f191f252e49d5f61765c..a070d6856271e7e36bc553f924cb32d0c5ea816d 100644 (file)
@@ -22,6 +22,7 @@ from .config import (
 from .dispatch.base import AsyncDispatcher, Dispatcher
 from .dispatch.connection import HTTPConnection
 from .dispatch.connection_pool import ConnectionPool
+from .dispatch.proxy_http import HTTPProxy, HTTPProxyMode
 from .exceptions import (
     ConnectTimeout,
     CookieConflict,
@@ -30,6 +31,7 @@ from .exceptions import (
     NotRedirectResponse,
     PoolTimeout,
     ProtocolError,
+    ProxyError,
     ReadTimeout,
     RedirectBodyUnavailable,
     RedirectLoop,
@@ -90,6 +92,8 @@ __all__ = [
     "BasePoolSemaphore",
     "BaseBackgroundManager",
     "ConnectionPool",
+    "HTTPProxy",
+    "HTTPProxyMode",
     "ConnectTimeout",
     "CookieConflict",
     "DecodingError",
@@ -103,6 +107,7 @@ __all__ = [
     "ResponseClosed",
     "ResponseNotRead",
     "StreamConsumed",
+    "ProxyError",
     "Timeout",
     "TooManyRedirects",
     "WriteTimeout",
index 60d9ea7644b09736c71795592d228dcabc676060..5f0d639cda60d83db4a07eeb398f752add27d984 100644 (file)
@@ -27,7 +27,6 @@ def request(
     auth: AuthTypes = None,
     timeout: TimeoutTypes = None,
     allow_redirects: bool = True,
-    # proxies
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     stream: bool = False,
index 430f072442e6a22518ce5fcb4abaf38fff8d1fff..ea6dae6e1794c9ca02b71b180b02ace7c07cc4cc 100644 (file)
@@ -334,7 +334,7 @@ class AsyncClient(BaseClient):
         cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
-        allow_redirects: bool = False,  #  Note: Differs to usual default.
+        allow_redirects: bool = False,  # NOTE: Differs to usual default.
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
@@ -784,7 +784,7 @@ class Client(BaseClient):
         cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
-        allow_redirects: bool = False,  #  Note: Differs to usual default.
+        allow_redirects: bool = False,  # NOTE: Differs to usual default.
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
index befa88cb400713ae49690514124e21071feb09a6..2cc11c9884e61d856f4b2314bb349c60759d1684 100644 (file)
@@ -128,14 +128,8 @@ class ConnectionPool(AsyncDispatcher):
         return response
 
     async def acquire_connection(self, origin: Origin) -> HTTPConnection:
-        logger.debug(f"acquire_connection origin={origin!r}")
-        connection = self.active_connections.pop_by_origin(origin, http2_only=True)
-        if connection is None:
-            connection = self.keepalive_connections.pop_by_origin(origin)
-
-        if connection is not None and connection.is_connection_dropped():
-            self.max_connections.release()
-            connection = None
+        logger.debug("acquire_connection origin={origin!r}")
+        connection = self.pop_connection(origin)
 
         if connection is None:
             await self.max_connections.acquire()
@@ -179,3 +173,14 @@ class ConnectionPool(AsyncDispatcher):
         self.keepalive_connections.clear()
         for connection in connections:
             await connection.close()
+
+    def pop_connection(self, origin: Origin) -> typing.Optional[HTTPConnection]:
+        connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+        if connection is None:
+            connection = self.keepalive_connections.pop_by_origin(origin)
+
+        if connection is not None and connection.is_connection_dropped():
+            self.max_connections.release()
+            connection = None
+
+        return connection
index 7452b226bad48a403058b0358d72b19c34289d5c..cfd227f0787feddb3d7fb91ef70d90f319fb09f2 100644 (file)
@@ -153,7 +153,7 @@ class HTTP11Connection:
             if isinstance(event, h11.Data):
                 yield bytes(event.data)
             else:
-                assert isinstance(event, h11.EndOfMessage)
+                assert isinstance(event, h11.EndOfMessage) or event is h11.PAUSED
                 break  # pragma: no cover
 
     async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
index fd40964d265432e26762d18fe0945d8d5cc2d0c7..cf518f1b6ea74e31f698ab5134b19fe25499f86d 100644 (file)
@@ -35,7 +35,7 @@ class HTTP2Connection:
     ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
-        #  Start sending the request.
+        # Start sending the request.
         if not self.initialized:
             self.initiate_connection()
 
diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py
new file mode 100644 (file)
index 0000000..fe8bbd3
--- /dev/null
@@ -0,0 +1,251 @@
+import enum
+
+import h11
+
+from ..concurrency.base import ConcurrencyBackend
+from ..config import (
+    DEFAULT_POOL_LIMITS,
+    DEFAULT_TIMEOUT_CONFIG,
+    CertTypes,
+    PoolLimits,
+    SSLConfig,
+    TimeoutTypes,
+    VerifyTypes,
+)
+from ..exceptions import ProxyError
+from ..middleware.basic_auth import build_basic_auth_header
+from ..models import (
+    URL,
+    AsyncRequest,
+    AsyncResponse,
+    Headers,
+    HeaderTypes,
+    Origin,
+    URLTypes,
+)
+from ..utils import get_logger
+from .connection import HTTPConnection
+from .connection_pool import ConnectionPool
+from .http2 import HTTP2Connection
+from .http11 import HTTP11Connection
+
+logger = get_logger(__name__)
+
+
+class HTTPProxyMode(enum.Enum):
+    DEFAULT = "DEFAULT"
+    FORWARD_ONLY = "FORWARD_ONLY"
+    TUNNEL_ONLY = "TUNNEL_ONLY"
+
+
+class HTTPProxy(ConnectionPool):
+    """A proxy that sends requests to the recipient server
+    on behalf of the connecting client.
+    """
+
+    def __init__(
+        self,
+        proxy_url: URLTypes,
+        *,
+        proxy_headers: HeaderTypes = None,
+        proxy_mode: HTTPProxyMode = HTTPProxyMode.DEFAULT,
+        verify: VerifyTypes = True,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
+        pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+        backend: ConcurrencyBackend = None,
+    ):
+
+        super(HTTPProxy, self).__init__(
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            pool_limits=pool_limits,
+            backend=backend,
+        )
+
+        self.proxy_url = URL(proxy_url)
+        self.proxy_mode = proxy_mode
+        self.proxy_headers = Headers(proxy_headers)
+
+        url = self.proxy_url
+        if url.username or url.password:
+            self.proxy_headers.setdefault(
+                "Proxy-Authorization",
+                build_basic_auth_header(url.username, url.password),
+            )
+            # Remove userinfo from the URL authority, e.g.:
+            # 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
+            credentials, _, authority = url.authority.rpartition("@")
+            self.proxy_url = url.copy_with(authority=authority)
+
+    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
+        if self.should_forward_origin(origin):
+            logger.debug(
+                f"forward_connection proxy_url={self.proxy_url!r} origin={origin!r}"
+            )
+            return await super().acquire_connection(self.proxy_url.origin)
+        else:
+            logger.debug(
+                f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
+            )
+            return await self.tunnel_connection(origin)
+
+    async def tunnel_connection(self, origin: Origin) -> HTTPConnection:
+        """Creates a new HTTPConnection via the CONNECT method
+        usually reserved for proxying HTTPS connections.
+        """
+        connection = self.pop_connection(origin)
+
+        if connection is None:
+            connection = await self.request_tunnel_proxy_connection(origin)
+
+            # After we receive the 2XX response from the proxy that our
+            # tunnel is open we switch the connection's origin
+            # to the original so the tunnel can be re-used.
+            self.active_connections.remove(connection)
+            connection.origin = origin
+            self.active_connections.add(connection)
+
+            await self.tunnel_start_tls(origin, connection)
+        else:
+            self.active_connections.add(connection)
+
+        return connection
+
+    async def request_tunnel_proxy_connection(self, origin: Origin) -> HTTPConnection:
+        """Creates an HTTPConnection by setting up a TCP tunnel"""
+        proxy_headers = self.proxy_headers.copy()
+        proxy_headers.setdefault("Accept", "*/*")
+        proxy_request = AsyncRequest(
+            method="CONNECT", url=self.proxy_url.copy_with(), headers=proxy_headers
+        )
+        proxy_request.url.full_path = f"{origin.host}:{origin.port}"
+
+        await self.max_connections.acquire()
+
+        connection = HTTPConnection(
+            self.proxy_url.origin,
+            verify=self.verify,
+            cert=self.cert,
+            timeout=self.timeout,
+            backend=self.backend,
+            http_versions=["HTTP/1.1"],  # Short-lived 'connection'
+            trust_env=self.trust_env,
+            release_func=self.release_connection,
+        )
+        self.active_connections.add(connection)
+
+        # See if our tunnel has been opened successfully
+        proxy_response = await connection.send(proxy_request)
+        logger.debug(
+            f"tunnel_response "
+            f"proxy_url={self.proxy_url!r} "
+            f"origin={origin!r} "
+            f"response={proxy_response!r}"
+        )
+        if not (200 <= proxy_response.status_code <= 299):
+            await proxy_response.read()
+            raise ProxyError(
+                f"Non-2XX response received from HTTP proxy "
+                f"({proxy_response.status_code})",
+                request=proxy_request,
+                response=proxy_response,
+            )
+        else:
+            proxy_response.on_close = None
+            await proxy_response.read()
+
+        return connection
+
+    async def tunnel_start_tls(
+        self, origin: Origin, connection: HTTPConnection
+    ) -> None:
+        """Runs start_tls() on a TCP-tunneled connection"""
+
+        # Store this information here so that we can transfer
+        # it to the new internal connection object after
+        # the old one goes to 'SWITCHED_PROTOCOL'.
+        http_version = "HTTP/1.1"
+        http_connection = connection.h11_connection
+        assert http_connection is not None
+        assert http_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
+        on_release = http_connection.on_release
+        stream = http_connection.stream
+
+        # If we need to start TLS again for the target server
+        # we need to pull the TCP stream off the internal
+        # HTTP connection object and run start_tls()
+        if origin.is_ssl:
+            ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
+            timeout = connection.timeout
+            ssl_context = await connection.get_ssl_context(ssl_config)
+            assert ssl_context is not None
+
+            logger.debug(
+                f"tunnel_start_tls "
+                f"proxy_url={self.proxy_url!r} "
+                f"origin={origin!r}"
+            )
+            stream = await self.backend.start_tls(
+                stream=stream,
+                hostname=origin.host,
+                ssl_context=ssl_context,
+                timeout=timeout,
+            )
+            http_version = stream.get_http_version()
+            logger.debug(
+                f"tunnel_tls_complete "
+                f"proxy_url={self.proxy_url!r} "
+                f"origin={origin!r} "
+                f"http_version={http_version!r}"
+            )
+
+        if http_version == "HTTP/2":
+            connection.h2_connection = HTTP2Connection(
+                stream, self.backend, on_release=on_release
+            )
+        else:
+            assert http_version == "HTTP/1.1"
+            connection.h11_connection = HTTP11Connection(
+                stream, self.backend, on_release=on_release
+            )
+
+    def should_forward_origin(self, origin: Origin) -> bool:
+        """Determines if the given origin should
+        be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'
+        then the proxy will forward all 'HTTP' requests and
+        tunnel all 'HTTPS' requests.
+        """
+        return (
+            self.proxy_mode == HTTPProxyMode.DEFAULT and not origin.is_ssl
+        ) or self.proxy_mode == HTTPProxyMode.FORWARD_ONLY
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+
+        if self.should_forward_origin(request.url.origin):
+            # Change the request to have the target URL
+            # as its full_path and switch the proxy URL
+            # for where the request will be sent.
+            target_url = str(request.url)
+            request.url = self.proxy_url.copy_with()
+            request.url.full_path = target_url
+            for name, value in self.proxy_headers.items():
+                request.headers.setdefault(name, value)
+
+        return await super().send(
+            request=request, verify=verify, cert=cert, timeout=timeout
+        )
+
+    def __repr__(self) -> str:
+        return (
+            f"HTTPProxy(proxy_url={self.proxy_url!r} "
+            f"proxy_headers={self.proxy_headers!r} "
+            f"proxy_mode={self.proxy_mode!r})"
+        )
index 42531dc79b91b953feaf43edda53d573e1bd8143..81df38b79d7d1990b45fbcc5991667b407d2ab2f 100644 (file)
@@ -53,6 +53,12 @@ class PoolTimeout(Timeout):
     """
 
 
+class ProxyError(HTTPError):
+    """
+    Error from within a proxy
+    """
+
+
 # HTTP exceptions...
 
 
index 29fc1b931a07e5636edd2a56df8ef4c34509bcbe..e60453bd72cc9d53422f8d4449721aa7076ff4ff 100644 (file)
@@ -115,6 +115,10 @@ class URL:
             if not self.host:
                 raise InvalidURL("No host included in URL.")
 
+        # Allow setting full_path to custom attributes requests
+        # like OPTIONS, CONNECT, and forwarding proxy requests.
+        self._full_path: typing.Optional[str] = None
+
     @property
     def scheme(self) -> str:
         return self._uri_reference.scheme or ""
@@ -154,11 +158,17 @@ class URL:
 
     @property
     def full_path(self) -> str:
+        if self._full_path is not None:
+            return self._full_path
         path = self.path
         if self.query:
             path += "?" + self.query
         return path
 
+    @full_path.setter
+    def full_path(self, value: typing.Optional[str]) -> None:
+        self._full_path = value
+
     @property
     def fragment(self) -> str:
         return self._uri_reference.fragment or ""
@@ -426,6 +436,9 @@ class Headers(typing.MutableMapping[str, str]):
         for header in headers:
             self[header] = headers[header]
 
+    def copy(self) -> "Headers":
+        return Headers(self.items(), encoding=self.encoding)
+
     def __getitem__(self, key: str) -> str:
         """
         Return a single header value.
index 4f8a11a57b38844f7db968088dcdf820d6a49383..71f3007c214b786904bf8a4179d371e971ff1f2e 100644 (file)
@@ -17,24 +17,29 @@ from uvicorn.main import Server
 
 from httpx import URL, AsyncioBackend
 
-ENVIRONMENT_VARIABLES = (
+ENVIRONMENT_VARIABLES = {
     "SSL_CERT_FILE",
-    "REQUESTS_CA_BUNDLE",
-    "CURL_CA_BUNDLE",
+    "SSL_CERT_DIR",
     "HTTP_PROXY",
     "HTTPS_PROXY",
     "ALL_PROXY",
     "NO_PROXY",
     "SSLKEYLOGFILE",
-)
+}
 
 
 @pytest.fixture(scope="function", autouse=True)
 def clean_environ() -> typing.Dict[str, typing.Any]:
     """Keeps os.environ clean for every test without having to mock os.environ"""
     original_environ = os.environ.copy()
-    for key in ENVIRONMENT_VARIABLES:
-        os.environ.pop(key, None)
+    os.environ.clear()
+    os.environ.update(
+        {
+            k: v
+            for k, v in original_environ.items()
+            if k not in ENVIRONMENT_VARIABLES and k.lower() not in ENVIRONMENT_VARIABLES
+        }
+    )
     yield
     os.environ.clear()
     os.environ.update(original_environ)
diff --git a/tests/dispatch/test_proxy_http.py b/tests/dispatch/test_proxy_http.py
new file mode 100644 (file)
index 0000000..0302f46
--- /dev/null
@@ -0,0 +1,181 @@
+import pytest
+
+import httpx
+
+from .utils import MockRawSocketBackend
+
+
+async def test_proxy_tunnel_success(backend):
+    raw_io = MockRawSocketBackend(
+        data_to_send=(
+            [
+                b"HTTP/1.1 200 OK\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: proxy-server\r\n"
+                b"\r\n",
+                b"HTTP/1.1 404 Not Found\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: origin-server\r\n"
+                b"\r\n",
+            ]
+        ),
+        backend=backend,
+    )
+    async with httpx.HTTPProxy(
+        proxy_url="http://127.0.0.1:8000",
+        backend=raw_io,
+        proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+    ) as proxy:
+        response = await proxy.request("GET", f"http://example.com")
+
+        assert response.status_code == 404
+        assert response.headers["Server"] == "origin-server"
+
+        assert response.request.method == "GET"
+        assert response.request.url == "http://example.com"
+        assert response.request.headers["Host"] == "example.com"
+
+    recv = raw_io.received_data
+    assert len(recv) == 3
+    assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+    assert recv[1].startswith(
+        b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+    )
+    assert recv[2].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
+
+
+@pytest.mark.parametrize("status_code", [300, 304, 308, 401, 500])
+async def test_proxy_tunnel_non_2xx_response(backend, status_code):
+    raw_io = MockRawSocketBackend(
+        data_to_send=(
+            [
+                b"HTTP/1.1 %d Not Good\r\n" % status_code,
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: proxy-server\r\n"
+                b"\r\n",
+            ]
+        ),
+        backend=backend,
+    )
+
+    with pytest.raises(httpx.ProxyError) as e:
+        async with httpx.HTTPProxy(
+            proxy_url="http://127.0.0.1:8000",
+            backend=raw_io,
+            proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+        ) as proxy:
+            await proxy.request("GET", f"http://example.com")
+
+    # ProxyError.request should be the CONNECT request not the original request
+    assert e.value.request.method == "CONNECT"
+    assert e.value.request.headers["Host"] == "127.0.0.1:8000"
+    assert e.value.request.url.full_path == "example.com:80"
+
+    # ProxyError.response should be the CONNECT response
+    assert e.value.response.status_code == status_code
+    assert e.value.response.headers["Server"] == "proxy-server"
+
+    # Verify that the request wasn't sent after receiving an error from CONNECT
+    recv = raw_io.received_data
+    assert len(recv) == 2
+    assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+    assert recv[1].startswith(
+        b"CONNECT example.com:80 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+    )
+
+
+async def test_proxy_tunnel_start_tls(backend):
+    raw_io = MockRawSocketBackend(
+        data_to_send=(
+            [
+                b"HTTP/1.1 200 OK\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: proxy-server\r\n"
+                b"\r\n",
+                b"HTTP/1.1 404 Not Found\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: origin-server\r\n"
+                b"\r\n",
+            ]
+        ),
+        backend=backend,
+    )
+    async with httpx.HTTPProxy(
+        proxy_url="http://127.0.0.1:8000",
+        backend=raw_io,
+        proxy_mode=httpx.HTTPProxyMode.TUNNEL_ONLY,
+    ) as proxy:
+        response = await proxy.request("GET", f"https://example.com")
+
+        assert response.status_code == 404
+        assert response.headers["Server"] == "origin-server"
+
+        assert response.request.method == "GET"
+        assert response.request.url == "https://example.com"
+        assert response.request.headers["Host"] == "example.com"
+
+    recv = raw_io.received_data
+    assert len(recv) == 4
+    assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+    assert recv[1].startswith(
+        b"CONNECT example.com:443 HTTP/1.1\r\nhost: 127.0.0.1:8000\r\n"
+    )
+    assert recv[2] == b"--- START_TLS(example.com) ---"
+    assert recv[3].startswith(b"GET / HTTP/1.1\r\nhost: example.com\r\n")
+
+
+@pytest.mark.parametrize(
+    "proxy_mode", [httpx.HTTPProxyMode.FORWARD_ONLY, httpx.HTTPProxyMode.DEFAULT]
+)
+async def test_proxy_forwarding(backend, proxy_mode):
+    raw_io = MockRawSocketBackend(
+        data_to_send=(
+            [
+                b"HTTP/1.1 200 OK\r\n"
+                b"Date: Sun, 10 Oct 2010 23:26:07 GMT\r\n"
+                b"Server: origin-server\r\n"
+                b"\r\n"
+            ]
+        ),
+        backend=backend,
+    )
+    async with httpx.HTTPProxy(
+        proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode=proxy_mode
+    ) as proxy:
+        response = await proxy.request("GET", f"http://example.com")
+
+        assert response.status_code == 200
+        assert response.headers["Server"] == "origin-server"
+
+        assert response.request.method == "GET"
+        assert response.request.url == "http://127.0.0.1:8000"
+        assert response.request.url.full_path == "http://example.com"
+        assert response.request.headers["Host"] == "example.com"
+
+    recv = raw_io.received_data
+    assert len(recv) == 2
+    assert recv[0] == b"--- CONNECT(127.0.0.1, 8000) ---"
+    assert recv[1].startswith(
+        b"GET http://example.com HTTP/1.1\r\nhost: example.com\r\n"
+    )
+
+
+def test_proxy_url_with_username_and_password():
+    proxy = httpx.HTTPProxy("http://user:password@example.com:1080")
+
+    assert proxy.proxy_url == "http://example.com:1080"
+    assert proxy.proxy_headers["Proxy-Authorization"] == "Basic dXNlcjpwYXNzd29yZA=="
+
+
+def test_proxy_repr():
+    proxy = httpx.HTTPProxy(
+        "http://127.0.0.1:1080",
+        proxy_headers={"Custom": "Header"},
+        proxy_mode=httpx.HTTPProxyMode.DEFAULT,
+    )
+
+    assert repr(proxy) == (
+        "HTTPProxy(proxy_url=URL('http://127.0.0.1:1080') "
+        "proxy_headers=Headers({'custom': 'Header'}) "
+        "proxy_mode=<HTTPProxyMode.DEFAULT: 'DEFAULT'>)"
+    )
index a9ab231712699f2dc87d38dc2672c5013fb72d8e..c7cbedb52fba0a1b728b811f8bf0d501c674b9ed 100644 (file)
@@ -160,3 +160,61 @@ class MockHTTP2Server(BaseStream):
         self.returning[stream_id] = False
         self.conn.end_stream(stream_id)
         self.buffer += self.conn.data_to_send()
+
+
+class MockRawSocketBackend:
+    def __init__(self, data_to_send=b"", backend=None):
+        self.backend = AsyncioBackend() if backend is None else backend
+        self.data_to_send = data_to_send
+        self.received_data = []
+        self.stream = MockRawSocketStream(self)
+
+    async def connect(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> BaseStream:
+        self.received_data.append(
+            b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
+        )
+        return self.stream
+
+    async def start_tls(
+        self,
+        stream: BaseStream,
+        hostname: str,
+        ssl_context: ssl.SSLContext,
+        timeout: TimeoutConfig,
+    ) -> BaseStream:
+        self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
+        return self.stream
+
+    # Defer all other attributes and methods to the underlying backend.
+    def __getattr__(self, name: str) -> typing.Any:
+        return getattr(self.backend, name)
+
+
+class MockRawSocketStream(BaseStream):
+    def __init__(self, backend: MockRawSocketBackend):
+        self.backend = backend
+
+    def get_http_version(self) -> str:
+        return "HTTP/1.1"
+
+    def write_no_block(self, data: bytes) -> None:
+        self.backend.received_data.append(data)
+
+    async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
+        if data:
+            self.write_no_block(data)
+
+    async def read(self, n, timeout, flag=None) -> bytes:
+        await sleep(self.backend.backend, 0)
+        if not self.backend.data_to_send:
+            return b""
+        return self.backend.data_to_send.pop(0)
+
+    async def close(self) -> None:
+        pass
index 0a8efe12b69faf4736022de9e02425a8a979ff76..0bba256690ec31a7e161a6786126e3340965e916 100644 (file)
@@ -177,6 +177,13 @@ def test_url_set():
     assert all(url in urls for url in url_set)
 
 
+def test_url_full_path_setter():
+    url = URL("http://example.org")
+
+    url.full_path = "http://example.net"
+    assert url.full_path == "http://example.net"
+
+
 def test_origin_from_url_string():
     origin = Origin("https://example.com")
     assert origin.scheme == "https"