]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Only create httpx transports when needed. (#1130)
authorBrian Wellington <bwelling@xbill.org>
Fri, 13 Sep 2024 18:33:51 +0000 (11:33 -0700)
committerGitHub <noreply@github.com>
Fri, 13 Sep 2024 18:33:51 +0000 (11:33 -0700)
When a caller passes an httpx client to https(), there's no need to
create a transport object that's not used.

dns/asyncquery.py
dns/query.py

index b93e267dde884819c0669621139d072d79f2f908..0848b87724ef66176688394d006c410277eb0cd5 100644 (file)
@@ -597,7 +597,6 @@ async def https(
         raise ValueError("session parameter must be an httpx.AsyncClient")
 
     wire = q.to_wire()
-    transport = None
     headers = {"accept": "application/dns-message"}
 
     h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
@@ -611,20 +610,21 @@ async def https(
     else:
         local_address = source
         local_port = source_port
-    transport = backend.get_transport_class()(
-        local_address=local_address,
-        http1=h1,
-        http2=h2,
-        verify=verify,
-        local_port=local_port,
-        bootstrap_address=bootstrap_address,
-        resolver=resolver,
-        family=family,
-    )
 
     if client:
         cm: contextlib.AbstractAsyncContextManager = NullContext(client)
     else:
+        transport = backend.get_transport_class()(
+            local_address=local_address,
+            http1=h1,
+            http2=h2,
+            verify=verify,
+            local_port=local_port,
+            bootstrap_address=bootstrap_address,
+            resolver=resolver,
+            family=family,
+        )
+
         cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
 
     async with cm as the_client:
index 1880a3f35eea1455bb9f09cc2e98223f0777a7d8..b4acf680d43e91a33460de03a3ee684713017bf3 100644 (file)
@@ -486,7 +486,6 @@ def https(
         raise ValueError("session parameter must be an httpx.Client")
 
     wire = q.to_wire()
-    transport = None
     headers = {"accept": "application/dns-message"}
 
     h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
@@ -500,20 +499,21 @@ def https(
     else:
         local_address = the_source[0]
         local_port = the_source[1]
-    transport = _HTTPTransport(
-        local_address=local_address,
-        http1=h1,
-        http2=h2,
-        verify=verify,
-        local_port=local_port,
-        bootstrap_address=bootstrap_address,
-        resolver=resolver,
-        family=family,
-    )
 
     if session:
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
     else:
+        transport = _HTTPTransport(
+            local_address=local_address,
+            http1=h1,
+            http2=h2,
+            verify=verify,
+            local_port=local_port,
+            bootstrap_address=bootstrap_address,
+            resolver=resolver,
+            family=family,
+        )
+
         cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
     with cm as session:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH