]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/doqclient.py
Merge pull request #13923 from rgacogne/ddist-xfr-response-chain
[thirdparty/pdns.git] / regression-tests.dnsdist / doqclient.py
CommitLineData
e7000cce
CHB
1import asyncio
2import pickle
3import ssl
4import struct
5from typing import Any, Optional, cast
6import dns
9ec97c74 7import dns.message
e7000cce
CHB
8import async_timeout
9
10from aioquic.quic.configuration import QuicConfiguration
11from aioquic.asyncio.client import connect
12from aioquic.asyncio.protocol import QuicConnectionProtocol
13from aioquic.quic.configuration import QuicConfiguration
14from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
15from aioquic.quic.logger import QuicFileLogger
16
17class DnsClientProtocol(QuicConnectionProtocol):
18 def __init__(self, *args, **kwargs):
19 super().__init__(*args, **kwargs)
20 self._ack_waiter: Any = None
21
22 def pack(self, data):
23 # serialize query
24 data = bytes(data)
25 data = struct.pack("!H", len(data)) + data
26 return data
27
28 async def query(self, query: dns.message) -> None:
29 data = self.pack(query.to_wire())
30 # send query and wait for answer
31 stream_id = self._quic.get_next_available_stream_id()
32 self._quic.send_stream_data(stream_id, data, end_stream=True)
33 waiter = self._loop.create_future()
34 self._ack_waiter = waiter
35 self.transmit()
36
37 return await asyncio.shield(waiter)
38
39 def quic_event_received(self, event: QuicEvent) -> None:
40 if self._ack_waiter is not None:
41 if isinstance(event, StreamDataReceived):
42 length = struct.unpack("!H", bytes(event.data[:2]))[0]
43 answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
44
45 waiter = self._ack_waiter
46 self._ack_waiter = None
47 waiter.set_result(answer)
48 if isinstance(event, StreamReset):
49 waiter = self._ack_waiter
50 self._ack_waiter = None
51 waiter.set_result(event)
52
53class BogusDnsClientProtocol(DnsClientProtocol):
54 def pack(self, data):
55 # serialize query
56 data = bytes(data)
57 data = struct.pack("!H", len(data) * 2) + data
58 return data
59
60
61async def async_quic_query(
62 configuration: QuicConfiguration,
63 host: str,
64 port: int,
65 query: dns.message,
66 timeout: float,
67 create_protocol=DnsClientProtocol
68) -> None:
69 print("Connecting to {}:{}".format(host, port))
70 async with connect(
71 host,
72 port,
73 configuration=configuration,
74 create_protocol=create_protocol,
75 ) as client:
76 client = cast(DnsClientProtocol, client)
77 print("Sending DNS query")
78 try:
79 async with async_timeout.timeout(timeout):
80 answer = await client.query(query)
9ec97c74 81 return (answer, client._quic.tls._peer_certificate.serial_number)
e7000cce 82 except asyncio.TimeoutError as e:
9ec97c74 83 return (e, None)
e7000cce
CHB
84
85class StreamResetError(Exception):
86 def __init__(self, error, message="Stream reset by peer"):
87 self.error = error
88 super().__init__(message)
89
90def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
91 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
92 if verify:
93 configuration.load_verify_locations(verify)
9ec97c74 94 (result, serial) = asyncio.run(
e7000cce
CHB
95 async_quic_query(
96 configuration=configuration,
97 host=host,
98 port=port,
99 query=query,
100 timeout=timeout,
101 create_protocol=DnsClientProtocol
102 )
103 )
104 if (isinstance(result, StreamReset)):
105 raise StreamResetError(result.error_code)
106 if (isinstance(result, asyncio.TimeoutError)):
107 raise TimeoutError()
9ec97c74 108 return (result, serial)
e7000cce
CHB
109
110def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
111 configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
112 if verify:
113 configuration.load_verify_locations(verify)
9ec97c74 114 (result, _) = asyncio.run(
e7000cce
CHB
115 async_quic_query(
116 configuration=configuration,
117 host=host,
118 port=port,
119 query=query,
120 timeout=timeout,
121 create_protocol=BogusDnsClientProtocol
122 )
123 )
124 if (isinstance(result, StreamReset)):
125 raise StreamResetError(result.error_code)
126 if (isinstance(result, asyncio.TimeoutError)):
127 raise TimeoutError()
128 return result