]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/doh3client.py
Merge pull request #13980 from karelbilek/d_xfr
[thirdparty/pdns.git] / regression-tests.dnsdist / doh3client.py
1 import base64
2 import asyncio
3 import pickle
4 import ssl
5 import struct
6 import dns
7 import time
8 import async_timeout
9
10 from collections import deque
11 from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast
12 from urllib.parse import urlparse
13
14 import aioquic
15 from aioquic.asyncio.client import connect
16 from aioquic.asyncio.protocol import QuicConnectionProtocol
17 from aioquic.h0.connection import H0_ALPN, H0Connection
18 from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection
19 from aioquic.h3.events import (
20 DataReceived,
21 H3Event,
22 HeadersReceived,
23 PushPromiseReceived,
24 )
25 from aioquic.quic.configuration import QuicConfiguration
26 from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
27 from aioquic.tls import CipherSuite, SessionTicket
28
29 from doqclient import StreamResetError
30
31 HttpConnection = Union[H0Connection, H3Connection]
32
33 class URL:
34 def __init__(self, url: str) -> None:
35 parsed = urlparse(url)
36
37 self.authority = parsed.netloc
38 self.full_path = parsed.path or "/"
39 if parsed.query:
40 self.full_path += "?" + parsed.query
41 self.scheme = parsed.scheme
42
43
44 class HttpRequest:
45 def __init__(
46 self,
47 method: str,
48 url: URL,
49 content: bytes = b"",
50 headers: Optional[Dict] = None,
51 ) -> None:
52 if headers is None:
53 headers = {}
54
55 self.content = content
56 self.headers = headers
57 self.method = method
58 self.url = url
59
60 class HttpClient(QuicConnectionProtocol):
61 def __init__(self, *args, **kwargs) -> None:
62 super().__init__(*args, **kwargs)
63
64 self.pushes: Dict[int, Deque[H3Event]] = {}
65 self._http: Optional[HttpConnection] = None
66 self._request_events: Dict[int, Deque[H3Event]] = {}
67 self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {}
68
69 if self._quic.configuration.alpn_protocols[0].startswith("hq-"):
70 self._http = H0Connection(self._quic)
71 else:
72 self._http = H3Connection(self._quic)
73
74 async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]:
75 """
76 Perform a GET request.
77 """
78 return await self._request(
79 HttpRequest(method="GET", url=URL(url), headers=headers)
80 )
81
82 async def post(
83 self, url: str, data: bytes, headers: Optional[Dict] = None
84 ) -> Deque[H3Event]:
85 """
86 Perform a POST request.
87 """
88 return await self._request(
89 HttpRequest(method="POST", url=URL(url), content=data, headers=headers)
90 )
91
92
93 def http_event_received(self, event: H3Event) -> None:
94 if isinstance(event, (HeadersReceived, DataReceived)):
95 stream_id = event.stream_id
96 if stream_id in self._request_events:
97 # http
98 self._request_events[event.stream_id].append(event)
99 if event.stream_ended:
100 request_waiter = self._request_waiter.pop(stream_id)
101 request_waiter.set_result(self._request_events.pop(stream_id))
102
103 elif stream_id in self._websockets:
104 # websocket
105 websocket = self._websockets[stream_id]
106 websocket.http_event_received(event)
107
108 elif event.push_id in self.pushes:
109 # push
110 self.pushes[event.push_id].append(event)
111
112 elif isinstance(event, PushPromiseReceived):
113 self.pushes[event.push_id] = deque()
114 self.pushes[event.push_id].append(event)
115
116 def quic_event_received(self, event: QuicEvent) -> None:
117 if isinstance(event, StreamReset):
118 waiter = self._request_waiter.pop(event.stream_id)
119 waiter.set_result([event])
120
121 #  pass event to the HTTP layer
122 if self._http is not None:
123 for http_event in self._http.handle_event(event):
124 self.http_event_received(http_event)
125
126 async def _request(self, request: HttpRequest) -> Deque[H3Event]:
127 stream_id = self._quic.get_next_available_stream_id()
128 self._http.send_headers(
129 stream_id=stream_id,
130 headers=[
131 (b":method", request.method.encode()),
132 (b":scheme", request.url.scheme.encode()),
133 (b":authority", request.url.authority.encode()),
134 (b":path", request.url.full_path.encode()),
135 ]
136 + [(k.encode(), v.encode()) for (k, v) in request.headers.items()],
137 end_stream=not request.content,
138 )
139 if request.content:
140 self._http.send_data(
141 stream_id=stream_id, data=request.content, end_stream=True
142 )
143
144 waiter = self._loop.create_future()
145 self._request_events[stream_id] = deque()
146 self._request_waiter[stream_id] = waiter
147 self.transmit()
148
149 return await asyncio.shield(waiter)
150
151
152 async def perform_http_request(
153 client: HttpClient,
154 url: str,
155 data: Optional[bytes],
156 include: bool,
157 output_dir: Optional[str],
158 ) -> None:
159 # perform request
160 start = time.time()
161 if data is not None:
162 http_events = await client.post(
163 url,
164 data=data,
165 headers={
166 "content-length": str(len(data)),
167 "content-type": "application/dns-message",
168 },
169 )
170 method = "POST"
171 else:
172 http_events = await client.get(url)
173 method = "GET"
174 elapsed = time.time() - start
175
176 result = bytes()
177 for http_event in http_events:
178 if isinstance(http_event, DataReceived):
179 result += http_event.data
180 if isinstance(http_event, StreamReset):
181 result = http_event
182 return result
183
184
185 async def async_h3_query(
186 configuration: QuicConfiguration,
187 baseurl: str,
188 port: int,
189 query: dns.message,
190 timeout: float,
191 post: bool,
192 create_protocol=HttpClient,
193 ) -> None:
194
195 url = baseurl
196 if not post:
197 url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
198 async with connect(
199 "127.0.0.1",
200 port,
201 configuration=configuration,
202 create_protocol=create_protocol,
203 ) as client:
204 client = cast(HttpClient, client)
205
206 try:
207 async with async_timeout.timeout(timeout):
208
209 answer = await perform_http_request(
210 client=client,
211 url=url,
212 data=query.to_wire() if post else None,
213 include=False,
214 output_dir=None,
215 )
216
217 return answer
218 except asyncio.TimeoutError as e:
219 return e
220
221
222 def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False):
223 configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
224 if verify:
225 configuration.load_verify_locations(verify)
226
227 result = asyncio.run(
228 async_h3_query(
229 configuration=configuration,
230 baseurl=baseurl,
231 port=port,
232 query=query,
233 timeout=timeout,
234 create_protocol=HttpClient,
235 post=post
236 )
237 )
238
239 if (isinstance(result, StreamReset)):
240 raise StreamResetError(result.error_code)
241 if (isinstance(result, asyncio.TimeoutError)):
242 raise TimeoutError()
243 return dns.message.from_wire(result)