From: Tom Christie Date: Thu, 23 May 2019 15:21:00 +0000 (+0100) Subject: Work on bringing API into parity with `requests`. (#76) X-Git-Tag: 0.3.1~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=95740415db889e68255bf4839ab158dc54e02785;p=thirdparty%2Fhttpx.git Work on bringing API into parity with `requests`. (#76) * Finesse timeout argument. * Drop unused imports * Add 'cert' and 'verify' arguments --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 6d83e16f..8d443963 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,7 +1,14 @@ from .api import delete, get, head, options, patch, post, put, request from .client import AsyncClient, Client from .concurrency import AsyncioBackend -from .config import PoolLimits, SSLConfig, TimeoutConfig +from .config import ( + CertTypes, + PoolLimits, + SSLConfig, + TimeoutConfig, + TimeoutTypes, + VerifyTypes, +) from .dispatch.connection import HTTPConnection from .dispatch.connection_pool import ConnectionPool from .exceptions import ( diff --git a/httpcore/api.py b/httpcore/api.py index 8e242c74..33d68c5e 100644 --- a/httpcore/api.py +++ b/httpcore/api.py @@ -1,7 +1,7 @@ import typing from .client import Client -from .config import SSLConfig, TimeoutConfig +from .config import CertTypes, TimeoutTypes, VerifyTypes from .models import ( AuthTypes, CookieTypes, @@ -17,16 +17,19 @@ def request( method: str, url: URLTypes, *, + params: QueryParamTypes = None, data: RequestData = b"", json: typing.Any = None, - params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, - stream: bool = False, + # files auth: AuthTypes = None, + timeout: TimeoutTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + # proxies + cert: CertTypes = None, + verify: VerifyTypes = True, + stream: bool = False, ) -> SyncResponse: with Client() as client: return client.request( @@ -40,7 +43,8 @@ def request( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -54,8 +58,9 @@ def get( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "GET", @@ -65,7 +70,8 @@ def get( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -79,8 +85,9 @@ def options( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "OPTIONS", @@ -90,7 +97,8 @@ def options( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -104,8 +112,9 @@ def head( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "HEAD", @@ -115,7 +124,8 @@ def head( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -131,8 +141,9 @@ def post( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "POST", @@ -144,7 +155,8 @@ def post( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -160,8 +172,9 @@ def put( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "PUT", @@ -173,7 +186,8 @@ def put( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -189,8 +203,9 @@ def patch( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "PATCH", @@ -202,7 +217,8 @@ def patch( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) @@ -218,8 +234,9 @@ def delete( stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = None, ) -> SyncResponse: return request( "DELETE", @@ -231,6 +248,7 @@ def delete( stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + cert=cert, + verify=verify, timeout=timeout, ) diff --git a/httpcore/client.py b/httpcore/client.py index cc31e844..39fa0aa5 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -6,11 +6,11 @@ from .auth import HTTPBasicAuth from .config import ( DEFAULT_MAX_REDIRECTS, DEFAULT_POOL_LIMITS, - DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, + CertTypes, PoolLimits, - SSLConfig, - TimeoutConfig, + TimeoutTypes, + VerifyTypes, ) from .dispatch.connection_pool import ConnectionPool from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects @@ -37,8 +37,9 @@ class AsyncClient: self, auth: AuthTypes = None, cookies: CookieTypes = None, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + verify: VerifyTypes = True, + cert: CertTypes = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, max_redirects: int = DEFAULT_MAX_REDIRECTS, dispatch: Dispatcher = None, @@ -46,7 +47,11 @@ class AsyncClient: ): if dispatch is None: dispatch = ConnectionPool( - ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend + verify=verify, + cert=cert, + timeout=timeout, + pool_limits=pool_limits, + backend=backend, ) self.auth = auth @@ -64,8 +69,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "GET", @@ -76,7 +82,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -90,8 +97,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "OPTIONS", @@ -102,7 +110,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -116,8 +125,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "HEAD", @@ -128,7 +138,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -144,8 +155,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "POST", @@ -158,7 +170,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -174,8 +187,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "PUT", @@ -188,7 +202,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -204,8 +219,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "PATCH", @@ -218,7 +234,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -234,8 +251,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: return await self.request( "DELETE", @@ -248,7 +266,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -265,8 +284,9 @@ class AsyncClient: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> Response: request = Request( method, @@ -283,7 +303,8 @@ class AsyncClient: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) return response @@ -306,9 +327,10 @@ class AsyncClient: *, stream: bool = False, auth: AuthTypes = None, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, allow_redirects: bool = True, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: if auth is None: auth = self.auth @@ -325,7 +347,8 @@ class AsyncClient: response = await self.send_handling_redirects( request, stream=stream, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, allow_redirects=allow_redirects, ) @@ -336,8 +359,9 @@ class AsyncClient: request: Request, *, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, allow_redirects: bool = True, history: typing.List[Response] = None, ) -> Response: @@ -353,7 +377,7 @@ class AsyncClient: raise RedirectLoop() response = await self.dispatch.send( - request, stream=stream, ssl=ssl, timeout=timeout + request, stream=stream, verify=verify, cert=cert, timeout=timeout ) response.history = list(history) self.cookies.extract_cookies(response) @@ -366,13 +390,14 @@ class AsyncClient: else: async def send_next() -> Response: - nonlocal request, response, ssl, allow_redirects, timeout, history + nonlocal request, response, verify, cert, allow_redirects, timeout, history request = self.build_redirect_request(request, response) response = await self.send_handling_redirects( request, stream=stream, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, history=history, ) @@ -474,8 +499,9 @@ class Client: def __init__( self, auth: AuthTypes = None, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + cert: CertTypes = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, max_redirects: int = DEFAULT_MAX_REDIRECTS, dispatch: Dispatcher = None, @@ -483,7 +509,8 @@ class Client: ) -> None: self._client = AsyncClient( auth=auth, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, pool_limits=pool_limits, max_redirects=max_redirects, @@ -509,8 +536,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: request = Request( method, @@ -527,7 +555,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) return response @@ -542,8 +571,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "GET", @@ -553,7 +583,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -567,8 +598,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "OPTIONS", @@ -578,7 +610,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -592,8 +625,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "HEAD", @@ -603,7 +637,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -619,8 +654,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "POST", @@ -632,7 +668,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -648,8 +685,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "PUT", @@ -661,7 +699,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -677,8 +716,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "PATCH", @@ -690,7 +730,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -706,8 +747,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: return self.request( "DELETE", @@ -719,7 +761,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) @@ -733,8 +776,9 @@ class Client: stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> SyncResponse: response = self._loop.run_until_complete( self._client.send( @@ -742,7 +786,8 @@ class Client: stream=stream, auth=auth, allow_redirects=allow_redirects, - ssl=ssl, + verify=verify, + cert=cert, timeout=timeout, ) ) diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index fb20d7c5..0c1d3409 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -22,9 +22,6 @@ from .interfaces import ( Protocol, ) -OptionalTimeout = typing.Optional[TimeoutConfig] - - SSL_MONKEY_PATCH_APPLIED = False @@ -56,7 +53,7 @@ class Reader(BaseReader): self.stream_reader = stream_reader self.timeout = timeout - async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes: if timeout is None: timeout = self.timeout @@ -78,7 +75,7 @@ class Writer(BaseWriter): def write_no_block(self, data: bytes) -> None: self.stream_writer.write(data) # pragma: nocover - async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None: if not data: return diff --git a/httpcore/config.py b/httpcore/config.py index 82fd125f..5b3c3131 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -5,18 +5,17 @@ import typing import certifi +CertTypes = typing.Union[str, typing.Tuple[str, str]] +VerifyTypes = typing.Union[str, bool] +TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] + class SSLConfig: """ SSL Configuration. """ - def __init__( - self, - *, - cert: typing.Union[None, str, typing.Tuple[str, str]] = None, - verify: typing.Union[str, bool] = True, - ): + def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True): self.cert = cert self.verify = verify @@ -31,6 +30,15 @@ class SSLConfig: class_name = self.__class__.__name__ return f"{class_name}(cert={self.cert}, verify={self.verify})" + def with_overrides( + self, cert: CertTypes = None, verify: VerifyTypes = None + ) -> "SSLConfig": + cert = self.cert if cert is None else cert + verify = self.verify if verify is None else verify + if (cert == self.cert) and (verify == self.verify): + return self + return SSLConfig(cert=cert, verify=verify) + async def load_ssl_context(self) -> ssl.SSLContext: if not hasattr(self, "ssl_context"): if not self.verify: @@ -109,25 +117,33 @@ class TimeoutConfig: def __init__( self, - timeout: float = None, + timeout: TimeoutTypes = None, *, connect_timeout: float = None, read_timeout: float = None, write_timeout: float = None, ): - if timeout is not None: + if timeout is None: + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout + self.write_timeout = write_timeout + else: # Specified as a single timeout value assert connect_timeout is None assert read_timeout is None assert write_timeout is None - connect_timeout = timeout - read_timeout = timeout - write_timeout = timeout - - self.timeout = timeout - self.connect_timeout = connect_timeout - self.read_timeout = read_timeout - self.write_timeout = write_timeout + if isinstance(timeout, TimeoutConfig): + self.connect_timeout = timeout.connect_timeout + self.read_timeout = timeout.read_timeout + self.write_timeout = timeout.write_timeout + elif isinstance(timeout, tuple): + self.connect_timeout = timeout[0] + self.read_timeout = timeout[1] + self.write_timeout = timeout[2] + else: + self.connect_timeout = timeout + self.read_timeout = timeout + self.write_timeout = timeout def __eq__(self, other: typing.Any) -> bool: return ( @@ -139,8 +155,8 @@ class TimeoutConfig: def __repr__(self) -> str: class_name = self.__class__.__name__ - if self.timeout is not None: - return f"{class_name}(timeout={self.timeout})" + if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1: + return f"{class_name}(timeout={self.connect_timeout})" return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})" diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 053a9980..60214333 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -8,8 +8,11 @@ from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, + CertTypes, SSLConfig, TimeoutConfig, + TimeoutTypes, + VerifyTypes, ) from ..exceptions import ConnectTimeout from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol @@ -25,14 +28,15 @@ class HTTPConnection(Dispatcher): def __init__( self, origin: typing.Union[str, Origin], - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + verify: VerifyTypes = True, + cert: CertTypes = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin - self.ssl = ssl - self.timeout = timeout + self.ssl = SSLConfig(cert=cert, verify=verify) + self.timeout = TimeoutConfig(timeout) self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] @@ -42,11 +46,12 @@ class HTTPConnection(Dispatcher): self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: if self.h11_connection is None and self.h2_connection is None: - await self.connect(ssl=ssl, timeout=timeout) + await self.connect(verify=verify, cert=cert, timeout=timeout) if self.h2_connection is not None: response = await self.h2_connection.send( @@ -61,12 +66,13 @@ class HTTPConnection(Dispatcher): return response async def connect( - self, ssl: SSLConfig = None, timeout: TimeoutConfig = None + self, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> None: - if ssl is None: - ssl = self.ssl - if timeout is None: - timeout = self.timeout + ssl = self.ssl.with_overrides(verify=verify, cert=cert) + timeout = self.timeout if timeout is None else TimeoutConfig(timeout) host = self.origin.host port = self.origin.port diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index ba92acd0..e7cefbd7 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -4,11 +4,11 @@ from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_CA_BUNDLE_PATH, DEFAULT_POOL_LIMITS, - DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, + CertTypes, PoolLimits, - SSLConfig, - TimeoutConfig, + TimeoutTypes, + VerifyTypes, ) from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout @@ -81,12 +81,14 @@ class ConnectionPool(Dispatcher): def __init__( self, *, - ssl: SSLConfig = DEFAULT_SSL_CONFIG, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + verify: VerifyTypes = True, + cert: CertTypes = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, backend: ConcurrencyBackend = None, ): - self.ssl = ssl + self.verify = verify + self.cert = cert self.timeout = timeout self.pool_limits = pool_limits self.is_closed = False @@ -105,13 +107,14 @@ class ConnectionPool(Dispatcher): self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: connection = await self.acquire_connection(request.url.origin) try: response = await connection.send( - request, stream=stream, ssl=ssl, timeout=timeout + request, stream=stream, verify=verify, cert=cert, timeout=timeout ) except BaseException as exc: self.active_connections.remove(connection) @@ -128,7 +131,8 @@ class ConnectionPool(Dispatcher): await self.max_connections.acquire() connection = HTTPConnection( origin, - ssl=self.ssl, + verify=self.verify, + cert=self.cert, timeout=self.timeout, backend=self.backend, release_func=self.release_connection, diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index fc5f34fc..4308f64a 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -2,12 +2,7 @@ import typing import h11 -from ..config import ( - DEFAULT_SSL_CONFIG, - DEFAULT_TIMEOUT_CONFIG, - SSLConfig, - TimeoutConfig, -) +from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter, Dispatcher from ..models import Request, Response @@ -22,8 +17,6 @@ H11Event = typing.Union[ ] -OptionalTimeout = typing.Optional[TimeoutConfig] - # Callback signature: async def callback() -> None # In practice the callback will be a functools partial, which binds # the `ConnectionPool.release_connection(conn: HTTPConnection)` method. @@ -45,8 +38,10 @@ class HTTP11Connection: self.h11_state = h11.Connection(our_role=h11.CLIENT) async def send( - self, request: Request, stream: bool = False, timeout: TimeoutConfig = None + self, request: Request, stream: bool = False, timeout: TimeoutTypes = None ) -> Response: + timeout = None if timeout is None else TimeoutConfig(timeout) + #  Start sending the request. method = request.method.encode("ascii") target = request.url.full_path.encode("ascii") @@ -97,18 +92,20 @@ class HTTP11Connection: self.h11_state.send(event) await self.writer.close() - async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]: + async def _body_iter( + self, timeout: TimeoutConfig = None + ) -> typing.AsyncIterator[bytes]: event = await self._receive_event(timeout) while isinstance(event, h11.Data): yield event.data event = await self._receive_event(timeout) assert isinstance(event, h11.EndOfMessage) - async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None: + async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None: data = self.h11_state.send(event) await self.writer.write(data, timeout) - async def _receive_event(self, timeout: OptionalTimeout) -> H11Event: + async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event: event = self.h11_state.next_event() while event is h11.NEED_DATA: diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 301a36c4..bb1857f3 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -4,18 +4,11 @@ import typing import h2.connection import h2.events -from ..config import ( - DEFAULT_SSL_CONFIG, - DEFAULT_TIMEOUT_CONFIG, - SSLConfig, - TimeoutConfig, -) +from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter, Dispatcher from ..models import Request, Response -OptionalTimeout = typing.Optional[TimeoutConfig] - class HTTP2Connection: READ_NUM_BYTES = 4096 @@ -31,8 +24,10 @@ class HTTP2Connection: self.initialized = False async def send( - self, request: Request, stream: bool = False, timeout: TimeoutConfig = None + self, request: Request, stream: bool = False, timeout: TimeoutTypes = None ) -> Response: + timeout = None if timeout is None else TimeoutConfig(timeout) + #  Start sending the request. if not self.initialized: self.initiate_connection() @@ -89,7 +84,9 @@ class HTTP2Connection: self.writer.write_no_block(data_to_send) self.initialized = True - async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int: + async def send_headers( + self, request: Request, timeout: TimeoutConfig = None + ) -> int: stream_id = self.h2_state.get_next_available_stream_id() headers = [ (b":method", request.method.encode("ascii")), @@ -103,19 +100,19 @@ class HTTP2Connection: return stream_id async def send_data( - self, stream_id: int, data: bytes, timeout: OptionalTimeout + self, stream_id: int, data: bytes, timeout: TimeoutConfig = None ) -> None: self.h2_state.send_data(stream_id, data) data_to_send = self.h2_state.data_to_send() await self.writer.write(data_to_send, timeout) - async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None: + async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None: self.h2_state.end_stream(stream_id) data_to_send = self.h2_state.data_to_send() await self.writer.write(data_to_send, timeout) async def body_iter( - self, stream_id: int, timeout: OptionalTimeout + self, stream_id: int, timeout: TimeoutConfig = None ) -> typing.AsyncIterator[bytes]: while True: event = await self.receive_event(stream_id, timeout) @@ -125,7 +122,7 @@ class HTTP2Connection: break async def receive_event( - self, stream_id: int, timeout: OptionalTimeout + self, stream_id: int, timeout: TimeoutConfig = None ) -> h2.events.Event: while not self.events[stream_id]: data = await self.reader.read(self.READ_NUM_BYTES, timeout) diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 20b17239..f2e846be 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -3,7 +3,7 @@ import ssl import typing from types import TracebackType -from .config import PoolLimits, SSLConfig, TimeoutConfig +from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( URL, Headers, @@ -15,8 +15,6 @@ from .models import ( URLTypes, ) -OptionalTimeout = typing.Optional[TimeoutConfig] - class Protocol(str, enum.Enum): HTTP_11 = "HTTP/1.1" @@ -41,12 +39,15 @@ class Dispatcher: params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None ) -> Response: request = Request(method, url, data=data, params=params, headers=headers) self.prepare_request(request) - response = await self.send(request, stream=stream, ssl=ssl, timeout=timeout) + response = await self.send( + request, stream=stream, verify=verify, cert=cert, timeout=timeout + ) return response def prepare_request(self, request: Request) -> None: @@ -56,8 +57,9 @@ class Dispatcher: self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: raise NotImplementedError() # pragma: nocover @@ -83,7 +85,7 @@ class BaseReader: backend, or for stand-alone test cases. """ - async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes: raise NotImplementedError() # pragma: no cover @@ -97,7 +99,7 @@ class BaseWriter: def write_no_block(self, data: bytes) -> None: raise NotImplementedError() # pragma: no cover - async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None: raise NotImplementedError() # pragma: no cover async def close(self) -> None: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8a79a50a..1d2b9723 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -4,12 +4,13 @@ import pytest from httpcore import ( URL, + CertTypes, Client, Dispatcher, Request, Response, - SSLConfig, - TimeoutConfig, + TimeoutTypes, + VerifyTypes, ) @@ -18,8 +19,9 @@ class MockDispatch(Dispatcher): self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: body = json.dumps({"auth": request.headers.get("Authorization")}).encode() return Response(200, content=body, request=request) diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index 59e70b03..a21f5c13 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -5,13 +5,14 @@ import pytest from httpcore import ( URL, + CertTypes, Client, Cookies, Dispatcher, Request, Response, - SSLConfig, - TimeoutConfig, + TimeoutTypes, + VerifyTypes, ) @@ -20,8 +21,9 @@ class MockDispatch(Dispatcher): self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: if request.url.path.startswith("/echo_cookies"): body = json.dumps({"cookies": request.headers.get("Cookie")}).encode() diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 0edad78a..c3b384dc 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -6,14 +6,15 @@ import pytest from httpcore import ( URL, AsyncClient, + CertTypes, Dispatcher, RedirectBodyUnavailable, RedirectLoop, Request, Response, - SSLConfig, - TimeoutConfig, + TimeoutTypes, TooManyRedirects, + VerifyTypes, codes, ) @@ -23,8 +24,9 @@ class MockDispatch(Dispatcher): self, request: Request, stream: bool = False, - ssl: SSLConfig = None, - timeout: TimeoutConfig = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, ) -> Response: if request.url.path == "/redirect_301": status_code = codes.MOVED_PERMANENTLY diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index 2edf3ada..f323f55d 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -6,25 +6,35 @@ from httpcore import HTTPConnection, Request, SSLConfig @pytest.mark.asyncio async def test_get(server): conn = HTTPConnection(origin="http://127.0.0.1:8000/") - request = Request("GET", "http://127.0.0.1:8000/") - request.prepare() - response = await conn.send(request) + response = await conn.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.content == b"Hello, world!" @pytest.mark.asyncio -async def test_https_get(https_server): - http = HTTPConnection(origin="https://127.0.0.1:8001/", ssl=SSLConfig(verify=False)) - response = await http.request("GET", "https://127.0.0.1:8001/") +async def test_post(server): + conn = HTTPConnection(origin="http://127.0.0.1:8000/") + response = await conn.request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_https_get_with_ssl_defaults(https_server): + """ + An HTTPS request, with default SSL configuration set on the client. + """ + conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False) + response = await conn.request("GET", "https://127.0.0.1:8001/") assert response.status_code == 200 assert response.content == b"Hello, world!" @pytest.mark.asyncio -async def test_post(server): - conn = HTTPConnection(origin="http://127.0.0.1:8000/") - request = Request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!") - request.prepare() - response = await conn.send(request) +async def test_https_get_with_sll_overrides(https_server): + """ + An HTTPS request, with SSL configuration set on the request. + """ + conn = HTTPConnection(origin="https://127.0.0.1:8001/") + response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) assert response.status_code == 200 + assert response.content == b"Hello, world!" diff --git a/tests/test_config.py b/tests/test_config.py index bf12edb2..4ee6d6e7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -94,3 +94,13 @@ def test_timeout_eq(): def test_limits_eq(): limits = httpcore.PoolLimits(hard_limit=100) assert limits == httpcore.PoolLimits(hard_limit=100) + + +def test_timeout_from_tuple(): + timeout = httpcore.TimeoutConfig(timeout=(5.0, 5.0, 5.0)) + assert timeout == httpcore.TimeoutConfig(timeout=5.0) + + +def test_timeout_from_config_instance(): + timeout = httpcore.TimeoutConfig(timeout=(5.0)) + assert httpcore.TimeoutConfig(timeout) == httpcore.TimeoutConfig(timeout=5.0)