]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Differentiate between `timeout=None` and `timeout=UNSET`. (#592)
authorTom Christie <tom@tomchristie.com>
Wed, 4 Dec 2019 11:54:39 +0000 (11:54 +0000)
committerGitHub <noreply@github.com>
Wed, 4 Dec 2019 11:54:39 +0000 (11:54 +0000)
* TimeoutConfig -> Timeout

* Timeout=None should mean what it says.

* Drop optional timeout on internal client methods

httpx/client.py
httpx/concurrency/base.py
httpx/dispatch/base.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/dispatch/proxy_http.py
tests/client/test_proxies.py

index b845724fd67e3776aa88133e6080a73323b8c952..36f6a2404dfc36c635dfd7e23c2a734f39426701 100644 (file)
@@ -12,6 +12,7 @@ from .config import (
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
     PoolLimits,
+    Timeout,
     TimeoutTypes,
     VerifyTypes,
 )
@@ -49,6 +50,13 @@ from .utils import ElapsedTimer, NetRCInfo, get_environment_proxies, get_logger
 logger = get_logger(__name__)
 
 
+class UnsetType:
+    pass  # pragma: nocover
+
+
+UNSET = UnsetType()
+
+
 class Client:
     """
     An HTTP client, with connection pooling, HTTP/2, redirects, cookie persistence, etc.
@@ -129,7 +137,6 @@ class Client:
             dispatch = ConnectionPool(
                 verify=verify,
                 cert=cert,
-                timeout=timeout,
                 http2=http2,
                 pool_limits=pool_limits,
                 backend=backend,
@@ -149,6 +156,7 @@ class Client:
         self._params = QueryParams(params)
         self._headers = Headers(headers)
         self._cookies = Cookies(cookies)
+        self.timeout = Timeout(timeout)
         self.max_redirects = max_redirects
         self.trust_env = trust_env
         self.dispatch = dispatch
@@ -161,7 +169,6 @@ class Client:
             proxies,
             verify=verify,
             cert=cert,
-            timeout=timeout,
             http2=http2,
             pool_limits=pool_limits,
             backend=backend,
@@ -217,7 +224,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         request = self.build_request(
@@ -330,7 +337,7 @@ class Client:
         allow_redirects: bool = True,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         if request.url.scheme not in ("http", "https"):
@@ -338,6 +345,7 @@ class Client:
 
         auth = self.auth if auth is None else auth
         trust_env = self.trust_env if trust_env is None else trust_env
+        timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
 
         if not isinstance(auth, Middleware):
             request = self.authenticate(request, trust_env, auth)
@@ -390,9 +398,9 @@ class Client:
     async def send_handling_redirects(
         self,
         request: Request,
+        timeout: Timeout,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
         allow_redirects: bool = True,
         history: typing.List[Response] = None,
     ) -> Response:
@@ -522,9 +530,9 @@ class Client:
     async def send_single_request(
         self,
         request: Request,
+        timeout: Timeout,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
     ) -> Response:
         """
         Sends a single request, without handling any redirections.
@@ -592,7 +600,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -622,7 +630,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -652,7 +660,7 @@ class Client:
         allow_redirects: bool = False,  # NOTE: Differs to usual default.
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -685,7 +693,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -721,7 +729,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -757,7 +765,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -790,7 +798,7 @@ class Client:
         allow_redirects: bool = True,
         cert: CertTypes = None,
         verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
         trust_env: bool = None,
     ) -> Response:
         return await self.request(
@@ -827,21 +835,19 @@ def _proxies_to_dispatchers(
     proxies: typing.Optional[ProxiesTypes],
     verify: VerifyTypes,
     cert: typing.Optional[CertTypes],
-    timeout: TimeoutTypes,
     http2: bool,
     pool_limits: PoolLimits,
     backend: typing.Union[str, ConcurrencyBackend],
     trust_env: bool,
 ) -> typing.Dict[str, Dispatcher]:
     def _proxy_from_url(url: URLTypes) -> Dispatcher:
-        nonlocal verify, cert, timeout, http2, pool_limits, backend, trust_env
+        nonlocal verify, cert, http2, pool_limits, backend, trust_env
         url = URL(url)
         if url.scheme in ("http", "https"):
             return HTTPProxy(
                 url,
                 verify=verify,
                 cert=cert,
-                timeout=timeout,
                 pool_limits=pool_limits,
                 backend=backend,
                 trust_env=trust_env,
index 53c1c79cef053f149a18ebfe3fc882ab1af43ea4..065996077c0e1c118848f4d79c1ede2c5488e60d 100644 (file)
@@ -74,15 +74,13 @@ class BaseSocketStream:
     ) -> "BaseSocketStream":
         raise NotImplementedError()  # pragma: no cover
 
-    async def read(
-        self, n: int, timeout: Timeout = None, flag: typing.Any = None
-    ) -> bytes:
+    async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
     def write_no_block(self, data: bytes) -> None:
         raise NotImplementedError()  # pragma: no cover
 
-    async def write(self, data: bytes, timeout: Timeout = None) -> None:
+    async def write(self, data: bytes, timeout: Timeout) -> None:
         raise NotImplementedError()  # pragma: no cover
 
     async def close(self) -> None:
index 6e0af89861915a540d07e2cf652f6a9c50652337..cf09b07924a56d5265d55c730d6e09f251da9106 100644 (file)
@@ -1,7 +1,7 @@
 import typing
 from types import TracebackType
 
-from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..config import CertTypes, Timeout, VerifyTypes
 from ..models import (
     HeaderTypes,
     QueryParamTypes,
@@ -31,7 +31,7 @@ class Dispatcher:
         headers: HeaderTypes = None,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: Timeout = None,
     ) -> Response:
         request = Request(method, url, data=data, params=params, headers=headers)
         return await self.send(request, verify=verify, cert=cert, timeout=timeout)
@@ -41,7 +41,7 @@ class Dispatcher:
         request: Request,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: Timeout = None,
     ) -> Response:
         raise NotImplementedError()  # pragma: nocover
 
