]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/doqclient.py
dnsdist: add a test for new protobuf field httpVersion
[thirdparty/pdns.git] / regression-tests.dnsdist / doqclient.py
1 import asyncio
2 import pickle
3 import ssl
4 import struct
5 from typing import Any, Optional, cast
6 import dns
7 import async_timeout
8
9 from aioquic.quic.configuration import QuicConfiguration
10 from aioquic.asyncio.client import connect
11 from aioquic.asyncio.protocol import QuicConnectionProtocol
12 from aioquic.quic.configuration import QuicConfiguration
13 from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
14 from aioquic.quic.logger import QuicFileLogger
15
16 class DnsClientProtocol(QuicConnectionProtocol):
17 def __init__(self, *args, **kwargs):
18 super().__init__(*args, **kwargs)
19 self._ack_waiter: Any = None
20
21 def pack(self, data):
22 # serialize query
23 data = bytes(data)
24 data = struct.pack("!H", len(data)) + data
25 return data
26
27 async def query(self, query: dns.message) -> None:
28 data = self.pack(query.to_wire())
29 # send query and wait for answer
30 stream_id = self._quic.get_next_available_stream_id()
31 self._quic.send_stream_data(stream_id, data, end_stream=True)
32 waiter = self._loop.create_future()
33 self._ack_waiter = waiter
34 self.transmit()
35
36 return await asyncio.shield(waiter)
37
38 def quic_event_received(self, event: QuicEvent) -> None:
39 if self._ack_waiter is not None:
40 if isinstance(event, StreamDataReceived):
41 length = struct.unpack("!H", bytes(event.data[:2]))[0]
42 answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
43
44 waiter = self._ack_waiter
45 self._ack_waiter = None
46 waiter.set_result(answer)
47 if isinstance(event, StreamReset):
48 waiter = self._ack_waiter
49 self._ack_waiter = None
50 waiter.set_result(event)
51
52 class BogusDnsClientProtocol(DnsClientProtocol):
53 def pack(self, data):
54 # serialize query
55 data = bytes(data)
56 data = struct.pack("!H", len(data) * 2) + data
57 return data
58
59
60 async def async_quic_query(
61 configuration: QuicConfiguration,
62 host: str,
63 port: int,
64 query: dns.message,
65 timeout: float,
66 create_protocol=DnsClientProtocol
67 ) -> None:
68 print("Connecting to {}:{}".format(host, port))
69 async with connect(
70 host,
71 port,
72 configuration=configuration,
73 create_protocol=create_protocol,
74 ) as client:
75 client = cast(DnsClientProtocol, client)
76 print("Sending DNS query")
77 try:
78 async with async_timeout.timeout(timeout):
79 answer = await client.query(query)
80 return answer
81 except asyncio.TimeoutError as e:
82 return e
83
84 class StreamResetError(Exception):
85 def __init__(self, error, message="Stream reset by peer"):
86 self.error = error
87 super().__init__(message)
88
89 def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
90 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
91 if verify:
92 configuration.load_verify_locations(verify)
93 result = asyncio.run(
94 async_quic_query(
95 configuration=configuration,
96 host=host,
97 port=port,
98 query=query,
99 timeout=timeout,
100 create_protocol=DnsClientProtocol
101 )
102 )
103 if (isinstance(result, StreamReset)):
104 raise StreamResetError(result.error_code)
105 if (isinstance(result, asyncio.TimeoutError)):
106 raise TimeoutError()
107 return result
108
109 def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
110 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
111 if verify:
112 configuration.load_verify_locations(verify)
113 result = asyncio.run(
114 async_quic_query(
115 configuration=configuration,
116 host=host,
117 port=port,
118 query=query,
119 timeout=timeout,
120 create_protocol=BogusDnsClientProtocol
121 )
122 )
123 if (isinstance(result, StreamReset)):
124 raise StreamResetError(result.error_code)
125 if (isinstance(result, asyncio.TimeoutError)):
126 raise TimeoutError()
127 return result