--- /dev/null
+import asyncio
+import pickle
+import ssl
+import struct
+from typing import Any, Optional, cast
+import dns
+import async_timeout
+
+from aioquic.quic.configuration import QuicConfiguration
+from aioquic.asyncio.client import connect
+from aioquic.asyncio.protocol import QuicConnectionProtocol
+from aioquic.quic.configuration import QuicConfiguration
+from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
+from aioquic.quic.logger import QuicFileLogger
+
+class DnsClientProtocol(QuicConnectionProtocol):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ack_waiter: Any = None
+
+ def pack(self, data):
+ # serialize query
+ data = bytes(data)
+ data = struct.pack("!H", len(data)) + data
+ return data
+
+ async def query(self, query: dns.message) -> None:
+ data = self.pack(query.to_wire())
+ # send query and wait for answer
+ stream_id = self._quic.get_next_available_stream_id()
+ self._quic.send_stream_data(stream_id, data, end_stream=True)
+ waiter = self._loop.create_future()
+ self._ack_waiter = waiter
+ self.transmit()
+
+ return await asyncio.shield(waiter)
+
+ def quic_event_received(self, event: QuicEvent) -> None:
+ if self._ack_waiter is not None:
+ if isinstance(event, StreamDataReceived):
+ length = struct.unpack("!H", bytes(event.data[:2]))[0]
+ answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
+
+ waiter = self._ack_waiter
+ self._ack_waiter = None
+ waiter.set_result(answer)
+ if isinstance(event, StreamReset):
+ waiter = self._ack_waiter
+ self._ack_waiter = None
+ waiter.set_result(event)
+
+class BogusDnsClientProtocol(DnsClientProtocol):
+ def pack(self, data):
+ # serialize query
+ data = bytes(data)
+ data = struct.pack("!H", len(data) * 2) + data
+ return data
+
+
+async def async_quic_query(
+ configuration: QuicConfiguration,
+ host: str,
+ port: int,
+ query: dns.message,
+ timeout: float,
+ create_protocol=DnsClientProtocol
+) -> None:
+ print("Connecting to {}:{}".format(host, port))
+ async with connect(
+ host,
+ port,
+ configuration=configuration,
+ create_protocol=create_protocol,
+ ) as client:
+ client = cast(DnsClientProtocol, client)
+ print("Sending DNS query")
+ try:
+ async with async_timeout.timeout(timeout):
+ answer = await client.query(query)
+ return answer
+ except asyncio.TimeoutError as e:
+ return e
+
+class StreamResetError(Exception):
+ def __init__(self, error, message="Stream reset by peer"):
+ self.error = error
+ super().__init__(message)
+
+def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
+ configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+ if verify:
+ configuration.load_verify_locations(verify)
+ result = asyncio.run(
+ async_quic_query(
+ configuration=configuration,
+ host=host,
+ port=port,
+ query=query,
+ timeout=timeout,
+ create_protocol=DnsClientProtocol
+ )
+ )
+ if (isinstance(result, StreamReset)):
+ raise StreamResetError(result.error_code)
+ if (isinstance(result, asyncio.TimeoutError)):
+ raise TimeoutError()
+ return result
+
+def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
+ configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
+ if verify:
+ configuration.load_verify_locations(verify)
+ result = asyncio.run(
+ async_quic_query(
+ configuration=configuration,
+ host=host,
+ port=port,
+ query=query,
+ timeout=timeout,
+ create_protocol=BogusDnsClientProtocol
+ )
+ )
+ if (isinstance(result, StreamReset)):
+ raise StreamResetError(result.error_code)
+ if (isinstance(result, asyncio.TimeoutError)):
+ raise TimeoutError()
+ return result
import threading
import unittest
import dns
+import doqclient
+
from dnsdisttests import DNSDistTest, pickAvailablePort
def AsyncResponder(listenPath, responsePath):
sender = getattr(self, method)
try:
(receivedQuery, receivedResponse) = sender(query, response)
- except dns.exception.Timeout:
+ except doqclient.StreamResetError:
if not self._fromResponderQueue.empty():
receivedQuery = self._fromResponderQueue.get(True, 1.0)
receivedResponse = None
sender = getattr(self, method)
try:
(_, receivedResponse) = sender(query, response=None, useQueue=False)
- except dns.exception.Timeout:
+ except doqclient.StreamResetError:
receivedResponse = None
self.assertEqual(receivedResponse, None)
from dnsdisttests import DNSDistTest
from dnsdisttests import pickAvailablePort
+from doqclient import quic_bogus_query
+import doqclient
+
+class TestDOQBogus(DNSDistTest):
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _doqServerPort = pickAvailablePort()
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ addDOQLocal("127.0.0.1:%d", "%s", "%s")
+ """
+ _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey']
+ _verboseMode = True
+
+ def testDOQBogus(self):
+ """
+ DOQ: Test a bogus query (wrong packed length)
+ """
+ name = 'bogus.doq.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ expectedQuery.id = 0
+
+ try:
+ message = quic_bogus_query(query, '127.0.0.1', 2.0, self._doqServerPort, verify=self._caCert, server_hostname=self._serverName)
+ self.assertFalse(True)
+ except doqclient.StreamResetError as e :
+ self.assertEqual(e.error, 2);
class TestDOQ(DNSDistTest):
_serverKey = 'server.key'
dropped = False
try:
(_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName)
- # dns.quic doesn't seem to report correctly the quic error so the connection timeout
- except dns.exception.Timeout :
- dropped = True
- self.assertTrue(dropped)
+ self.assertTrue(False)
+ except doqclient.StreamResetError as e :
+ self.assertEqual(e.error, 5);
def testRefused(self):
"""
dropped = False
try:
(_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName)
- except dns.exception.Timeout :
- dropped = True
- self.assertTrue(dropped)
- # dns.quic doesn't seem to report correctly the quic error so the connection timeout
+ self.assertTrue(False)
+ except doqclient.StreamResetError as e :
+ self.assertEqual(e.error, 5);
class TestDOQWithCache(DNSDistTest):
_serverKey = 'server.key'