* `hstspreload` - determines whether IDNA-encoded host should be only accessed via HTTPS.
* `idna` - Internationalized domain name support.
* `rfc3986` - URL parsing & normalization.
+* `sniffio` - Async library autodetection.
* `brotlipy` - Decoding for "brotli" compressed responses. *(Optional)*
A huge amount of credit is due to `requests` for the API layout that
import hstspreload
from .auth import BasicAuth
-from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import ConcurrencyBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
* **app** - *(optional)* An ASGI application to send requests to,
rather than sending actual network requests.
* **backend** - *(optional)* A concurrency backend to use when issuing
- async requests.
+ async requests. Either 'auto', 'asyncio', 'trio', or a `ConcurrencyBackend`
+ instance. Defaults to 'auto', for autodetection.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
base_url: URLTypes = None,
dispatch: Dispatcher = None,
app: typing.Callable = None,
- backend: ConcurrencyBackend = None,
+ backend: typing.Union[str, ConcurrencyBackend] = "auto",
trust_env: bool = True,
uds: str = None,
):
- if backend is None:
- backend = AsyncioBackend()
-
if app is not None:
- dispatch = ASGIDispatch(app=app, backend=backend)
+ dispatch = ASGIDispatch(app=app)
if dispatch is None:
dispatch = ConnectionPool(
self.max_redirects = max_redirects
self.trust_env = trust_env
self.dispatch = dispatch
- self.concurrency_backend = backend
self.netrc = NetRCInfo()
if proxies is None and trust_env:
timeout: TimeoutTypes,
http_2: bool,
pool_limits: PoolLimits,
- backend: ConcurrencyBackend,
+ backend: typing.Union[str, ConcurrencyBackend],
trust_env: bool,
) -> typing.Dict[str, Dispatcher]:
def _proxy_from_url(url: URLTypes) -> Dispatcher:
--- /dev/null
+import ssl
+import typing
+
+import sniffio
+
+from ..config import PoolLimits, TimeoutConfig
+from .base import (
+ BaseBackgroundManager,
+ BaseEvent,
+ BasePoolSemaphore,
+ BaseSocketStream,
+ ConcurrencyBackend,
+ lookup_backend,
+)
+
+
+class AutoBackend(ConcurrencyBackend):
+ @property
+ def backend(self) -> ConcurrencyBackend:
+ if not hasattr(self, "_backend_implementation"):
+ backend = sniffio.current_async_library()
+ if backend not in ("asyncio", "trio"):
+ raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
+ self._backend_implementation = lookup_backend(backend)
+ return self._backend_implementation
+
+ async def open_tcp_stream(
+ self,
+ hostname: str,
+ port: int,
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> BaseSocketStream:
+ return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
+
+ async def open_uds_stream(
+ self,
+ path: str,
+ hostname: typing.Optional[str],
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> BaseSocketStream:
+ return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
+
+ def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+ return self.backend.get_semaphore(limits)
+
+ async def run_in_threadpool(
+ self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+ ) -> typing.Any:
+ return await self.backend.run_in_threadpool(func, *args, **kwargs)
+
+ def create_event(self) -> BaseEvent:
+ return self.backend.create_event()
+
+ def background_manager(
+ self, coroutine: typing.Callable, *args: typing.Any
+ ) -> BaseBackgroundManager:
+ return self.backend.background_manager(coroutine, *args)
from ..config import PoolLimits, TimeoutConfig
+def lookup_backend(
+ backend: typing.Union[str, "ConcurrencyBackend"] = "auto"
+) -> "ConcurrencyBackend":
+ if not isinstance(backend, str):
+ return backend
+
+ if backend == "auto":
+ from .auto import AutoBackend
+
+ return AutoBackend()
+ elif backend == "asyncio":
+ from .asyncio import AsyncioBackend
+
+ return AsyncioBackend()
+ elif backend == "trio":
+ from .trio import TrioBackend
+
+ return TrioBackend()
+
+ raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
+
+
class TimeoutFlag:
"""
A timeout flag holds a state of either read-timeout or write-timeout mode.
import typing
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
from .base import Dispatcher
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
- backend: ConcurrencyBackend = None,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
- self.backend = AsyncioBackend() if backend is None else backend
async def send(
self,
import ssl
import typing
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
+from ..concurrency.base import ConcurrencyBackend, lookup_backend
from ..config import (
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
trust_env: bool = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
http_2: bool = False,
- backend: ConcurrencyBackend = None,
+ backend: typing.Union[str, ConcurrencyBackend] = "auto",
release_func: typing.Optional[ReleaseCallback] = None,
uds: typing.Optional[str] = None,
):
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
self.timeout = TimeoutConfig(timeout)
self.http_2 = http_2
- self.backend = AsyncioBackend() if backend is None else backend
+ self.backend = lookup_backend(backend)
self.release_func = release_func
self.uds = uds
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
- stream, self.backend, on_release=on_release
+ stream, backend=self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
- self.h11_connection = HTTP11Connection(
- stream, self.backend, on_release=on_release
- )
+ self.h11_connection = HTTP11Connection(stream, on_release=on_release)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
import typing
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
+from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
from ..config import (
DEFAULT_POOL_LIMITS,
DEFAULT_TIMEOUT_CONFIG,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_2: bool = False,
- backend: ConcurrencyBackend = None,
+ backend: typing.Union[str, ConcurrencyBackend] = "auto",
uds: typing.Optional[str] = None,
):
self.verify = verify
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
- self.backend = AsyncioBackend() if backend is None else backend
- self.max_connections = self.backend.get_semaphore(pool_limits)
+ self.backend = lookup_backend(backend)
+
+ @property
+ def max_connections(self) -> BasePoolSemaphore:
+ # We do this lazily, to make sure backend autodetection always
+ # runs within an async context.
+ if not hasattr(self, "_max_connections"):
+ self._max_connections = self.backend.get_semaphore(self.pool_limits)
+ return self._max_connections
@property
def num_connections(self) -> int:
import h11
-from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseSocketStream, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectionClosed, ProtocolError
from ..models import Request, Response
def __init__(
self,
stream: BaseSocketStream,
- backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.stream = stream
- self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
self.timeout_flag = TimeoutFlag()
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
+ lookup_backend,
)
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
def __init__(
self,
stream: BaseSocketStream,
- backend: ConcurrencyBackend,
+ backend: typing.Union[str, ConcurrencyBackend] = "auto",
on_release: typing.Callable = None,
):
self.stream = stream
- self.backend = backend
+ self.backend = lookup_backend(backend)
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
import enum
+import typing
from base64 import b64encode
import h11
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_2: bool = False,
- backend: ConcurrencyBackend = None,
+ backend: typing.Union[str, ConcurrencyBackend] = "auto",
):
super(HTTPProxy, self).__init__(
)
else:
assert http_version == "HTTP/1.1"
- connection.h11_connection = HTTP11Connection(
- stream, self.backend, on_release=on_release
- )
+ connection.h11_connection = HTTP11Connection(stream, on_release=on_release)
def should_forward_origin(self, origin: Origin) -> bool:
"""Determines if the given origin should
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,httpxprof,tests
-known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,tqdm,trio,trustme,uvicorn
+known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,tqdm,trio,trustme,uvicorn
line_length = 88
multi_line_output = 3
zip_safe=False,
install_requires=[
"certifi",
+ "hstspreload",
"chardet==3.*",
"h11==0.8.*",
"h2==3.*",
- "hstspreload>=2019.8.27",
"idna==2.*",
"rfc3986==1.*",
+ "sniffio==1.*",
],
classifiers=[
"Development Status :: 3 - Alpha",
"cert",
"timeout",
"pool_limits",
- "backend",
]:
assert getattr(pool, prop) == getattr(proxy, prop)