]>
Commit | Line | Data |
---|---|---|
e7000cce CHB |
1 | import asyncio |
2 | import pickle | |
3 | import ssl | |
4 | import struct | |
5 | from typing import Any, Optional, cast | |
6 | import dns | |
7 | import async_timeout | |
8 | ||
9 | from aioquic.quic.configuration import QuicConfiguration | |
10 | from aioquic.asyncio.client import connect | |
11 | from aioquic.asyncio.protocol import QuicConnectionProtocol | |
12 | from aioquic.quic.configuration import QuicConfiguration | |
13 | from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset | |
14 | from aioquic.quic.logger import QuicFileLogger | |
15 | ||
16 | class DnsClientProtocol(QuicConnectionProtocol): | |
17 | def __init__(self, *args, **kwargs): | |
18 | super().__init__(*args, **kwargs) | |
19 | self._ack_waiter: Any = None | |
20 | ||
21 | def pack(self, data): | |
22 | # serialize query | |
23 | data = bytes(data) | |
24 | data = struct.pack("!H", len(data)) + data | |
25 | return data | |
26 | ||
27 | async def query(self, query: dns.message) -> None: | |
28 | data = self.pack(query.to_wire()) | |
29 | # send query and wait for answer | |
30 | stream_id = self._quic.get_next_available_stream_id() | |
31 | self._quic.send_stream_data(stream_id, data, end_stream=True) | |
32 | waiter = self._loop.create_future() | |
33 | self._ack_waiter = waiter | |
34 | self.transmit() | |
35 | ||
36 | return await asyncio.shield(waiter) | |
37 | ||
38 | def quic_event_received(self, event: QuicEvent) -> None: | |
39 | if self._ack_waiter is not None: | |
40 | if isinstance(event, StreamDataReceived): | |
41 | length = struct.unpack("!H", bytes(event.data[:2]))[0] | |
42 | answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True) | |
43 | ||
44 | waiter = self._ack_waiter | |
45 | self._ack_waiter = None | |
46 | waiter.set_result(answer) | |
47 | if isinstance(event, StreamReset): | |
48 | waiter = self._ack_waiter | |
49 | self._ack_waiter = None | |
50 | waiter.set_result(event) | |
51 | ||
52 | class BogusDnsClientProtocol(DnsClientProtocol): | |
53 | def pack(self, data): | |
54 | # serialize query | |
55 | data = bytes(data) | |
56 | data = struct.pack("!H", len(data) * 2) + data | |
57 | return data | |
58 | ||
59 | ||
60 | async def async_quic_query( | |
61 | configuration: QuicConfiguration, | |
62 | host: str, | |
63 | port: int, | |
64 | query: dns.message, | |
65 | timeout: float, | |
66 | create_protocol=DnsClientProtocol | |
67 | ) -> None: | |
68 | print("Connecting to {}:{}".format(host, port)) | |
69 | async with connect( | |
70 | host, | |
71 | port, | |
72 | configuration=configuration, | |
73 | create_protocol=create_protocol, | |
74 | ) as client: | |
75 | client = cast(DnsClientProtocol, client) | |
76 | print("Sending DNS query") | |
77 | try: | |
78 | async with async_timeout.timeout(timeout): | |
79 | answer = await client.query(query) | |
80 | return answer | |
81 | except asyncio.TimeoutError as e: | |
82 | return e | |
83 | ||
84 | class StreamResetError(Exception): | |
85 | def __init__(self, error, message="Stream reset by peer"): | |
86 | self.error = error | |
87 | super().__init__(message) | |
88 | ||
89 | def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): | |
90 | configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) | |
91 | if verify: | |
92 | configuration.load_verify_locations(verify) | |
93 | result = asyncio.run( | |
94 | async_quic_query( | |
95 | configuration=configuration, | |
96 | host=host, | |
97 | port=port, | |
98 | query=query, | |
99 | timeout=timeout, | |
100 | create_protocol=DnsClientProtocol | |
101 | ) | |
102 | ) | |
103 | if (isinstance(result, StreamReset)): | |
104 | raise StreamResetError(result.error_code) | |
105 | if (isinstance(result, asyncio.TimeoutError)): | |
106 | raise TimeoutError() | |
107 | return result | |
108 | ||
109 | def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None): | |
110 | configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True) | |
111 | if verify: | |
112 | configuration.load_verify_locations(verify) | |
113 | result = asyncio.run( | |
114 | async_quic_query( | |
115 | configuration=configuration, | |
116 | host=host, | |
117 | port=port, | |
118 | query=query, | |
119 | timeout=timeout, | |
120 | create_protocol=BogusDnsClientProtocol | |
121 | ) | |
122 | ) | |
123 | if (isinstance(result, StreamReset)): | |
124 | raise StreamResetError(result.error_code) | |
125 | if (isinstance(result, asyncio.TimeoutError)): | |
126 | raise TimeoutError() | |
127 | return result |