]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add NetworkOptions
authorTom Christie <tom.christie@krakentechnologies.ltd>
Fri, 12 Jan 2024 12:14:15 +0000 (12:14 +0000)
committerTom Christie <tom.christie@krakentechnologies.ltd>
Fri, 12 Jan 2024 12:14:15 +0000 (12:14 +0000)
httpx/__init__.py
httpx/_config.py
httpx/_transports/default.py
tests/test_config.py

index f61112f8b20e11be3395d6f9265082ad762a7638..1665b35a1f4a7770859b4791e0f531d53c111849 100644 (file)
@@ -2,7 +2,7 @@ from .__version__ import __description__, __title__, __version__
 from ._api import delete, get, head, options, patch, post, put, request, stream
 from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth
 from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client
-from ._config import Limits, Proxy, Timeout, create_ssl_context
+from ._config import Limits, NetworkOptions, Proxy, Timeout, create_ssl_context
 from ._content import ByteStream
 from ._exceptions import (
     CloseError,
@@ -96,6 +96,7 @@ __all__ = [
     "MockTransport",
     "NetRCAuth",
     "NetworkError",
+    "NetworkOptions",
     "options",
     "patch",
     "PoolTimeout",
index 0cfd552e49b9c510fcf5c7466118dcd4ed574d5f..69c3c6ffa8924d059449406dc93e2d633f0bea70 100644 (file)
@@ -12,6 +12,13 @@ from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
 from ._urls import URL
 from ._utils import get_ca_bundle_from_env
 
+
+SOCKET_OPTION = typing.Union[
+    typing.Tuple[int, int, int],
+    typing.Tuple[int, int, typing.Union[bytes, bytearray]],
+    typing.Tuple[int, int, None, int],
+]
+
 DEFAULT_CIPHERS = ":".join(
     [
         "ECDHE+AESGCM",
@@ -363,6 +370,37 @@ class Proxy:
         return f"Proxy({url_str}{auth_str}{headers_str})"
 
 
+class NetworkOptions:
+    def __init__(
+        self,
+        connection_retries: int = 0,
+        local_address: typing.Optional[str] = None,
+        socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
+        uds: typing.Optional[str] = None,
+    ) -> None:
+        self.connection_retries = connection_retries
+        self.local_address = local_address
+        self.socket_options = socket_options
+        self.uds = uds
+
+    def __repr__(self) -> str:
+        defaults = {
+            "connection_retries": 0,
+            "local_address": None,
+            "socket_options": None,
+            "uds": None,
+        }
+        params = ", ".join(
+            [
+                f"{attr}={getattr(self, attr)!r}"
+                for attr, default in defaults.items()
+                if getattr(self, attr) != default
+            ]
+        )
+        return f"NetworkOptions({params})"
+
+
 DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0)
 DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20)
+DEFAULT_NETWORK_OPTIONS = NetworkOptions(connection_retries=0)
 DEFAULT_MAX_REDIRECTS = 20
index 14a087389a8ba910a6b4d88ef69527ba9eb335c0..7802026940742e943d274185b43e6fb0b11b74d3 100644 (file)
@@ -29,7 +29,14 @@ from types import TracebackType
 
 import httpcore
 
-from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context
+from .._config import (
+    DEFAULT_LIMITS,
+    DEFAULT_NETWORK_OPTIONS,
+    Proxy,
+    Limits,
+    NetworkOptions,
+    create_ssl_context,
+)
 from .._exceptions import (
     ConnectError,
     ConnectTimeout,
@@ -54,12 +61,6 @@ from .base import AsyncBaseTransport, BaseTransport
 T = typing.TypeVar("T", bound="HTTPTransport")
 A = typing.TypeVar("A", bound="AsyncHTTPTransport")
 
-SOCKET_OPTION = typing.Union[
-    typing.Tuple[int, int, int],
-    typing.Tuple[int, int, typing.Union[bytes, bytearray]],
-    typing.Tuple[int, int, None, int],
-]
-
 
 @contextlib.contextmanager
 def map_httpcore_exceptions() -> typing.Iterator[None]:
@@ -126,10 +127,7 @@ class HTTPTransport(BaseTransport):
         limits: Limits = DEFAULT_LIMITS,
         trust_env: bool = True,
         proxy: typing.Optional[ProxyTypes] = None,
-        uds: typing.Optional[str] = None,
-        local_address: typing.Optional[str] = None,
-        retries: int = 0,
-        socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
+        network_options: NetworkOptions = DEFAULT_NETWORK_OPTIONS,
     ) -> None:
         ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
         proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
@@ -142,10 +140,10 @@ class HTTPTransport(BaseTransport):
                 keepalive_expiry=limits.keepalive_expiry,
                 http1=http1,
                 http2=http2,
-                uds=uds,
-                local_address=local_address,
-                retries=retries,
-                socket_options=socket_options,
+                uds=network_options.uds,
+                local_address=network_options.local_address,
+                retries=network_options.connection_retries,
+                socket_options=network_options.socket_options,
             )
         elif proxy.url.scheme in ("http", "https"):
             self._pool = httpcore.HTTPProxy(
@@ -164,7 +162,10 @@ class HTTPTransport(BaseTransport):
                 keepalive_expiry=limits.keepalive_expiry,
                 http1=http1,
                 http2=http2,
-                socket_options=socket_options,
+                uds=network_options.uds,
+                local_address=network_options.local_address,
+                retries=network_options.connection_retries,
+                socket_options=network_options.socket_options,
             )
         elif proxy.url.scheme == "socks5":
             try:
@@ -267,10 +268,7 @@ class AsyncHTTPTransport(AsyncBaseTransport):
         limits: Limits = DEFAULT_LIMITS,
         trust_env: bool = True,
         proxy: typing.Optional[ProxyTypes] = None,
-        uds: typing.Optional[str] = None,
-        local_address: typing.Optional[str] = None,
-        retries: int = 0,
-        socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
+        network_options: NetworkOptions = DEFAULT_NETWORK_OPTIONS,
     ) -> None:
         ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
         proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