index c83a2af363d25cd43f62cd101a6e242f58ada998..ecb84a748bd1616ce3e46a3d905163ee621295f0 100644 (file)
@@ -3,14 +3,7 @@ import ssl
 import typing
 
 from ..concurrency.base import ConcurrencyBackend, lookup_backend
-from ..config import (
-    DEFAULT_TIMEOUT_CONFIG,
-    CertTypes,
-    SSLConfig,
-    Timeout,
-    TimeoutTypes,
-    VerifyTypes,
-)
+from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
 from ..models import Origin, Request, Response
 from ..utils import get_logger
 from .base import Dispatcher
@@ -31,7 +24,6 @@ class HTTPConnection(Dispatcher):
         verify: VerifyTypes = True,
         cert: CertTypes = None,
         trust_env: bool = None,
-        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         http2: bool = False,
         backend: typing.Union[str, ConcurrencyBackend] = "auto",
         release_func: typing.Optional[ReleaseCallback] = None,
@@ -39,7 +31,6 @@ class HTTPConnection(Dispatcher):
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
-        self.timeout = Timeout(timeout)
         self.http2 = http2
         self.backend = lookup_backend(backend)
         self.release_func = release_func
@@ -52,8 +43,10 @@ class HTTPConnection(Dispatcher):
         request: Request,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: Timeout = None,
     ) -> Response:
+        timeout = Timeout() if timeout is None else timeout
+
         if self.h11_connection is None and self.h2_connection is None:
             await self.connect(verify=verify, cert=cert, timeout=timeout)
 
@@ -66,13 +59,9 @@ class HTTPConnection(Dispatcher):
         return response
 
     async def connect(
-        self,
-        verify: VerifyTypes = None,
-        cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        self, timeout: Timeout, verify: VerifyTypes = None, cert: CertTypes = None,
     ) -> None:
         ssl = self.ssl.with_overrides(verify=verify, cert=cert)
-        timeout = self.timeout if timeout is None else Timeout(timeout)
 
         host = self.origin.host
         port = self.origin.port
index 0dddb2392504173de1c152efb86563fe5bc8174b..f11137fbd02ca18f5d13c1f84f51c6b804ef77e6 100644 (file)
@@ -1,15 +1,7 @@
 import typing
 
 from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
