]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
doq: make sure connection is properly reset if necessary in the tests
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Wed, 27 Sep 2023 12:57:16 +0000 (14:57 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Oct 2023 11:38:10 +0000 (13:38 +0200)
pdns/dnsdistdist/doq.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/doqclient.py [new file with mode: 0644]
regression-tests.dnsdist/requirements.txt
regression-tests.dnsdist/test_Async.py
regression-tests.dnsdist/test_DOQ.py

index 10f7ea89e15369926403d6c77bed395cff56b8f8..f1206bb9a115569320a7ede05cd371986e77118e 100644 (file)
@@ -748,7 +748,6 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const
     DNSPacketMangler mangler(reinterpret_cast<char*>(query.data()), query.size());
     mangler.skipDomainName();
     mangler.skipBytes(4);
-    // Should we ensure message id is 0 ?
 
     auto unit = std::make_unique<DOQUnit>(std::move(query));
     unit->dsc = &dsc;
index 97f8b72dcbb63a898a68563078b61415fa8af918..f2e5e31a99dd454806c24717fa1d14857cb5d441 100644 (file)
@@ -28,6 +28,8 @@ import h2.config
 import pycurl
 from io import BytesIO
 
+from doqclient import quic_query
+
 from eqdnsmessage import AssertEqualDNSMessageMixin
 from proxyprotocol import ProxyProtocol
 
@@ -1111,7 +1113,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             else:
                 cls._toResponderQueue.put(response, True, timeout)
 
-        message = dns.query.quic(query, '127.0.0.1', timeout, port, verify=caFile, connection=connection, server_hostname=serverName)
+        message = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
 
         receivedQuery = None
 
diff --git a/regression-tests.dnsdist/doqclient.py b/regression-tests.dnsdist/doqclient.py
new file mode 100644 (file)
index 0000000..94fa7bd
--- /dev/null
@@ -0,0 +1,127 @@
+import asyncio
+import pickle
+import ssl
+import struct
+from typing import Any, Optional, cast
+import dns
+import async_timeout
+
+from aioquic.quic.configuration import QuicConfiguration
+from aioquic.asyncio.client import connect
+from aioquic.asyncio.protocol import QuicConnectionProtocol
+from aioquic.quic.configuration import QuicConfiguration
+from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
+from aioquic.quic.logger import QuicFileLogger
+
+class DnsClientProtocol(QuicConnectionProtocol):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._ack_waiter: Any = None
+
+    def pack(self, data):
+        # serialize query
+        data = bytes(data)
+        data = struct.pack("!H", len(data)) + data
+        return data
+
+    async def query(self, query: dns.message) -> None:
+        data = self.pack(query.to_wire())
+        # send query and wait for answer
+        stream_id = self._quic.get_next_available_stream_id()
+        self._quic.send_stream_data(stream_id, data, end_stream=True)
+        waiter = self._loop.create_future()
+        self._ack_waiter = waiter
+        self.transmit()
+
+        return await asyncio.shield(waiter)
+
+    def quic_event_received(self, event: QuicEvent) -> None:
+        if self._ack_waiter is not None:
+            if isinstance(event, StreamDataReceived):
+                length = struct.unpack("!H", bytes(event.data[:2]))[0]
+                answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
+
+                waiter = self._ack_waiter
+                self._ack_waiter = None
+                waiter.set_result(answer)
+            if isinstance(event, StreamReset):
+                waiter = self._ack_waiter
+                self._ack_waiter = None
+                waiter.set_result(event)
+
+class BogusDnsClientProtocol(DnsClientProtocol):
+    def pack(self, data):
+        # serialize query
+        data = bytes(data)
+        data = struct.pack("!H", len(data) * 2) + data
+        return data
+
+
+async def async_quic_query(
+    configuration: QuicConfiguration,
+    host: str,
+    port: int,
+    query: dns.message,
+    timeout: float,
+    create_protocol=DnsClientProtocol
+) -> None:
+    print("Connecting to {}:{}".format(host, port))
+    async with connect(
+        host,
+        port,
+        configuration=configuration,
+        create_protocol=create_protocol,
+    ) as client:
+        client = cast(DnsClientProtocol, client)
+        print("Sending DNS query")
+        try:
+            async with async_timeout.timeout(timeout):
+                answer = await client.query(query)
+                return answer
+        except asyncio.TimeoutError as e:
+            return e
+
+class StreamResetError(Exception):
+    def __init__(self, error, message="Stream reset by peer"):
+        self.error = error
+        super().__init__(message)
+
+def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
+    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+    if verify:
+        configuration.load_verify_locations(verify)
+    result = asyncio.run(
+        async_quic_query(
+            configuration=configuration,
+            host=host,
+            port=port,
+            query=query,
+            timeout=timeout,
+            create_protocol=DnsClientProtocol
+        )
+    )
+    if (isinstance(result, StreamReset)):
+        raise StreamResetError(result.error_code)
+    if (isinstance(result, asyncio.TimeoutError)):
+        raise TimeoutError()
+    return result
+
+def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
+    configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+    if verify:
+        configuration.load_verify_locations(verify)
+    result = asyncio.run(
+        async_quic_query(
+            configuration=configuration,
+            host=host,
+            port=port,
+            query=query,
+            timeout=timeout,
+            create_protocol=BogusDnsClientProtocol
+        )
+    )
+    if (isinstance(result, StreamReset)):
+        raise StreamResetError(result.error_code)
+    if (isinstance(result, asyncio.TimeoutError)):
+        raise TimeoutError()
+    return result
index 4c6b1020bdab2d939f83f544ab89df5618eac95d..13ce6de1840cb2b70a864b7be811c6a5bc446cea 100644 (file)
@@ -12,3 +12,4 @@ lmdb>=0.95
 cdbx==0.1.2
 h2>=4.0.0
 aioquic
+async_timeout
index e4b8e41c651e5c4ce89e15ad8d792ca1da4e9b0a..34ebf018c90ba02e01c9f2669c6259fdad199de1 100644 (file)
@@ -6,6 +6,8 @@ import sys
 import threading
 import unittest
 import dns
+import doqclient
+
 from dnsdisttests import DNSDistTest, pickAvailablePort
 
 def AsyncResponder(listenPath, responsePath):
@@ -284,7 +286,7 @@ class AsyncTests(object):
                 sender = getattr(self, method)
                 try:
                     (receivedQuery, receivedResponse) = sender(query, response)
-                except dns.exception.Timeout:
+                except doqclient.StreamResetError:
                     if not self._fromResponderQueue.empty():
                         receivedQuery = self._fromResponderQueue.get(True, 1.0)
                     receivedResponse = None
@@ -323,7 +325,7 @@ class AsyncTests(object):
             sender = getattr(self, method)
             try:
                 (_, receivedResponse) = sender(query, response=None, useQueue=False)
-            except dns.exception.Timeout:
+            except doqclient.StreamResetError:
                 receivedResponse = None
             self.assertEqual(receivedResponse, None)
 
index 150838eb339b2086e358fba35dba9b512799d458..9a87b62255fc274ea858c2597fc49c80d547dd41 100644 (file)
@@ -4,6 +4,38 @@ import clientsubnetoption
 
 from dnsdisttests import DNSDistTest
 from dnsdisttests import pickAvailablePort
+from doqclient import quic_bogus_query
+import doqclient
+
+class TestDOQBogus(DNSDistTest):
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _doqServerPort = pickAvailablePort()
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    addDOQLocal("127.0.0.1:%d", "%s", "%s")
+    """
+    _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey']
+    _verboseMode = True
+
+    def testDOQBogus(self):
+        """
+        DOQ: Test a bogus query (wrong packed length)
+        """
+        name = 'bogus.doq.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 0
+
+        try:
+            message = quic_bogus_query(query, '127.0.0.1', 2.0, self._doqServerPort, verify=self._caCert, server_hostname=self._serverName)
+            self.assertFalse(True)
+        except doqclient.StreamResetError as e :
+            self.assertEqual(e.error, 2);
 
 class TestDOQ(DNSDistTest):
     _serverKey = 'server.key'
@@ -87,10 +119,9 @@ class TestDOQ(DNSDistTest):
         dropped = False
         try:
             (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName)
-            # dns.quic doesn't seem to report correctly the quic error so the connection timeout
-        except dns.exception.Timeout :
-            dropped = True
-        self.assertTrue(dropped)
+            self.assertTrue(False)
+        except doqclient.StreamResetError as e :
+            self.assertEqual(e.error, 5);
 
     def testRefused(self):
         """
@@ -134,10 +165,9 @@ class TestDOQ(DNSDistTest):
         dropped = False
         try:
             (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName)
-        except dns.exception.Timeout :
-            dropped = True
-        self.assertTrue(dropped)
-            # dns.quic doesn't seem to report correctly the quic error so the connection timeout
+            self.assertTrue(False)
+        except doqclient.StreamResetError as e :
+            self.assertEqual(e.error, 5);
 
 class TestDOQWithCache(DNSDistTest):
     _serverKey = 'server.key'