]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Replace h3 parameter with http_version. (#1068)
authorBrian Wellington <bwelling@xbill.org>
Thu, 21 Mar 2024 12:20:46 +0000 (05:20 -0700)
committerGitHub <noreply@github.com>
Thu, 21 Mar 2024 12:20:46 +0000 (05:20 -0700)
This allows more flexibility; clients can specify which http version
they want, or use the default.

dns/asyncquery.py
dns/nameserver.py
dns/query.py
tests/test_async.py
tests/test_doh.py

index e3003b1f5296a9759777e11ac8b890e06909925f..f7d4df44cf4cdf19a314eb282b586f256a649603 100644 (file)
@@ -41,6 +41,7 @@ from dns.query import (
     BadResponse,
     NoDOH,
     NoDOQ,
+    HTTPVersion,
     UDPMode,
     _check_status,
     _compute_times,
@@ -533,7 +534,7 @@ async def https(
     bootstrap_address: Optional[str] = None,
     resolver: Optional["dns.asyncresolver.Resolver"] = None,
     family: int = socket.AF_UNSPEC,
-    h3: bool = False,
+    http_version: HTTPVersion = HTTPVersion.DEFAULT,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -559,7 +560,7 @@ async def https(
     else:
         url = where
 
-    if h3:
+    if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh):
         if bootstrap_address is None:
             parsed = urllib.parse.urlparse(url)
             resolver = _maybe_get_resolver(resolver)
@@ -595,6 +596,9 @@ async def https(
     transport = None
     headers = {"accept": "application/dns-message"}
 
+    h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
+    h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
+
     backend = dns.asyncbackend.get_default_backend()
 
     if source is None:
@@ -605,8 +609,8 @@ async def https(
         local_port = source_port
     transport = backend.get_transport_class()(
         local_address=local_address,
-        http1=True,
-        http2=True,
+        http1=h1,
+        http2=h2,
         verify=verify,
         local_port=local_port,
         bootstrap_address=bootstrap_address,
@@ -618,7 +622,7 @@ async def https(
         cm: contextlib.AbstractAsyncContextManager = NullContext(client)
     else:
         cm = httpx.AsyncClient(
-            http1=True, http2=True, verify=verify, transport=transport
+            http1=h1, http2=h2, verify=verify, transport=transport
         )
 
     async with cm as the_client:
index e8068e7e456cf8fe66455f3ec5663a00da4f0631..b02a239b3c5886a47aeeb09b68c235c59fb91b95 100644 (file)
@@ -168,14 +168,14 @@ class DoHNameserver(Nameserver):
         bootstrap_address: Optional[str] = None,
         verify: Union[bool, str] = True,
         want_get: bool = False,
-        h3: bool = False,
+        http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
     ):
         super().__init__()
         self.url = url
         self.bootstrap_address = bootstrap_address
         self.verify = verify
         self.want_get = want_get
-        self.h3 = h3
+        self.http_version = http_version
 
     def kind(self):
         return "DoH"
@@ -216,7 +216,7 @@ class DoHNameserver(Nameserver):
             ignore_trailing=ignore_trailing,
             verify=self.verify,
             post=(not self.want_get),
-            h3=self.h3,
+            http_version=self.http_version,
         )
 
     async def async_query(
@@ -241,7 +241,7 @@ class DoHNameserver(Nameserver):
             ignore_trailing=ignore_trailing,
             verify=self.verify,
             post=(not self.want_get),
-            h3=self.h3,
+            http_version=self.http_version,
         )
 
 
index bfd6908c392f60bd559e8741bd4fc378a015d731..f3907c6fee4181b332a8848110029b461fe47c3a 100644 (file)
@@ -351,6 +351,22 @@ def _maybe_get_resolver(
     return resolver
 
 
+class HTTPVersion(enum.IntEnum):
+    """Which version of HTTP should be used?
+
+    DEFAULT will select the first version from the list [2, 1.1, 3] that
+    is available.
+    """
+
+    DEFAULT = 0
+    HTTP_1 = 1
+    H1 = 1
+    HTTP_2 = 2
+    H2 = 2
+    HTTP_3 = 3
+    H3 = 3
+
+
 def https(
     q: dns.message.Message,
     where: str,
@@ -367,7 +383,7 @@ def https(
     verify: Union[bool, str] = True,
     resolver: Optional["dns.resolver.Resolver"] = None,
     family: int = socket.AF_UNSPEC,
-    h3: bool = False,
+    http_version: HTTPVersion = HTTPVersion.DEFAULT,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -417,7 +433,7 @@ def https(
     *family*, an ``int``, the address family.  If socket.AF_UNSPEC (the default), both A
     and AAAA records will be retrieved.
 
-    *h3*, a ``bool``.  If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1.
+    *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
 
     Returns a ``dns.message.Message``.
     """
@@ -433,7 +449,7 @@ def https(
     else:
         url = where
 
-    if h3:
+    if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh):
         if bootstrap_address is None:
             parsed = urllib.parse.urlparse(url)
             resolver = _maybe_get_resolver(resolver)
@@ -469,6 +485,9 @@ def https(
     transport = None
     headers = {"accept": "application/dns-message"}
 
+    h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
+    h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
+
     # set source port and source address
 
     if the_source is None:
@@ -479,8 +498,8 @@ def https(
         local_port = the_source[1]
     transport = _HTTPTransport(
         local_address=local_address,
-        http1=True,
-        http2=True,
+        http1=h1,
+        http2=h2,
         verify=verify,
         local_port=local_port,
         bootstrap_address=bootstrap_address,
@@ -491,7 +510,7 @@ def https(
     if session:
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
     else:
-        cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
+        cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
     with cm as session:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
         # GET and POST examples
index e1cb8610f46c2a9c1fb33b5d385184be17e4c11b..f0c227de08e884fa742b251f6dc00c452b7a80d6 100644 (file)
@@ -570,7 +570,7 @@ class AsyncTests(unittest.TestCase):
                 post=False,
                 timeout=4,
                 family=family,
-                h3=True,
+                http_version=dns.asyncquery.HTTPVersion.H3,
             )
             self.assertTrue(q.is_response(r))
 
@@ -587,7 +587,7 @@ class AsyncTests(unittest.TestCase):
                 post=True,
                 timeout=4,
                 family=family,
-                h3=True,
+                http_version=dns.asyncquery.HTTPVersion.H3,
             )
             self.assertTrue(q.is_response(r))
 
@@ -603,7 +603,7 @@ class AsyncTests(unittest.TestCase):
                 nameserver_ip,
                 post=False,
                 timeout=4,
-                h3=True,
+                http_version=dns.asyncquery.HTTPVersion.H3,
             )
             self.assertTrue(q.is_response(r))
 
index 692b2d670d890bd1aa16c24998244726c88673ff..900a3fae414ea611da7485da01094958a3751605 100644 (file)
@@ -203,7 +203,7 @@ class DNSOverHTTP3TestCase(unittest.TestCase):
             post=False,
             timeout=4,
             family=family,
-            h3=True,
+            http_version=dns.query.HTTPVersion.H3,
         )
         self.assertTrue(q.is_response(r))
 
@@ -216,7 +216,7 @@ class DNSOverHTTP3TestCase(unittest.TestCase):
             post=True,
             timeout=4,
             family=family,
-            h3=True,
+            http_version=dns.query.HTTPVersion.H3,
         )
         self.assertTrue(q.is_response(r))
 
@@ -233,7 +233,7 @@ class DNSOverHTTP3TestCase(unittest.TestCase):
                 nameserver_ip,
                 post=False,
                 timeout=4,
-                h3=True,
+                http_version=dns.query.HTTPVersion.H3,
             )
             self.assertTrue(q.is_response(r))
         if resolver_v6_addresses:
@@ -244,7 +244,7 @@ class DNSOverHTTP3TestCase(unittest.TestCase):
                 nameserver_ip,
                 post=False,
                 timeout=4,
-                h3=True,
+                http_version=dns.query.HTTPVersion.H3,
             )
             self.assertTrue(q.is_response(r))