]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Work on bringing API into parity with `requests`. (#76)
authorTom Christie <tom@tomchristie.com>
Thu, 23 May 2019 15:21:00 +0000 (16:21 +0100)
committerGitHub <noreply@github.com>
Thu, 23 May 2019 15:21:00 +0000 (16:21 +0100)
* Finesse timeout argument.

* Drop unused imports

* Add 'cert' and 'verify' arguments

15 files changed:
httpcore/__init__.py
httpcore/api.py
httpcore/client.py
httpcore/concurrency.py
httpcore/config.py
httpcore/dispatch/connection.py
httpcore/dispatch/connection_pool.py
httpcore/dispatch/http11.py
httpcore/dispatch/http2.py
httpcore/interfaces.py
tests/client/test_auth.py
tests/client/test_cookies.py
tests/client/test_redirects.py
tests/dispatch/test_connections.py
tests/test_config.py

index 6d83e16f20b9ea4051bbd7ed1662a64b94657029..8d443963cd7b1384548951262cb694ecaa520f2f 100644 (file)
@@ -1,7 +1,14 @@
 from .api import delete, get, head, options, patch, post, put, request
 from .client import AsyncClient, Client
 from .concurrency import AsyncioBackend
-from .config import PoolLimits, SSLConfig, TimeoutConfig
+from .config import (
+    CertTypes,
+    PoolLimits,
+    SSLConfig,
+    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
+)
 from .dispatch.connection import HTTPConnection
 from .dispatch.connection_pool import ConnectionPool
 from .exceptions import (
index 8e242c74b176e81708ef3042344c579815f9214c..33d68c5e77832fcb66558bb9d667c44b3700015b 100644 (file)
@@ -1,7 +1,7 @@
 import typing
 
 from .client import Client
-from .config import SSLConfig, TimeoutConfig
+from .config import CertTypes, TimeoutTypes, VerifyTypes
 from .models import (
     AuthTypes,
     CookieTypes,
@@ -17,16 +17,19 @@ def request(
     method: str,
     url: URLTypes,
     *,
+    params: QueryParamTypes = None,
     data: RequestData = b"",
     json: typing.Any = None,
-    params: QueryParamTypes = None,
     headers: HeaderTypes = None,
     cookies: CookieTypes = None,
-    stream: bool = False,
+    # files
     auth: AuthTypes = None,
+    timeout: TimeoutTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    # proxies
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    stream: bool = False,
 ) -> SyncResponse:
     with Client() as client:
         return client.request(
@@ -40,7 +43,8 @@ def request(
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            cert=cert,
+            verify=verify,
             timeout=timeout,
         )
 
@@ -54,8 +58,9 @@ def get(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "GET",
@@ -65,7 +70,8 @@ def get(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -79,8 +85,9 @@ def options(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "OPTIONS",
@@ -90,7 +97,8 @@ def options(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -104,8 +112,9 @@ def head(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = False,  #  Note: Differs to usual default.
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "HEAD",
@@ -115,7 +124,8 @@ def head(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -131,8 +141,9 @@ def post(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "POST",
@@ -144,7 +155,8 @@ def post(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -160,8 +172,9 @@ def put(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "PUT",
@@ -173,7 +186,8 @@ def put(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -189,8 +203,9 @@ def patch(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "PATCH",
@@ -202,7 +217,8 @@ def patch(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
 
@@ -218,8 +234,9 @@ def delete(
     stream: bool = False,
     auth: AuthTypes = None,
     allow_redirects: bool = True,
-    ssl: SSLConfig = None,
-    timeout: TimeoutConfig = None,
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    timeout: TimeoutTypes = None,
 ) -> SyncResponse:
     return request(
         "DELETE",
@@ -231,6 +248,7 @@ def delete(
         stream=stream,
         auth=auth,
         allow_redirects=allow_redirects,
-        ssl=ssl,
+        cert=cert,
+        verify=verify,
         timeout=timeout,
     )
index cc31e844046c0dd67b9a77721cbd67d29dfd57a0..39fa0aa5b42d16735d6462d59aad967377d6f43c 100644 (file)
@@ -6,11 +6,11 @@ from .auth import HTTPBasicAuth
 from .config import (
     DEFAULT_MAX_REDIRECTS,
     DEFAULT_POOL_LIMITS,
-    DEFAULT_SSL_CONFIG,
     DEFAULT_TIMEOUT_CONFIG,
+    CertTypes,
     PoolLimits,
-    SSLConfig,
-    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
 )
 from .dispatch.connection_pool import ConnectionPool
 from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
@@ -37,8 +37,9 @@ class AsyncClient:
         self,
         auth: AuthTypes = None,
         cookies: CookieTypes = None,
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        verify: VerifyTypes = True,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         dispatch: Dispatcher = None,
@@ -46,7 +47,11 @@ class AsyncClient:
     ):
         if dispatch is None:
             dispatch = ConnectionPool(
-                ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                pool_limits=pool_limits,
+                backend=backend,
             )
 
         self.auth = auth
@@ -64,8 +69,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "GET",
@@ -76,7 +82,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -90,8 +97,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "OPTIONS",
@@ -102,7 +110,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -116,8 +125,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "HEAD",
@@ -128,7 +138,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -144,8 +155,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "POST",
@@ -158,7 +170,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -174,8 +187,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "PUT",
@@ -188,7 +202,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -204,8 +219,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "PATCH",
@@ -218,7 +234,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -234,8 +251,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         return await self.request(
             "DELETE",
@@ -248,7 +266,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -265,8 +284,9 @@ class AsyncClient:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         request = Request(
             method,
@@ -283,7 +303,8 @@ class AsyncClient:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
         return response
@@ -306,9 +327,10 @@ class AsyncClient:
         *,
         stream: bool = False,
         auth: AuthTypes = None,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
         allow_redirects: bool = True,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         if auth is None:
             auth = self.auth
@@ -325,7 +347,8 @@ class AsyncClient:
         response = await self.send_handling_redirects(
             request,
             stream=stream,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
             allow_redirects=allow_redirects,
         )
@@ -336,8 +359,9 @@ class AsyncClient:
         request: Request,
         *,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
         allow_redirects: bool = True,
         history: typing.List[Response] = None,
     ) -> Response:
@@ -353,7 +377,7 @@ class AsyncClient:
                 raise RedirectLoop()
 
             response = await self.dispatch.send(
-                request, stream=stream, ssl=ssl, timeout=timeout
+                request, stream=stream, verify=verify, cert=cert, timeout=timeout
             )
             response.history = list(history)
             self.cookies.extract_cookies(response)
@@ -366,13 +390,14 @@ class AsyncClient:
             else:
 
                 async def send_next() -> Response:
-                    nonlocal request, response, ssl, allow_redirects, timeout, history
+                    nonlocal request, response, verify, cert, allow_redirects, timeout, history
                     request = self.build_redirect_request(request, response)
                     response = await self.send_handling_redirects(
                         request,
                         stream=stream,
                         allow_redirects=allow_redirects,
-                        ssl=ssl,
+                        verify=verify,
+                        cert=cert,
                         timeout=timeout,
                         history=history,
                     )
@@ -474,8 +499,9 @@ class Client:
     def __init__(
         self,
         auth: AuthTypes = None,
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        cert: CertTypes = None,
+        verify: VerifyTypes = True,
+        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
         dispatch: Dispatcher = None,
@@ -483,7 +509,8 @@ class Client:
     ) -> None:
         self._client = AsyncClient(
             auth=auth,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
             pool_limits=pool_limits,
             max_redirects=max_redirects,
@@ -509,8 +536,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         request = Request(
             method,
@@ -527,7 +555,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
         return response
@@ -542,8 +571,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "GET",
@@ -553,7 +583,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -567,8 +598,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "OPTIONS",
@@ -578,7 +610,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -592,8 +625,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "HEAD",
@@ -603,7 +637,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -619,8 +654,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "POST",
@@ -632,7 +668,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -648,8 +685,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "PUT",
@@ -661,7 +699,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -677,8 +716,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "PATCH",
@@ -690,7 +730,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -706,8 +747,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         return self.request(
             "DELETE",
@@ -719,7 +761,8 @@ class Client:
             stream=stream,
             auth=auth,
             allow_redirects=allow_redirects,
-            ssl=ssl,
+            verify=verify,
+            cert=cert,
             timeout=timeout,
         )
 
@@ -733,8 +776,9 @@ class Client:
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> SyncResponse:
         response = self._loop.run_until_complete(
             self._client.send(
@@ -742,7 +786,8 @@ class Client:
                 stream=stream,
                 auth=auth,
                 allow_redirects=allow_redirects,
-                ssl=ssl,
+                verify=verify,
+                cert=cert,
                 timeout=timeout,
             )
         )
index fb20d7c5f18aea6a480779b3fbf8ba27f44b9cfb..0c1d3409eb8152ea2fbeb49e9ece0762ccbbae93 100644 (file)
@@ -22,9 +22,6 @@ from .interfaces import (
     Protocol,
 )
 
-OptionalTimeout = typing.Optional[TimeoutConfig]
-
-
 SSL_MONKEY_PATCH_APPLIED = False
 
 
@@ -56,7 +53,7 @@ class Reader(BaseReader):
         self.stream_reader = stream_reader
         self.timeout = timeout
 
-    async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+    async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
         if timeout is None:
             timeout = self.timeout
 
@@ -78,7 +75,7 @@ class Writer(BaseWriter):
     def write_no_block(self, data: bytes) -> None:
         self.stream_writer.write(data)  # pragma: nocover
 
-    async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+    async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
         if not data:
             return
 
index 82fd125ff4fe018c1933fa3c147c4ac9570b3f4e..5b3c31316259d127f389234dd33ae88866f6c461 100644 (file)
@@ -5,18 +5,17 @@ import typing
 
 import certifi
 
+CertTypes = typing.Union[str, typing.Tuple[str, str]]
+VerifyTypes = typing.Union[str, bool]
+TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
+
 
 class SSLConfig:
     """
     SSL Configuration.
     """
 
-    def __init__(
-        self,
-        *,
-        cert: typing.Union[None, str, typing.Tuple[str, str]] = None,
-        verify: typing.Union[str, bool] = True,
-    ):
+    def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
         self.cert = cert
         self.verify = verify
 
@@ -31,6 +30,15 @@ class SSLConfig:
         class_name = self.__class__.__name__
         return f"{class_name}(cert={self.cert}, verify={self.verify})"
 
+    def with_overrides(
+        self, cert: CertTypes = None, verify: VerifyTypes = None
+    ) -> "SSLConfig":
+        cert = self.cert if cert is None else cert
+        verify = self.verify if verify is None else verify
+        if (cert == self.cert) and (verify == self.verify):
+            return self
+        return SSLConfig(cert=cert, verify=verify)
+
     async def load_ssl_context(self) -> ssl.SSLContext:
         if not hasattr(self, "ssl_context"):
             if not self.verify:
@@ -109,25 +117,33 @@ class TimeoutConfig:
 
     def __init__(
         self,
-        timeout: float = None,
+        timeout: TimeoutTypes = None,
         *,
         connect_timeout: float = None,
         read_timeout: float = None,
         write_timeout: float = None,
     ):
-        if timeout is not None:
+        if timeout is None:
+            self.connect_timeout = connect_timeout
+            self.read_timeout = read_timeout
+            self.write_timeout = write_timeout
+        else:
             # Specified as a single timeout value
             assert connect_timeout is None
             assert read_timeout is None
             assert write_timeout is None
-            connect_timeout = timeout
-            read_timeout = timeout
-            write_timeout = timeout
-
-        self.timeout = timeout
-        self.connect_timeout = connect_timeout
-        self.read_timeout = read_timeout
-        self.write_timeout = write_timeout
+            if isinstance(timeout, TimeoutConfig):
+                self.connect_timeout = timeout.connect_timeout
+                self.read_timeout = timeout.read_timeout
+                self.write_timeout = timeout.write_timeout
+            elif isinstance(timeout, tuple):
+                self.connect_timeout = timeout[0]
+                self.read_timeout = timeout[1]
+                self.write_timeout = timeout[2]
+            else:
+                self.connect_timeout = timeout
+                self.read_timeout = timeout
+                self.write_timeout = timeout
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -139,8 +155,8 @@ class TimeoutConfig:
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        if self.timeout is not None:
-            return f"{class_name}(timeout={self.timeout})"
+        if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
+            return f"{class_name}(timeout={self.connect_timeout})"
         return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
 
 
index 053a998081b6fe4d88c6b4fcae2af2d475948218..60214333fe990599bc044d662637071a11713a4e 100644 (file)
@@ -8,8 +8,11 @@ from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_SSL_CONFIG,
     DEFAULT_TIMEOUT_CONFIG,
+    CertTypes,
     SSLConfig,
     TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
 )
 from ..exceptions import ConnectTimeout
 from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
@@ -25,14 +28,15 @@ class HTTPConnection(Dispatcher):
     def __init__(
         self,
         origin: typing.Union[str, Origin],
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        verify: VerifyTypes = True,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
-        self.ssl = ssl
-        self.timeout = timeout
+        self.ssl = SSLConfig(cert=cert, verify=verify)
+        self.timeout = TimeoutConfig(timeout)
         self.backend = AsyncioBackend() if backend is None else backend
         self.release_func = release_func
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
@@ -42,11 +46,12 @@ class HTTPConnection(Dispatcher):
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         if self.h11_connection is None and self.h2_connection is None:
-            await self.connect(ssl=ssl, timeout=timeout)
+            await self.connect(verify=verify, cert=cert, timeout=timeout)
 
         if self.h2_connection is not None:
             response = await self.h2_connection.send(
@@ -61,12 +66,13 @@ class HTTPConnection(Dispatcher):
         return response
 
     async def connect(
-        self, ssl: SSLConfig = None, timeout: TimeoutConfig = None
+        self,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> None:
-        if ssl is None:
-            ssl = self.ssl
-        if timeout is None:
-            timeout = self.timeout
+        ssl = self.ssl.with_overrides(verify=verify, cert=cert)
+        timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
 
         host = self.origin.host
         port = self.origin.port
index ba92acd057c1f6077f51590243ef19c52249858b..e7cefbd7e4d27677badeb25b8a30b914ecdd035c 100644 (file)
@@ -4,11 +4,11 @@ from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_CA_BUNDLE_PATH,
     DEFAULT_POOL_LIMITS,
-    DEFAULT_SSL_CONFIG,
     DEFAULT_TIMEOUT_CONFIG,
+    CertTypes,
     PoolLimits,
-    SSLConfig,
-    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
 )
 from ..decoders import ACCEPT_ENCODING
 from ..exceptions import PoolTimeout
@@ -81,12 +81,14 @@ class ConnectionPool(Dispatcher):
     def __init__(
         self,
         *,
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        verify: VerifyTypes = True,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         backend: ConcurrencyBackend = None,
     ):
-        self.ssl = ssl
+        self.verify = verify
+        self.cert = cert
         self.timeout = timeout
         self.pool_limits = pool_limits
         self.is_closed = False
@@ -105,13 +107,14 @@ class ConnectionPool(Dispatcher):
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         connection = await self.acquire_connection(request.url.origin)
         try:
             response = await connection.send(
-                request, stream=stream, ssl=ssl, timeout=timeout
+                request, stream=stream, verify=verify, cert=cert, timeout=timeout
             )
         except BaseException as exc:
             self.active_connections.remove(connection)
@@ -128,7 +131,8 @@ class ConnectionPool(Dispatcher):
             await self.max_connections.acquire()
             connection = HTTPConnection(
                 origin,
-                ssl=self.ssl,
+                verify=self.verify,
+                cert=self.cert,
                 timeout=self.timeout,
                 backend=self.backend,
                 release_func=self.release_connection,
index fc5f34fca8ba444e9d6b788ad519d90e5203fbe9..4308f64a3a4432e7218541aa8316f7d51bd79ac8 100644 (file)
@@ -2,12 +2,7 @@ import typing
 
 import h11
 
-from ..config import (
-    DEFAULT_SSL_CONFIG,
-    DEFAULT_TIMEOUT_CONFIG,
-    SSLConfig,
-    TimeoutConfig,
-)
+from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, Dispatcher
 from ..models import Request, Response
@@ -22,8 +17,6 @@ H11Event = typing.Union[
 ]
 
 
-OptionalTimeout = typing.Optional[TimeoutConfig]
-
 # Callback signature: async def callback() -> None
 # In practice the callback will be a functools partial, which binds
 # the `ConnectionPool.release_connection(conn: HTTPConnection)` method.
@@ -45,8 +38,10 @@ class HTTP11Connection:
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
     async def send(
-        self, request: Request, stream: bool = False, timeout: TimeoutConfig = None
+        self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
     ) -> Response:
+        timeout = None if timeout is None else TimeoutConfig(timeout)
+
         #  Start sending the request.
         method = request.method.encode("ascii")
         target = request.url.full_path.encode("ascii")
@@ -97,18 +92,20 @@ class HTTP11Connection:
         self.h11_state.send(event)
         await self.writer.close()
 
-    async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
+    async def _body_iter(
+        self, timeout: TimeoutConfig = None
+    ) -> typing.AsyncIterator[bytes]:
         event = await self._receive_event(timeout)
         while isinstance(event, h11.Data):
             yield event.data
             event = await self._receive_event(timeout)
         assert isinstance(event, h11.EndOfMessage)
 
-    async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
+    async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
         data = self.h11_state.send(event)
         await self.writer.write(data, timeout)
 
-    async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
+    async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
         event = self.h11_state.next_event()
 
         while event is h11.NEED_DATA:
index 301a36c4df1be14a89556b8a351797af7ab5bcab..bb1857f307fa5f2d3fcd66de653139524c1b0186 100644 (file)
@@ -4,18 +4,11 @@ import typing
 import h2.connection
 import h2.events
 
-from ..config import (
-    DEFAULT_SSL_CONFIG,
-    DEFAULT_TIMEOUT_CONFIG,
-    SSLConfig,
-    TimeoutConfig,
-)
+from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, Dispatcher
 from ..models import Request, Response
 
-OptionalTimeout = typing.Optional[TimeoutConfig]
-
 
 class HTTP2Connection:
     READ_NUM_BYTES = 4096
@@ -31,8 +24,10 @@ class HTTP2Connection:
         self.initialized = False
 
     async def send(
-        self, request: Request, stream: bool = False, timeout: TimeoutConfig = None
+        self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
     ) -> Response:
+        timeout = None if timeout is None else TimeoutConfig(timeout)
+
         #  Start sending the request.
         if not self.initialized:
             self.initiate_connection()
@@ -89,7 +84,9 @@ class HTTP2Connection:
         self.writer.write_no_block(data_to_send)
         self.initialized = True
 
-    async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
+    async def send_headers(
+        self, request: Request, timeout: TimeoutConfig = None
+    ) -> int:
         stream_id = self.h2_state.get_next_available_stream_id()
         headers = [
             (b":method", request.method.encode("ascii")),
@@ -103,19 +100,19 @@ class HTTP2Connection:
         return stream_id
 
     async def send_data(
-        self, stream_id: int, data: bytes, timeout: OptionalTimeout
+        self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
     ) -> None:
         self.h2_state.send_data(stream_id, data)
         data_to_send = self.h2_state.data_to_send()
         await self.writer.write(data_to_send, timeout)
 
-    async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
+    async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
         self.h2_state.end_stream(stream_id)
         data_to_send = self.h2_state.data_to_send()
         await self.writer.write(data_to_send, timeout)
 
     async def body_iter(
-        self, stream_id: int, timeout: OptionalTimeout
+        self, stream_id: int, timeout: TimeoutConfig = None
     ) -> typing.AsyncIterator[bytes]:
         while True:
             event = await self.receive_event(stream_id, timeout)
@@ -125,7 +122,7 @@ class HTTP2Connection:
                 break
 
     async def receive_event(
-        self, stream_id: int, timeout: OptionalTimeout
+        self, stream_id: int, timeout: TimeoutConfig = None
     ) -> h2.events.Event:
         while not self.events[stream_id]:
             data = await self.reader.read(self.READ_NUM_BYTES, timeout)
index 20b1723928db0bfaa46b75620730ec61f1297f42..f2e846be3150bc2787e8da464b89c849d33da7bd 100644 (file)
@@ -3,7 +3,7 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import PoolLimits, SSLConfig, TimeoutConfig
+from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
 from .models import (
     URL,
     Headers,
@@ -15,8 +15,6 @@ from .models import (
     URLTypes,
 )
 
-OptionalTimeout = typing.Optional[TimeoutConfig]
-
 
 class Protocol(str, enum.Enum):
     HTTP_11 = "HTTP/1.1"
@@ -41,12 +39,15 @@ class Dispatcher:
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None
     ) -> Response:
         request = Request(method, url, data=data, params=params, headers=headers)
         self.prepare_request(request)
-        response = await self.send(request, stream=stream, ssl=ssl, timeout=timeout)
+        response = await self.send(
+            request, stream=stream, verify=verify, cert=cert, timeout=timeout
+        )
         return response
 
     def prepare_request(self, request: Request) -> None:
@@ -56,8 +57,9 @@ class Dispatcher:
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         raise NotImplementedError()  # pragma: nocover
 
@@ -83,7 +85,7 @@ class BaseReader:
     backend, or for stand-alone test cases.
     """
 
-    async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+    async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
 
@@ -97,7 +99,7 @@ class BaseWriter:
     def write_no_block(self, data: bytes) -> None:
         raise NotImplementedError()  # pragma: no cover
 
-    async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+    async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
         raise NotImplementedError()  # pragma: no cover
 
     async def close(self) -> None:
index 8a79a50a7180215039ff2dd9cb6d00172613434c..1d2b97239c82d486c3418a4e9d5aa7a01f32ba42 100644 (file)
@@ -4,12 +4,13 @@ import pytest
 
 from httpcore import (
     URL,
+    CertTypes,
     Client,
     Dispatcher,
     Request,
     Response,
-    SSLConfig,
-    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
 )
 
 
@@ -18,8 +19,9 @@ class MockDispatch(Dispatcher):
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
         return Response(200, content=body, request=request)
index 59e70b036044f9d31a890a28704ce9996797cb4b..a21f5c134f878f7a3ed22099bb1eec5b72ef49d2 100644 (file)
@@ -5,13 +5,14 @@ import pytest
 
 from httpcore import (
     URL,
+    CertTypes,
     Client,
     Cookies,
     Dispatcher,
     Request,
     Response,
-    SSLConfig,
-    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
 )
 
 
@@ -20,8 +21,9 @@ class MockDispatch(Dispatcher):
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         if request.url.path.startswith("/echo_cookies"):
             body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
index 0edad78a9a17229e80e2498e629186a85d9acb24..c3b384dc954053c191f2c4b273dadab0d7a24bf3 100644 (file)
@@ -6,14 +6,15 @@ import pytest
 from httpcore import (
     URL,
     AsyncClient,
+    CertTypes,
     Dispatcher,
     RedirectBodyUnavailable,
     RedirectLoop,
     Request,
     Response,
-    SSLConfig,
-    TimeoutConfig,
+    TimeoutTypes,
     TooManyRedirects,
+    VerifyTypes,
     codes,
 )
 
@@ -23,8 +24,9 @@ class MockDispatch(Dispatcher):
         self,
         request: Request,
         stream: bool = False,
-        ssl: SSLConfig = None,
-        timeout: TimeoutConfig = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
     ) -> Response:
         if request.url.path == "/redirect_301":
             status_code = codes.MOVED_PERMANENTLY
index 2edf3ada12ed63773c9da7dc7a8fe059c2b946cd..f323f55df9500ce03c2e84818ec52523326440be 100644 (file)
@@ -6,25 +6,35 @@ from httpcore import HTTPConnection, Request, SSLConfig
 @pytest.mark.asyncio
 async def test_get(server):
     conn = HTTPConnection(origin="http://127.0.0.1:8000/")
-    request = Request("GET", "http://127.0.0.1:8000/")
-    request.prepare()
-    response = await conn.send(request)
+    response = await conn.request("GET", "http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
-async def test_https_get(https_server):
-    http = HTTPConnection(origin="https://127.0.0.1:8001/", ssl=SSLConfig(verify=False))
-    response = await http.request("GET", "https://127.0.0.1:8001/")
+async def test_post(server):
+    conn = HTTPConnection(origin="http://127.0.0.1:8000/")
+    response = await conn.request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!")
+    assert response.status_code == 200
+
+
+@pytest.mark.asyncio
+async def test_https_get_with_ssl_defaults(https_server):
+    """
+    An HTTPS request, with default SSL configuration set on the client.
+    """
+    conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
+    response = await conn.request("GET", "https://127.0.0.1:8001/")
     assert response.status_code == 200
     assert response.content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
-async def test_post(server):
-    conn = HTTPConnection(origin="http://127.0.0.1:8000/")
-    request = Request("GET", "http://127.0.0.1:8000/", data=b"Hello, world!")
-    request.prepare()
-    response = await conn.send(request)
+async def test_https_get_with_sll_overrides(https_server):
+    """
+    An HTTPS request, with SSL configuration set on the request.
+    """
+    conn = HTTPConnection(origin="https://127.0.0.1:8001/")
+    response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
     assert response.status_code == 200
+    assert response.content == b"Hello, world!"
index bf12edb27df66ac984762e0901e1c374422fa0c7..4ee6d6e78eed6b1f1d5f7d74bfd78257a9527389 100644 (file)
@@ -94,3 +94,13 @@ def test_timeout_eq():
 def test_limits_eq():
     limits = httpcore.PoolLimits(hard_limit=100)
     assert limits == httpcore.PoolLimits(hard_limit=100)
+
+
+def test_timeout_from_tuple():
+    timeout = httpcore.TimeoutConfig(timeout=(5.0, 5.0, 5.0))
+    assert timeout == httpcore.TimeoutConfig(timeout=5.0)
+
+
+def test_timeout_from_config_instance():
+    timeout = httpcore.TimeoutConfig(timeout=(5.0))
+    assert httpcore.TimeoutConfig(timeout) == httpcore.TimeoutConfig(timeout=5.0)