]>
git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/doh3client.py
10 from collections
import deque
11 from typing
import BinaryIO
, Callable
, Deque
, Dict
, List
, Optional
, Union
, cast
12 from urllib
.parse
import urlparse
15 from aioquic
.asyncio
.client
import connect
16 from aioquic
.asyncio
.protocol
import QuicConnectionProtocol
17 from aioquic
.h0
.connection
import H0_ALPN
, H0Connection
18 from aioquic
.h3
.connection
import H3_ALPN
, ErrorCode
, H3Connection
19 from aioquic
.h3
.events
import (
25 from aioquic
.quic
.configuration
import QuicConfiguration
26 from aioquic
.quic
.events
import QuicEvent
, StreamDataReceived
, StreamReset
27 from aioquic
.tls
import CipherSuite
, SessionTicket
29 from doqclient
import StreamResetError
31 HttpConnection
= Union
[H0Connection
, H3Connection
]
34 def __init__(self
, url
: str) -> None:
35 parsed
= urlparse(url
)
37 self
.authority
= parsed
.netloc
38 self
.full_path
= parsed
.path
or "/"
40 self
.full_path
+= "?" + parsed
.query
41 self
.scheme
= parsed
.scheme
50 headers
: Optional
[Dict
] = None,
55 self
.content
= content
56 self
.headers
= headers
60 class HttpClient(QuicConnectionProtocol
):
61 def __init__(self
, *args
, **kwargs
) -> None:
62 super().__init
__(*args
, **kwargs
)
64 self
.pushes
: Dict
[int, Deque
[H3Event
]] = {}
65 self
._http
: Optional
[HttpConnection
] = None
66 self
._request
_events
: Dict
[int, Deque
[H3Event
]] = {}
67 self
._request
_waiter
: Dict
[int, asyncio
.Future
[Deque
[H3Event
]]] = {}
69 if self
._quic
.configuration
.alpn_protocols
[0].startswith("hq-"):
70 self
._http
= H0Connection(self
._quic
)
72 self
._http
= H3Connection(self
._quic
)
74 async def get(self
, url
: str, headers
: Optional
[Dict
] = None) -> Deque
[H3Event
]:
76 Perform a GET request.
78 return await self
._request
(
79 HttpRequest(method
="GET", url
=URL(url
), headers
=headers
)
83 self
, url
: str, data
: bytes
, headers
: Optional
[Dict
] = None
86 Perform a POST request.
88 return await self
._request
(
89 HttpRequest(method
="POST", url
=URL(url
), content
=data
, headers
=headers
)
93 def http_event_received(self
, event
: H3Event
) -> None:
94 if isinstance(event
, (HeadersReceived
, DataReceived
)):
95 stream_id
= event
.stream_id
96 if stream_id
in self
._request
_events
:
98 self
._request
_events
[event
.stream_id
].append(event
)
99 if event
.stream_ended
:
100 request_waiter
= self
._request
_waiter
.pop(stream_id
)
101 request_waiter
.set_result(self
._request
_events
.pop(stream_id
))
103 elif stream_id
in self
._websockets
:
105 websocket
= self
._websockets
[stream_id
]
106 websocket
.http_event_received(event
)
108 elif event
.push_id
in self
.pushes
:
110 self
.pushes
[event
.push_id
].append(event
)
112 elif isinstance(event
, PushPromiseReceived
):
113 self
.pushes
[event
.push_id
] = deque()
114 self
.pushes
[event
.push_id
].append(event
)
116 def quic_event_received(self
, event
: QuicEvent
) -> None:
117 if isinstance(event
, StreamReset
):
118 waiter
= self
._request
_waiter
.pop(event
.stream_id
)
119 waiter
.set_result([event
])
121 # pass event to the HTTP layer
122 if self
._http
is not None:
123 for http_event
in self
._http
.handle_event(event
):
124 self
.http_event_received(http_event
)
126 async def _request(self
, request
: HttpRequest
) -> Deque
[H3Event
]:
127 stream_id
= self
._quic
.get_next_available_stream_id()
128 self
._http
.send_headers(
131 (b
":method", request
.method
.encode()),
132 (b
":scheme", request
.url
.scheme
.encode()),
133 (b
":authority", request
.url
.authority
.encode()),
134 (b
":path", request
.url
.full_path
.encode()),
136 + [(k
.encode(), v
.encode()) for (k
, v
) in request
.headers
.items()],
137 end_stream
=not request
.content
,
140 self
._http
.send_data(
141 stream_id
=stream_id
, data
=request
.content
, end_stream
=True
144 waiter
= self
._loop
.create_future()
145 self
._request
_events
[stream_id
] = deque()
146 self
._request
_waiter
[stream_id
] = waiter
149 return await asyncio
.shield(waiter
)
152 async def perform_http_request(
155 data
: Optional
[bytes
],
157 output_dir
: Optional
[str],
162 http_events
= await client
.post(
166 "content-length": str(len(data
)),
167 "content-type": "application/dns-message",
172 http_events
= await client
.get(url
)
174 elapsed
= time
.time() - start
177 for http_event
in http_events
:
178 if isinstance(http_event
, DataReceived
):
179 result
+= http_event
.data
180 if isinstance(http_event
, StreamReset
):
185 async def async_h3_query(
186 configuration
: QuicConfiguration
,
192 create_protocol
=HttpClient
,
197 url
= "{}?dns={}".format(baseurl
, base64
.urlsafe_b64encode(query
.to_wire()).decode('UTF8').rstrip('='))
201 configuration
=configuration
,
202 create_protocol
=create_protocol
,
204 client
= cast(HttpClient
, client
)
207 async with async_timeout
.timeout(timeout
):
209 answer
= await perform_http_request(
212 data
=query
.to_wire() if post
else None,
218 except asyncio
.TimeoutError
as e
:
222 def doh3_query(query
, baseurl
, timeout
=2, port
=853, verify
=None, server_hostname
=None, post
=False):
223 configuration
= QuicConfiguration(alpn_protocols
=H3_ALPN
, is_client
=True)
225 configuration
.load_verify_locations(verify
)
227 result
= asyncio
.run(
229 configuration
=configuration
,
234 create_protocol
=HttpClient
,
239 if (isinstance(result
, StreamReset
)):
240 raise StreamResetError(result
.error_code
)
241 if (isinstance(result
, asyncio
.TimeoutError
)):
243 return dns
.message
.from_wire(result
)