@@ -283,10 +281,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
                 keepalive_expiry=limits.keepalive_expiry,
                 http1=http1,
                 http2=http2,
-                uds=uds,
-                local_address=local_address,
-                retries=retries,
-                socket_options=socket_options,
+                uds=network_options.uds,
+                local_address=network_options.local_address,
+                retries=network_options.connection_retries,
+                socket_options=network_options.socket_options,
             )
         elif proxy.url.scheme in ("http", "https"):
             self._pool = httpcore.AsyncHTTPProxy(
@@ -304,7 +302,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
                 keepalive_expiry=limits.keepalive_expiry,
                 http1=http1,
                 http2=http2,
-                socket_options=socket_options,
+                uds=network_options.uds,
+                local_address=network_options.local_address,
+                retries=network_options.connection_retries,
+                socket_options=network_options.socket_options,
             )
         elif proxy.url.scheme == "socks5":
             try:
index 6f6ee4f575141ee2debefbf8342d3168064bb5e3..cef0aa37c2e0e2d2501612349bf74a0f2063121f 100644 (file)
@@ -221,3 +221,18 @@ def test_proxy_with_auth_from_url():
 def test_invalid_proxy_scheme():
     with pytest.raises(ValueError):
         httpx.Proxy("invalid://example.com")
+
+
+def test_network_options():
+    network_options = httpx.NetworkOptions()
+    assert repr(network_options) == "NetworkOptions()"
+
+    network_options = httpx.NetworkOptions(connection_retries=1)
+    assert repr(network_options) == "NetworkOptions(connection_retries=1)"
+
+    network_options = httpx.NetworkOptions(
+        connection_retries=1, local_address="0.0.0.0"
+    )
+    assert repr(network_options) == (
+        "NetworkOptions(connection_retries=1, local_address='0.0.0.0')"
+    )