]>
Commit | Line | Data |
---|---|---|
e7000cce CHB |
1 | import asyncio |
2 | import pickle | |
3 | import ssl | |
4 | import struct | |
5 | from typing import Any, Optional, cast | |
6 | import dns | |
9ec97c74 | 7 | import dns.message |
e7000cce CHB |
8 | import async_timeout |
9 | ||
10 | from aioquic.quic.configuration import QuicConfiguration | |
11 | from aioquic.asyncio.client import connect | |
12 | from aioquic.asyncio.protocol import QuicConnectionProtocol | |
13 | from aioquic.quic.configuration import QuicConfiguration | |
14 | from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset | |
15 | from aioquic.quic.logger import QuicFileLogger | |
16 | ||
17 | class DnsClientProtocol(QuicConnectionProtocol): | |
18 | def __init__(self, *args, **kwargs): | |
19 | super().__init__(*args, **kwargs) | |
20 | self._ack_waiter: Any = None | |
21 | ||
22 | def pack(self, data): | |
23 | # serialize query | |
24 | data = bytes(data) | |
25 | data = struct.pack("!H", len(data)) + data | |
26 | return data | |
27 | ||
28 | async def query(self, query: dns.message) -> None: | |
29 | data = self.pack(query.to_wire()) | |
30 | # send query and wait for answer | |
31 | stream_id = self._quic.get_next_available_stream_id() | |
32 | self._quic.send_stream_data(stream_id, data, end_stream=True) | |
33 | waiter = self._loop.create_future() | |
34 | self._ack_waiter = waiter | |
35 | self.transmit() | |
36 | ||
37 | return await asyncio.shield(waiter) | |
38 | ||
39 | def quic_event_received(self, event: QuicEvent) -> None: | |
40 | if self._ack_waiter is not None: | |
41 | if isinstance(event, StreamDataReceived): | |
42 | length = struct.unpack("!H", bytes(event.data[:2]))[0] | |
43 | answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True) | |
44 | ||
45 | waiter = self._ack_waiter | |
46 | self._ack_waiter = None | |
47 | waiter.set_result(answer) | |
48 | if isinstance(event, StreamReset): | |
49 | waiter = self._ack_waiter | |
50 | self._ack_waiter = None | |
51 | waiter.set_result(event) | |
52 | ||
53 | class BogusDnsClientProtocol(DnsClientProtocol): | |
54 | def pack(self, data): | |
55 | # serialize query | |
56 | data = bytes(data) | |
57 | data = struct.pack("!H", len(data) * 2) + data | |
58 | return data | |
59 | ||
60 | ||
61 | async def async_quic_query( | |
62 | configuration: QuicConfiguration, | |
63 | host: str, | |
64 | port: int, | |
65 | query: dns.message, | |
66 | timeout: float, | |
67 | create_protocol=DnsClientProtocol | |
68 | ) -> None: | |
69 | print("Connecting to {}:{}".format(host, port)) | |
70 | async with connect( | |
71 | host, | |
72 | port, | |
73 | configuration=configuration, | |
74 | create_protocol=create_protocol, | |
75 | ) as client: | |
76 | client = cast(DnsClientProtocol, client) | |
77 | print("Sending DNS query") | |
78 | try: | |
79 | async with async_timeout.timeout(timeout): | |
80 | answer = await client.query(query) | |
9ec97c74 | 81 | return (answer, client._quic.tls._peer_certificate.serial_number) |
e7000cce | 82 | except asyncio.TimeoutError as e: |
9ec97c74 | 83 | return (e, None) |
e7000cce CHB |
84 | |
85 | class StreamResetError(Exception): | |
86 | def __init__(self, error, message="Stream reset by peer"): | |
87 | self.error = error | |
88 | super().__init__(message) | |
89 | ||
90 | def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): | |
91 | configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) | |
92 | if verify: | |
93 | configuration.load_verify_locations(verify) | |
9ec97c74 | 94 | (result, serial) = asyncio.run( |
e7000cce CHB |
95 | async_quic_query( |
96 | configuration=configuration, | |
97 | host=host, | |
98 | port=port, | |
99 | query=query, | |
100 | timeout=timeout, | |
101 | create_protocol=DnsClientProtocol | |
102 | ) | |
103 | ) | |
104 | if (isinstance(result, StreamReset)): | |
105 | raise StreamResetError(result.error_code) | |
106 | if (isinstance(result, asyncio.TimeoutError)): | |
107 | raise TimeoutError() | |
9ec97c74 | 108 | return (result, serial) |
e7000cce CHB |
109 | |
110 | def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): | |
111 | configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) | |
112 | if verify: | |
113 | configuration.load_verify_locations(verify) | |
9ec97c74 | 114 | (result, _) = asyncio.run( |
e7000cce CHB |
115 | async_quic_query( |
116 | configuration=configuration, | |
117 | host=host, | |
118 | port=port, | |
119 | query=query, | |
120 | timeout=timeout, | |
121 | create_protocol=BogusDnsClientProtocol | |
122 | ) | |
123 | ) | |
124 | if (isinstance(result, StreamReset)): | |
125 | raise StreamResetError(result.error_code) | |
126 | if (isinstance(result, asyncio.TimeoutError)): | |
127 | raise TimeoutError() | |
128 | return result |