-from ..config import (
-    DEFAULT_POOL_LIMITS,
-    DEFAULT_TIMEOUT_CONFIG,
-    CertTypes,
-    PoolLimits,
-    Timeout,
-    TimeoutTypes,
-    VerifyTypes,
-)
+from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
 from ..models import Origin, Request, Response
 from ..utils import get_logger
 from .base import Dispatcher
@@ -84,7 +76,6 @@ class ConnectionPool(Dispatcher):
         verify: VerifyTypes = True,
         cert: CertTypes = None,
         trust_env: bool = None,
-        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         http2: bool = False,
         backend: typing.Union[str, ConcurrencyBackend] = "auto",
@@ -92,7 +83,6 @@ class ConnectionPool(Dispatcher):
     ):
         self.verify = verify
         self.cert = cert
-        self.timeout = Timeout(timeout)
         self.pool_limits = pool_limits
         self.http2 = http2
         self.is_closed = False
@@ -121,7 +111,7 @@ class ConnectionPool(Dispatcher):
         request: Request,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: Timeout = None,
     ) -> Response:
         connection = await self.acquire_connection(
             origin=request.url.origin, timeout=timeout
@@ -138,23 +128,19 @@ class ConnectionPool(Dispatcher):
         return response
 
     async def acquire_connection(
-        self, origin: Origin, timeout: TimeoutTypes = None
+        self, origin: Origin, timeout: Timeout = None
     ) -> HTTPConnection:
         logger.trace(f"acquire_connection origin={origin!r}")
         connection = self.pop_connection(origin)
 
         if connection is None:
-            if timeout is None:
-                pool_timeout = self.timeout.pool_timeout
-            else:
-                pool_timeout = Timeout(timeout).pool_timeout
+            pool_timeout = None if timeout is None else timeout.pool_timeout
 
             await self.max_connections.acquire(timeout=pool_timeout)
             connection = HTTPConnection(
                 origin,
                 verify=self.verify,
                 cert=self.cert,
-                timeout=self.timeout,
                 http2=self.http2,
                 backend=self.backend,
                 release_func=self.release_connection,
index 89950c9b02b5b69d3121bf91f26d140753123d77..bbe2ac3f9e87ad8de261c5193cc7a11b489307ad 100644 (file)
@@ -3,7 +3,7 @@ import typing
 import h11
 
 from ..concurrency.base import BaseSocketStream, TimeoutFlag
-from ..config import Timeout, TimeoutTypes
+from ..config import Timeout
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
@@ -40,8 +40,8 @@ class HTTP11Connection:
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
         self.timeout_flag = TimeoutFlag()
 
-    async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
-        timeout = None if timeout is None else Timeout(timeout)
+    async def send(self, request: Request, timeout: Timeout = None) -> Response:
+        timeout = Timeout() if timeout is None else timeout
 
         await self._send_request(request, timeout)
         await self._send_request_body(request, timeout)
@@ -67,7 +67,7 @@ class HTTP11Connection:
             pass
         await self.stream.close()
 
-    async def _send_request(self, request: Request, timeout: Timeout = None) -> None:
+    async def _send_request(self, request: Request, timeout: Timeout) -> None:
         """
         Send the request method, URL, and headers to the network.
         """
@@ -83,9 +83,7 @@ class HTTP11Connection:
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event, timeout)
 
-    async def _send_request_body(
-        self, request: Request, timeout: Timeout = None
-    ) -> None:
+    async def _send_request_body(self, request: Request, timeout: Timeout) -> None:
         """
         Send the request body to the network.
         """
@@ -108,7 +106,7 @@ class HTTP11Connection:
             # Once we've sent the request, we enable read timeouts.
             self.timeout_flag.set_read_timeouts()
 
-    async def _send_event(self, event: H11Event, timeout: Timeout = None) -> None:
+    async def _send_event(self, event: H11Event, timeout: Timeout) -> None:
         """
         Send a single `h11` event to the network, waiting for the data to
         drain before returning.
@@ -117,7 +115,7 @@ class HTTP11Connection:
         await self.stream.write(bytes_to_send, timeout)
 
     async def _receive_response(
-        self, timeout: Timeout = None
+        self, timeout: Timeout
     ) -> typing.Tuple[str, int, typing.List[typing.Tuple[bytes, bytes]]]:
         """
         Read the response status and headers from the network.
@@ -136,7 +134,7 @@ class HTTP11Connection:
         return http_version, event.status_code, event.headers
 
     async def _receive_response_data(
-        self, timeout: Timeout = None
+        self, timeout: Timeout
     ) -> typing.AsyncIterator[bytes]:
         """
         Read the response data from the network.
@@ -149,7 +147,7 @@ class HTTP11Connection:
                 assert isinstance(event, h11.EndOfMessage) or event is h11.PAUSED
                 break  # pragma: no cover
 
-    async def _receive_event(self, timeout: Timeout = None) -> H11Event:
+    async def _receive_event(self, timeout: Timeout) -> H11Event:
         """
         Read a single `h11` event, reading more data from the network if needed.
         """
index 226b1673f6492157e506efb64216b38f1e248244..8b73c6c66387b09ffa539da391954af97c071ef9 100644 (file)
@@ -12,7 +12,7 @@ from ..concurrency.base import (
     TimeoutFlag,
     lookup_backend,
 )
-from ..config import Timeout, TimeoutTypes
+from ..config import Timeout
 from ..exceptions import ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
@@ -38,8 +38,8 @@ class HTTP2Connection:
         self.initialized = False
         self.window_update_received = {}  # type: typing.Dict[int, BaseEvent]
 
-    async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
-        timeout = None if timeout is None else Timeout(timeout)
+    async def send(self, request: Request, timeout: Timeout = None) -> Response:
+        timeout = Timeout() if timeout is None else timeout
 
         # Start sending the request.
         if not self.initialized:
@@ -97,7 +97,7 @@ class HTTP2Connection:
         self.stream.write_no_block(data_to_send)
         self.initialized = True
 
-    async def send_headers(self, request: Request, timeout: Timeout = None) -> int:
+    async def send_headers(self, request: Request, timeout: Timeout) -> int:
         stream_id = self.h2_state.get_next_available_stream_id()
         headers = [
             (b":method", request.method.encode("ascii")),
@@ -119,10 +119,7 @@ class HTTP2Connection:
         return stream_id
 
     async def send_request_data(
-        self,
-        stream_id: int,
-        stream: typing.AsyncIterator[bytes],
-        timeout: Timeout = None,
+        self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout,
     ) -> None:
         try:
             async for data in stream:
@@ -132,9 +129,7 @@ class HTTP2Connection:
             # Once we've sent the request we should enable read timeouts.
             self.timeout_flags[stream_id].set_read_timeouts()
 
-    async def send_data(
-        self, stream_id: int, data: bytes, timeout: Timeout = None
-    ) -> None:
+    async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None:
         while data:
             # The data will be divided into frames to send based on the flow control
             # window and the maximum frame size. Because the flow control window
@@ -157,14 +152,14 @@ class HTTP2Connection:
                 data_to_send = self.h2_state.data_to_send()
                 await self.stream.write(data_to_send, timeout)
 
-    async def end_stream(self, stream_id: int, timeout: Timeout = None) -> None:
+    async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
         logger.trace(f"end_stream stream_id={stream_id}")
         self.h2_state.end_stream(stream_id)
         data_to_send = self.h2_state.data_to_send()
         await self.stream.write(data_to_send, timeout)
 
     async def receive_response(
-        self, stream_id: int, timeout: Timeout = None
+        self, stream_id: int, timeout: Timeout
     ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
         """
         Read the response status and headers from the network.
@@ -188,7 +183,7 @@ class HTTP2Connection:
         return (status_code, headers)
 
     async def body_iter(
-        self, stream_id: int, timeout: Timeout = None
+        self, stream_id: int, timeout: Timeout
     ) -> typing.AsyncIterator[bytes]:
         while True:
             event = await self.receive_event(stream_id, timeout)
@@ -200,9 +195,7 @@ class HTTP2Connection:
             elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
                 break
 
-    async def receive_event(
-        self, stream_id: int, timeout: Timeout = None
-    ) -> h2.events.Event:
+    async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
         while not self.events[stream_id]:
             flag = self.timeout_flags[stream_id]
             data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
index d146c57daa16b1887db6e5a5b1c21425fa04bcbb..55577e38ee4f6542b357246ac5cd696aba87b6a4 100644 (file)
@@ -7,11 +7,10 @@ import h11
 from ..concurrency.base import ConcurrencyBackend
 from ..config import (
     DEFAULT_POOL_LIMITS,
-    DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
     PoolLimits,
     SSLConfig,
-    TimeoutTypes,
+    Timeout,
     VerifyTypes,
 )
 from ..exceptions import ProxyError
@@ -45,7 +44,6 @@ class HTTPProxy(ConnectionPool):
         verify: VerifyTypes = True,
         cert: CertTypes = None,
         trust_env: bool = None,
-        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         http2: bool = False,
         backend: typing.Union[str, ConcurrencyBackend] = "auto",
@@ -54,7 +52,6 @@ class HTTPProxy(ConnectionPool):
         super(HTTPProxy, self).__init__(
             verify=verify,
             cert=cert,
-            timeout=timeout,
             pool_limits=pool_limits,
             backend=backend,
             trust_env=trust_env,
@@ -82,7 +79,7 @@ class HTTPProxy(ConnectionPool):
         return f"Basic {token}"
 
     async def acquire_connection(
-        self, origin: Origin, timeout: TimeoutTypes = None
+        self, origin: Origin, timeout: Timeout = None
     ) -> HTTPConnection:
         if self.should_forward_origin(origin):
             logger.trace(
@@ -93,9 +90,11 @@ class HTTPProxy(ConnectionPool):
             logger.trace(
                 f"tunnel_connection proxy_url={self.proxy_url!r} origin={origin!r}"
             )
-            return await self.tunnel_connection(origin)
+            return await self.tunnel_connection(origin, timeout)
 
-    async def tunnel_connection(self, origin: Origin) -> HTTPConnection:
+    async def tunnel_connection(
+        self, origin: Origin, timeout: Timeout = None
+    ) -> HTTPConnection:
         """Creates a new HTTPConnection via the CONNECT method
         usually reserved for proxying HTTPS connections.
         """
@@ -111,7 +110,8 @@ class HTTPProxy(ConnectionPool):
             connection.origin = origin
             self.active_connections.add(connection)
 
-            await self.tunnel_start_tls(origin, connection)
+            timeout = Timeout() if timeout is None else timeout
+            await self.tunnel_start_tls(origin, connection, timeout)
         else:
             self.active_connections.add(connection)
 
@@ -132,7 +132,6 @@ class HTTPProxy(ConnectionPool):
             self.proxy_url.origin,
             verify=self.verify,
             cert=self.cert,
-            timeout=self.timeout,
             backend=self.backend,
             http2=False,  # Short-lived 'connection'
             trust_env=self.trust_env,
@@ -163,9 +162,10 @@ class HTTPProxy(ConnectionPool):
         return connection
 
     async def tunnel_start_tls(
-        self, origin: Origin, connection: HTTPConnection
+        self, origin: Origin, connection: HTTPConnection, timeout: Timeout = None
     ) -> None:
         """Runs start_tls() on a TCP-tunneled connection"""
+        timeout = Timeout() if timeout is None else timeout
 
         # Store this information here so that we can transfer
         # it to the new internal connection object after
@@ -182,7 +182,6 @@ class HTTPProxy(ConnectionPool):
         # HTTP connection object and run start_tls()
         if origin.is_ssl:
             ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
-            timeout = connection.timeout
             ssl_context = await connection.get_ssl_context(ssl_config)
             assert ssl_context is not None
 
@@ -225,7 +224,7 @@ class HTTPProxy(ConnectionPool):
         request: Request,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
+        timeout: Timeout = None,
     ) -> Response:
 
         if self.should_forward_origin(request.url.origin):
index c0c1ae0e83a169b73e32cee6ec34299b8f510a42..89af4d61eaae9839989296746598003fcb3efb7c 100644 (file)
@@ -46,7 +46,6 @@ def test_proxies_has_same_properties_as_dispatch():
     for prop in [
         "verify",
         "cert",
-        "timeout",
         "pool_limits",
     ]:
         assert getattr(pool, prop) == getattr(proxy, prop)