From: Charles-Henri Bruyand Date: Wed, 27 Sep 2023 12:57:16 +0000 (+0200) Subject: doq: make sure connection is properly reset if necessary in the tests X-Git-Tag: rec-5.0.0-alpha2~6^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e7000cceb88a32b82e241a7a53fdd81439b90e0e;p=thirdparty%2Fpdns.git doq: make sure connection is properly reset if necessary in the tests --- diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 10f7ea89e1..f1206bb9a1 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -748,7 +748,6 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const DNSPacketMangler mangler(reinterpret_cast(query.data()), query.size()); mangler.skipDomainName(); mangler.skipBytes(4); - // Should we ensure message id is 0 ? auto unit = std::make_unique(std::move(query)); unit->dsc = &dsc; diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 97f8b72dcb..f2e5e31a99 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -28,6 +28,8 @@ import h2.config import pycurl from io import BytesIO +from doqclient import quic_query + from eqdnsmessage import AssertEqualDNSMessageMixin from proxyprotocol import ProxyProtocol @@ -1111,7 +1113,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): else: cls._toResponderQueue.put(response, True, timeout) - message = dns.query.quic(query, '127.0.0.1', timeout, port, verify=caFile, connection=connection, server_hostname=serverName) + message = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName) receivedQuery = None diff --git a/regression-tests.dnsdist/doqclient.py b/regression-tests.dnsdist/doqclient.py new file mode 100644 index 0000000000..94fa7bdc8e --- /dev/null +++ b/regression-tests.dnsdist/doqclient.py @@ -0,0 +1,127 @@ +import asyncio +import pickle +import ssl +import struct +from typing import Any, Optional, cast +import dns +import async_timeout + +from aioquic.quic.configuration import QuicConfiguration +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset +from aioquic.quic.logger import QuicFileLogger + +class DnsClientProtocol(QuicConnectionProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ack_waiter: Any = None + + def pack(self, data): + # serialize query + data = bytes(data) + data = struct.pack("!H", len(data)) + data + return data + + async def query(self, query: dns.message) -> None: + data = self.pack(query.to_wire()) + # send query and wait for answer + stream_id = self._quic.get_next_available_stream_id() + self._quic.send_stream_data(stream_id, data, end_stream=True) + waiter = self._loop.create_future() + self._ack_waiter = waiter + self.transmit() + + return await asyncio.shield(waiter) + + def quic_event_received(self, event: QuicEvent) -> None: + if self._ack_waiter is not None: + if isinstance(event, StreamDataReceived): + length = struct.unpack("!H", bytes(event.data[:2]))[0] + answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True) + + waiter = self._ack_waiter + self._ack_waiter = None + waiter.set_result(answer) + if isinstance(event, StreamReset): + waiter = self._ack_waiter + self._ack_waiter = None + waiter.set_result(event) + +class BogusDnsClientProtocol(DnsClientProtocol): + def pack(self, data): + # serialize query + data = bytes(data) + data = struct.pack("!H", len(data) * 2) + data + return data + + +async def async_quic_query( + configuration: QuicConfiguration, + host: str, + port: int, + query: dns.message, + timeout: float, + create_protocol=DnsClientProtocol +) -> None: + print("Connecting to {}:{}".format(host, port)) + async with connect( + host, + port, + configuration=configuration, + create_protocol=create_protocol, + ) as client: + client = cast(DnsClientProtocol, client) + print("Sending DNS query") + try: + async with async_timeout.timeout(timeout): + answer = await client.query(query) + return answer + except asyncio.TimeoutError as e: + return e + +class StreamResetError(Exception): + def __init__(self, error, message="Stream reset by peer"): + self.error = error + super().__init__(message) + +def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): + configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) + if verify: + configuration.load_verify_locations(verify) + result = asyncio.run( + async_quic_query( + configuration=configuration, + host=host, + port=port, + query=query, + timeout=timeout, + create_protocol=DnsClientProtocol + ) + ) + if (isinstance(result, StreamReset)): + raise StreamResetError(result.error_code) + if (isinstance(result, asyncio.TimeoutError)): + raise TimeoutError() + return result + +def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): + configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) + if verify: + configuration.load_verify_locations(verify) + result = asyncio.run( + async_quic_query( + configuration=configuration, + host=host, + port=port, + query=query, + timeout=timeout, + create_protocol=BogusDnsClientProtocol + ) + ) + if (isinstance(result, StreamReset)): + raise StreamResetError(result.error_code) + if (isinstance(result, asyncio.TimeoutError)): + raise TimeoutError() + return result diff --git a/regression-tests.dnsdist/requirements.txt b/regression-tests.dnsdist/requirements.txt index 4c6b1020bd..13ce6de184 100644 --- a/regression-tests.dnsdist/requirements.txt +++ b/regression-tests.dnsdist/requirements.txt @@ -12,3 +12,4 @@ lmdb>=0.95 cdbx==0.1.2 h2>=4.0.0 aioquic +async_timeout diff --git a/regression-tests.dnsdist/test_Async.py b/regression-tests.dnsdist/test_Async.py index e4b8e41c65..34ebf018c9 100644 --- a/regression-tests.dnsdist/test_Async.py +++ b/regression-tests.dnsdist/test_Async.py @@ -6,6 +6,8 @@ import sys import threading import unittest import dns +import doqclient + from dnsdisttests import DNSDistTest, pickAvailablePort def AsyncResponder(listenPath, responsePath): @@ -284,7 +286,7 @@ class AsyncTests(object): sender = getattr(self, method) try: (receivedQuery, receivedResponse) = sender(query, response) - except dns.exception.Timeout: + except doqclient.StreamResetError: if not self._fromResponderQueue.empty(): receivedQuery = self._fromResponderQueue.get(True, 1.0) receivedResponse = None @@ -323,7 +325,7 @@ class AsyncTests(object): sender = getattr(self, method) try: (_, receivedResponse) = sender(query, response=None, useQueue=False) - except dns.exception.Timeout: + except doqclient.StreamResetError: receivedResponse = None self.assertEqual(receivedResponse, None) diff --git a/regression-tests.dnsdist/test_DOQ.py b/regression-tests.dnsdist/test_DOQ.py index 150838eb33..9a87b62255 100644 --- a/regression-tests.dnsdist/test_DOQ.py +++ b/regression-tests.dnsdist/test_DOQ.py @@ -4,6 +4,38 @@ import clientsubnetoption from dnsdisttests import DNSDistTest from dnsdisttests import pickAvailablePort +from doqclient import quic_bogus_query +import doqclient + +class TestDOQBogus(DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _doqServerPort = pickAvailablePort() + _config_template = """ + newServer{address="127.0.0.1:%d"} + + addDOQLocal("127.0.0.1:%d", "%s", "%s") + """ + _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] + _verboseMode = True + + def testDOQBogus(self): + """ + DOQ: Test a bogus query (wrong packed length) + """ + name = 'bogus.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + + try: + message = quic_bogus_query(query, '127.0.0.1', 2.0, self._doqServerPort, verify=self._caCert, server_hostname=self._serverName) + self.assertFalse(True) + except doqclient.StreamResetError as e : + self.assertEqual(e.error, 2); class TestDOQ(DNSDistTest): _serverKey = 'server.key' @@ -87,10 +119,9 @@ class TestDOQ(DNSDistTest): dropped = False try: (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - # dns.quic doesn't seem to report correctly the quic error so the connection timeout - except dns.exception.Timeout : - dropped = True - self.assertTrue(dropped) + self.assertTrue(False) + except doqclient.StreamResetError as e : + self.assertEqual(e.error, 5); def testRefused(self): """ @@ -134,10 +165,9 @@ class TestDOQ(DNSDistTest): dropped = False try: (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - except dns.exception.Timeout : - dropped = True - self.assertTrue(dropped) - # dns.quic doesn't seem to report correctly the quic error so the connection timeout + self.assertTrue(False) + except doqclient.StreamResetError as e : + self.assertEqual(e.error, 5); class TestDOQWithCache(DNSDistTest): _serverKey = 'server.key'