]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/doqclient.py
Merge pull request #13592 from rgacogne/qname-suffix-rule
[thirdparty/pdns.git] / regression-tests.dnsdist / doqclient.py
CommitLineData
e7000cce
CHB
1import asyncio
2import pickle
3import ssl
4import struct
5from typing import Any, Optional, cast
6import dns
7import async_timeout
8
9from aioquic.quic.configuration import QuicConfiguration
10from aioquic.asyncio.client import connect
11from aioquic.asyncio.protocol import QuicConnectionProtocol
12from aioquic.quic.configuration import QuicConfiguration
13from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
14from aioquic.quic.logger import QuicFileLogger
15
16class DnsClientProtocol(QuicConnectionProtocol):
17 def __init__(self, *args, **kwargs):
18 super().__init__(*args, **kwargs)
19 self._ack_waiter: Any = None
20
21 def pack(self, data):
22 # serialize query
23 data = bytes(data)
24 data = struct.pack("!H", len(data)) + data
25 return data
26
27 async def query(self, query: dns.message) -> None:
28 data = self.pack(query.to_wire())
29 # send query and wait for answer
30 stream_id = self._quic.get_next_available_stream_id()
31 self._quic.send_stream_data(stream_id, data, end_stream=True)
32 waiter = self._loop.create_future()
33 self._ack_waiter = waiter
34 self.transmit()
35
36 return await asyncio.shield(waiter)
37
38 def quic_event_received(self, event: QuicEvent) -> None:
39 if self._ack_waiter is not None:
40 if isinstance(event, StreamDataReceived):
41 length = struct.unpack("!H", bytes(event.data[:2]))[0]
42 answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
43
44 waiter = self._ack_waiter
45 self._ack_waiter = None
46 waiter.set_result(answer)
47 if isinstance(event, StreamReset):
48 waiter = self._ack_waiter
49 self._ack_waiter = None
50 waiter.set_result(event)
51
52class BogusDnsClientProtocol(DnsClientProtocol):
53 def pack(self, data):
54 # serialize query
55 data = bytes(data)
56 data = struct.pack("!H", len(data) * 2) + data
57 return data
58
59
60async def async_quic_query(
61 configuration: QuicConfiguration,
62 host: str,
63 port: int,
64 query: dns.message,
65 timeout: float,
66 create_protocol=DnsClientProtocol
67) -> None:
68 print("Connecting to {}:{}".format(host, port))
69 async with connect(
70 host,
71 port,
72 configuration=configuration,
73 create_protocol=create_protocol,
74 ) as client:
75 client = cast(DnsClientProtocol, client)
76 print("Sending DNS query")
77 try:
78 async with async_timeout.timeout(timeout):
79 answer = await client.query(query)
80 return answer
81 except asyncio.TimeoutError as e:
82 return e
83
84class StreamResetError(Exception):
85 def __init__(self, error, message="Stream reset by peer"):
86 self.error = error
87 super().__init__(message)
88
89def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
90 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
91 if verify:
92 configuration.load_verify_locations(verify)
93 result = asyncio.run(
94 async_quic_query(
95 configuration=configuration,
96 host=host,
97 port=port,
98 query=query,
99 timeout=timeout,
100 create_protocol=DnsClientProtocol
101 )
102 )
103 if (isinstance(result, StreamReset)):
104 raise StreamResetError(result.error_code)
105 if (isinstance(result, asyncio.TimeoutError)):
106 raise TimeoutError()
107 return result
108
109def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
110 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
111 if verify:
112 configuration.load_verify_locations(verify)
113 result = asyncio.run(
114 async_quic_query(
115 configuration=configuration,
116 host=host,
117 port=port,
118 query=query,
119 timeout=timeout,
120 create_protocol=BogusDnsClientProtocol
121 )
122 )
123 if (isinstance(result, StreamReset)):
124 raise StreamResetError(result.error_code)
125 if (isinstance(result, asyncio.TimeoutError)):
126 raise TimeoutError()
127 return result