]>
Commit | Line | Data |
---|---|---|
4f0b10a9 CHB |
1 | import base64 |
2 | import asyncio | |
3 | import pickle | |
4 | import ssl | |
5 | import struct | |
6 | import dns | |
7 | import time | |
8 | import async_timeout | |
9 | ||
10 | from collections import deque | |
11 | from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast | |
12 | from urllib.parse import urlparse | |
13 | ||
14 | import aioquic | |
15 | from aioquic.asyncio.client import connect | |
16 | from aioquic.asyncio.protocol import QuicConnectionProtocol | |
17 | from aioquic.h0.connection import H0_ALPN, H0Connection | |
18 | from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection | |
19 | from aioquic.h3.events import ( | |
20 | DataReceived, | |
21 | H3Event, | |
22 | HeadersReceived, | |
23 | PushPromiseReceived, | |
24 | ) | |
25 | from aioquic.quic.configuration import QuicConfiguration | |
ac70190e | 26 | from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset |
4f0b10a9 | 27 | from aioquic.tls import CipherSuite, SessionTicket |
ac70190e RG |
28 | |
29 | from doqclient import StreamResetError | |
d0439b42 | 30 | |
4f0b10a9 CHB |
31 | HttpConnection = Union[H0Connection, H3Connection] |
32 | ||
33 | class URL: | |
34 | def __init__(self, url: str) -> None: | |
35 | parsed = urlparse(url) | |
36 | ||
37 | self.authority = parsed.netloc | |
38 | self.full_path = parsed.path or "/" | |
39 | if parsed.query: | |
40 | self.full_path += "?" + parsed.query | |
41 | self.scheme = parsed.scheme | |
42 | ||
43 | ||
44 | class HttpRequest: | |
45 | def __init__( | |
46 | self, | |
47 | method: str, | |
48 | url: URL, | |
49 | content: bytes = b"", | |
50 | headers: Optional[Dict] = None, | |
51 | ) -> None: | |
52 | if headers is None: | |
53 | headers = {} | |
54 | ||
55 | self.content = content | |
56 | self.headers = headers | |
57 | self.method = method | |
58 | self.url = url | |
59 | ||
60 | class HttpClient(QuicConnectionProtocol): | |
61 | def __init__(self, *args, **kwargs) -> None: | |
62 | super().__init__(*args, **kwargs) | |
63 | ||
64 | self.pushes: Dict[int, Deque[H3Event]] = {} | |
65 | self._http: Optional[HttpConnection] = None | |
66 | self._request_events: Dict[int, Deque[H3Event]] = {} | |
67 | self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} | |
68 | ||
69 | if self._quic.configuration.alpn_protocols[0].startswith("hq-"): | |
70 | self._http = H0Connection(self._quic) | |
71 | else: | |
72 | self._http = H3Connection(self._quic) | |
73 | ||
74 | async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]: | |
75 | """ | |
76 | Perform a GET request. | |
77 | """ | |
78 | return await self._request( | |
79 | HttpRequest(method="GET", url=URL(url), headers=headers) | |
80 | ) | |
81 | ||
82 | async def post( | |
83 | self, url: str, data: bytes, headers: Optional[Dict] = None | |
84 | ) -> Deque[H3Event]: | |
85 | """ | |
86 | Perform a POST request. | |
87 | """ | |
88 | return await self._request( | |
89 | HttpRequest(method="POST", url=URL(url), content=data, headers=headers) | |
90 | ) | |
91 | ||
92 | ||
93 | def http_event_received(self, event: H3Event) -> None: | |
94 | if isinstance(event, (HeadersReceived, DataReceived)): | |
95 | stream_id = event.stream_id | |
96 | if stream_id in self._request_events: | |
97 | # http | |
98 | self._request_events[event.stream_id].append(event) | |
99 | if event.stream_ended: | |
100 | request_waiter = self._request_waiter.pop(stream_id) | |
101 | request_waiter.set_result(self._request_events.pop(stream_id)) | |
102 | ||
103 | elif stream_id in self._websockets: | |
104 | # websocket | |
105 | websocket = self._websockets[stream_id] | |
106 | websocket.http_event_received(event) | |
107 | ||
108 | elif event.push_id in self.pushes: | |
109 | # push | |
110 | self.pushes[event.push_id].append(event) | |
111 | ||
112 | elif isinstance(event, PushPromiseReceived): | |
113 | self.pushes[event.push_id] = deque() | |
114 | self.pushes[event.push_id].append(event) | |
115 | ||
116 | def quic_event_received(self, event: QuicEvent) -> None: | |
ac70190e RG |
117 | if isinstance(event, StreamReset): |
118 | waiter = self._request_waiter.pop(event.stream_id) | |
119 | waiter.set_result([event]) | |
120 | ||
4f0b10a9 CHB |
121 | # pass event to the HTTP layer |
122 | if self._http is not None: | |
123 | for http_event in self._http.handle_event(event): | |
124 | self.http_event_received(http_event) | |
125 | ||
126 | async def _request(self, request: HttpRequest) -> Deque[H3Event]: | |
127 | stream_id = self._quic.get_next_available_stream_id() | |
128 | self._http.send_headers( | |
129 | stream_id=stream_id, | |
130 | headers=[ | |
131 | (b":method", request.method.encode()), | |
132 | (b":scheme", request.url.scheme.encode()), | |
133 | (b":authority", request.url.authority.encode()), | |
134 | (b":path", request.url.full_path.encode()), | |
135 | ] | |
136 | + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], | |
137 | end_stream=not request.content, | |
138 | ) | |
139 | if request.content: | |
140 | self._http.send_data( | |
141 | stream_id=stream_id, data=request.content, end_stream=True | |
142 | ) | |
143 | ||
144 | waiter = self._loop.create_future() | |
145 | self._request_events[stream_id] = deque() | |
146 | self._request_waiter[stream_id] = waiter | |
147 | self.transmit() | |
148 | ||
149 | return await asyncio.shield(waiter) | |
150 | ||
151 | ||
152 | async def perform_http_request( | |
153 | client: HttpClient, | |
154 | url: str, | |
d0439b42 | 155 | data: Optional[bytes], |
4f0b10a9 CHB |
156 | include: bool, |
157 | output_dir: Optional[str], | |
158 | ) -> None: | |
159 | # perform request | |
160 | start = time.time() | |
161 | if data is not None: | |
4f0b10a9 CHB |
162 | http_events = await client.post( |
163 | url, | |
d0439b42 | 164 | data=data, |
4f0b10a9 | 165 | headers={ |
d0439b42 CHB |
166 | "content-length": str(len(data)), |
167 | "content-type": "application/dns-message", | |
4f0b10a9 CHB |
168 | }, |
169 | ) | |
170 | method = "POST" | |
171 | else: | |
172 | http_events = await client.get(url) | |
173 | method = "GET" | |
174 | elapsed = time.time() - start | |
175 | ||
176 | result = bytes() | |
177 | for http_event in http_events: | |
178 | if isinstance(http_event, DataReceived): | |
179 | result += http_event.data | |
ac70190e RG |
180 | if isinstance(http_event, StreamReset): |
181 | result = http_event | |
4f0b10a9 | 182 | return result |
ac70190e RG |
183 | |
184 | ||
4f0b10a9 CHB |
185 | async def async_h3_query( |
186 | configuration: QuicConfiguration, | |
187 | baseurl: str, | |
188 | port: int, | |
189 | query: dns.message, | |
190 | timeout: float, | |
d0439b42 CHB |
191 | post: bool, |
192 | create_protocol=HttpClient, | |
4f0b10a9 CHB |
193 | ) -> None: |
194 | ||
d0439b42 CHB |
195 | url = baseurl |
196 | if not post: | |
197 | url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('=')) | |
4f0b10a9 CHB |
198 | async with connect( |
199 | "127.0.0.1", | |
200 | port, | |
201 | configuration=configuration, | |
202 | create_protocol=create_protocol, | |
203 | ) as client: | |
204 | client = cast(HttpClient, client) | |
205 | ||
4f0b10a9 CHB |
206 | try: |
207 | async with async_timeout.timeout(timeout): | |
208 | ||
209 | answer = await perform_http_request( | |
210 | client=client, | |
211 | url=url, | |
d0439b42 | 212 | data=query.to_wire() if post else None, |
4f0b10a9 CHB |
213 | include=False, |
214 | output_dir=None, | |
215 | ) | |
216 | ||
217 | return answer | |
218 | except asyncio.TimeoutError as e: | |
219 | return e | |
220 | ||
d0439b42 CHB |
221 | |
222 | def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False): | |
4f0b10a9 CHB |
223 | configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) |
224 | if verify: | |
225 | configuration.load_verify_locations(verify) | |
d0439b42 | 226 | |
4f0b10a9 CHB |
227 | result = asyncio.run( |
228 | async_h3_query( | |
229 | configuration=configuration, | |
230 | baseurl=baseurl, | |
231 | port=port, | |
232 | query=query, | |
233 | timeout=timeout, | |
d0439b42 CHB |
234 | create_protocol=HttpClient, |
235 | post=post | |
4f0b10a9 CHB |
236 | ) |
237 | ) | |
ac70190e RG |
238 | |
239 | if (isinstance(result, StreamReset)): | |
240 | raise StreamResetError(result.error_code) | |
4f0b10a9 CHB |
241 | if (isinstance(result, asyncio.TimeoutError)): |
242 | raise TimeoutError() | |
ac70190e | 243 | return dns.message.from_wire(result) |