]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/doh3client.py
Merge pull request #13670 from chbruyand/dnsdist-doq-acl
[thirdparty/pdns.git] / regression-tests.dnsdist / doh3client.py
CommitLineData
4f0b10a9
CHB
1import base64
2import asyncio
3import pickle
4import ssl
5import struct
6import dns
7import time
8import async_timeout
9
10from collections import deque
11from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast
12from urllib.parse import urlparse
13
14import aioquic
15from aioquic.asyncio.client import connect
16from aioquic.asyncio.protocol import QuicConnectionProtocol
17from aioquic.h0.connection import H0_ALPN, H0Connection
18from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection
19from aioquic.h3.events import (
20 DataReceived,
21 H3Event,
22 HeadersReceived,
23 PushPromiseReceived,
24)
25from aioquic.quic.configuration import QuicConfiguration
ac70190e 26from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
4f0b10a9 27from aioquic.tls import CipherSuite, SessionTicket
ac70190e
RG
28
29from doqclient import StreamResetError
d0439b42 30
4f0b10a9
CHB
31HttpConnection = Union[H0Connection, H3Connection]
32
33class URL:
34 def __init__(self, url: str) -> None:
35 parsed = urlparse(url)
36
37 self.authority = parsed.netloc
38 self.full_path = parsed.path or "/"
39 if parsed.query:
40 self.full_path += "?" + parsed.query
41 self.scheme = parsed.scheme
42
43
44class HttpRequest:
45 def __init__(
46 self,
47 method: str,
48 url: URL,
49 content: bytes = b"",
50 headers: Optional[Dict] = None,
51 ) -> None:
52 if headers is None:
53 headers = {}
54
55 self.content = content
56 self.headers = headers
57 self.method = method
58 self.url = url
59
60class HttpClient(QuicConnectionProtocol):
61 def __init__(self, *args, **kwargs) -> None:
62 super().__init__(*args, **kwargs)
63
64 self.pushes: Dict[int, Deque[H3Event]] = {}
65 self._http: Optional[HttpConnection] = None
66 self._request_events: Dict[int, Deque[H3Event]] = {}
67 self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {}
68
69 if self._quic.configuration.alpn_protocols[0].startswith("hq-"):
70 self._http = H0Connection(self._quic)
71 else:
72 self._http = H3Connection(self._quic)
73
74 async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]:
75 """
76 Perform a GET request.
77 """
78 return await self._request(
79 HttpRequest(method="GET", url=URL(url), headers=headers)
80 )
81
82 async def post(
83 self, url: str, data: bytes, headers: Optional[Dict] = None
84 ) -> Deque[H3Event]:
85 """
86 Perform a POST request.
87 """
88 return await self._request(
89 HttpRequest(method="POST", url=URL(url), content=data, headers=headers)
90 )
91
92
93 def http_event_received(self, event: H3Event) -> None:
94 if isinstance(event, (HeadersReceived, DataReceived)):
95 stream_id = event.stream_id
96 if stream_id in self._request_events:
97 # http
98 self._request_events[event.stream_id].append(event)
99 if event.stream_ended:
100 request_waiter = self._request_waiter.pop(stream_id)
101 request_waiter.set_result(self._request_events.pop(stream_id))
102
103 elif stream_id in self._websockets:
104 # websocket
105 websocket = self._websockets[stream_id]
106 websocket.http_event_received(event)
107
108 elif event.push_id in self.pushes:
109 # push
110 self.pushes[event.push_id].append(event)
111
112 elif isinstance(event, PushPromiseReceived):
113 self.pushes[event.push_id] = deque()
114 self.pushes[event.push_id].append(event)
115
116 def quic_event_received(self, event: QuicEvent) -> None:
ac70190e
RG
117 if isinstance(event, StreamReset):
118 waiter = self._request_waiter.pop(event.stream_id)
119 waiter.set_result([event])
120
4f0b10a9
CHB
121 #  pass event to the HTTP layer
122 if self._http is not None:
123 for http_event in self._http.handle_event(event):
124 self.http_event_received(http_event)
125
126 async def _request(self, request: HttpRequest) -> Deque[H3Event]:
127 stream_id = self._quic.get_next_available_stream_id()
128 self._http.send_headers(
129 stream_id=stream_id,
130 headers=[
131 (b":method", request.method.encode()),
132 (b":scheme", request.url.scheme.encode()),
133 (b":authority", request.url.authority.encode()),
134 (b":path", request.url.full_path.encode()),
135 ]
136 + [(k.encode(), v.encode()) for (k, v) in request.headers.items()],
137 end_stream=not request.content,
138 )
139 if request.content:
140 self._http.send_data(
141 stream_id=stream_id, data=request.content, end_stream=True
142 )
143
144 waiter = self._loop.create_future()
145 self._request_events[stream_id] = deque()
146 self._request_waiter[stream_id] = waiter
147 self.transmit()
148
149 return await asyncio.shield(waiter)
150
151
152async def perform_http_request(
153 client: HttpClient,
154 url: str,
d0439b42 155 data: Optional[bytes],
4f0b10a9
CHB
156 include: bool,
157 output_dir: Optional[str],
158) -> None:
159 # perform request
160 start = time.time()
161 if data is not None:
4f0b10a9
CHB
162 http_events = await client.post(
163 url,
d0439b42 164 data=data,
4f0b10a9 165 headers={
d0439b42
CHB
166 "content-length": str(len(data)),
167 "content-type": "application/dns-message",
4f0b10a9
CHB
168 },
169 )
170 method = "POST"
171 else:
172 http_events = await client.get(url)
173 method = "GET"
174 elapsed = time.time() - start
175
176 result = bytes()
177 for http_event in http_events:
178 if isinstance(http_event, DataReceived):
179 result += http_event.data
ac70190e
RG
180 if isinstance(http_event, StreamReset):
181 result = http_event
4f0b10a9 182 return result
ac70190e
RG
183
184
4f0b10a9
CHB
185async def async_h3_query(
186 configuration: QuicConfiguration,
187 baseurl: str,
188 port: int,
189 query: dns.message,
190 timeout: float,
d0439b42
CHB
191 post: bool,
192 create_protocol=HttpClient,
4f0b10a9
CHB
193) -> None:
194
d0439b42
CHB
195 url = baseurl
196 if not post:
197 url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
4f0b10a9
CHB
198 async with connect(
199 "127.0.0.1",
200 port,
201 configuration=configuration,
202 create_protocol=create_protocol,
203 ) as client:
204 client = cast(HttpClient, client)
205
4f0b10a9
CHB
206 try:
207 async with async_timeout.timeout(timeout):
208
209 answer = await perform_http_request(
210 client=client,
211 url=url,
d0439b42 212 data=query.to_wire() if post else None,
4f0b10a9
CHB
213 include=False,
214 output_dir=None,
215 )
216
217 return answer
218 except asyncio.TimeoutError as e:
219 return e
220
d0439b42
CHB
221
222def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False):
4f0b10a9
CHB
223 configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
224 if verify:
225 configuration.load_verify_locations(verify)
d0439b42 226
4f0b10a9
CHB
227 result = asyncio.run(
228 async_h3_query(
229 configuration=configuration,
230 baseurl=baseurl,
231 port=port,
232 query=query,
233 timeout=timeout,
d0439b42
CHB
234 create_protocol=HttpClient,
235 post=post
4f0b10a9
CHB
236 )
237 )
ac70190e
RG
238
239 if (isinstance(result, StreamReset)):
240 raise StreamResetError(result.error_code)
4f0b10a9
CHB
241 if (isinstance(result, asyncio.TimeoutError)):
242 raise TimeoutError()
ac70190e 243 return dns.message.from_wire(result)