]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/doh3.cc
dnsdist: Use the correct source IP for outgoing QUIC datagrams
[thirdparty/pdns.git] / pdns / dnsdistdist / doh3.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22
23 #include "doh3.hh"
24
25 #ifdef HAVE_DNS_OVER_HTTP3
26 #include <quiche.h>
27
28 #include "dolog.hh"
29 #include "iputils.hh"
30 #include "misc.hh"
31 #include "sstuff.hh"
32 #include "threadname.hh"
33 #include "base64.hh"
34
35 #include "dnsdist-dnsparser.hh"
36 #include "dnsdist-ecs.hh"
37 #include "dnsdist-proxy-protocol.hh"
38 #include "dnsdist-tcp.hh"
39 #include "dnsdist-random.hh"
40
41 #include "doq-common.hh"
42
43 #if 0
44 #define DEBUGLOG_ENABLED
45 #define DEBUGLOG(x) std::cerr << x << std::endl;
46 #else
47 #define DEBUGLOG(x)
48 #endif
49
50 using namespace dnsdist::doq;
51
52 using h3_headers_t = std::map<std::string, std::string>;
53
54 class H3Connection
55 {
56 public:
57 H3Connection(const ComboAddress& peer, const ComboAddress& localAddr, QuicheConfig config, QuicheConnection&& conn) :
58 d_peer(peer), d_localAddr(localAddr), d_conn(std::move(conn)), d_config(std::move(config))
59 {
60 }
61 H3Connection(const H3Connection&) = delete;
62 H3Connection(H3Connection&&) = default;
63 H3Connection& operator=(const H3Connection&) = delete;
64 H3Connection& operator=(H3Connection&&) = default;
65 ~H3Connection() = default;
66
67 ComboAddress d_peer;
68 ComboAddress d_localAddr;
69 QuicheConnection d_conn;
70 QuicheConfig d_config;
71 QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
72 // buffer request headers by streamID
73 std::unordered_map<uint64_t, h3_headers_t> d_headersBuffers;
74 std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
75 std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
76 };
77
78 static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
79
80 struct DOH3ServerConfig
81 {
82 DOH3ServerConfig(QuicheConfig&& config_, QuicheHTTP3Config&& http3config_, uint32_t internalPipeBufferSize) :
83 config(std::move(config_)), http3config(std::move(http3config_))
84 {
85 {
86 auto [sender, receiver] = pdns::channel::createObjectQueue<DOH3Unit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
87 d_responseSender = std::move(sender);
88 d_responseReceiver = std::move(receiver);
89 }
90 }
91 DOH3ServerConfig(const DOH3ServerConfig&) = delete;
92 DOH3ServerConfig(DOH3ServerConfig&&) = default;
93 DOH3ServerConfig& operator=(const DOH3ServerConfig&) = delete;
94 DOH3ServerConfig& operator=(DOH3ServerConfig&&) = default;
95 ~DOH3ServerConfig() = default;
96
97 using ConnectionsMap = std::map<PacketBuffer, H3Connection>;
98
99 LocalHolders holders;
100 ConnectionsMap d_connections;
101 QuicheConfig config;
102 QuicheHTTP3Config http3config;
103 ClientState* clientState{nullptr};
104 std::shared_ptr<DOH3Frontend> df{nullptr};
105 pdns::channel::Sender<DOH3Unit> d_responseSender;
106 pdns::channel::Receiver<DOH3Unit> d_responseReceiver;
107 };
108
109 /* these might seem useless, but they are needed because
110 they need to be declared _after_ the definition of DOH3ServerConfig
111 so that we can use a unique_ptr in DOH3Frontend */
112 DOH3Frontend::DOH3Frontend() = default;
113 DOH3Frontend::~DOH3Frontend() = default;
114
115 class DOH3TCPCrossQuerySender final : public TCPQuerySender
116 {
117 public:
118 DOH3TCPCrossQuerySender() = default;
119
120 [[nodiscard]] bool active() const override
121 {
122 return true;
123 }
124
125 void handleResponse([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
126 {
127 if (!response.d_idstate.doh3u) {
128 return;
129 }
130
131 auto unit = std::move(response.d_idstate.doh3u);
132 if (unit->dsc == nullptr) {
133 return;
134 }
135
136 unit->response = std::move(response.d_buffer);
137 unit->ids = std::move(response.d_idstate);
138 DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream);
139
140 dnsheader cleartextDH{};
141 memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
142
143 if (!response.isAsync()) {
144
145 static thread_local LocalStateHolder<vector<dnsdist::rules::ResponseRuleAction>> localRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::ResponseRules).getLocal();
146 static thread_local LocalStateHolder<vector<dnsdist::rules::ResponseRuleAction>> localCacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChainHolder(dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules).getLocal();
147
148 dnsResponse.ids.doh3u = std::move(unit);
149
150 if (!processResponse(dnsResponse.ids.doh3u->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) {
151 if (dnsResponse.ids.doh3u) {
152
153 sendBackDOH3Unit(std::move(dnsResponse.ids.doh3u), "Response dropped by rules");
154 }
155 return;
156 }
157
158 if (dnsResponse.isAsynchronous()) {
159 return;
160 }
161
162 unit = std::move(dnsResponse.ids.doh3u);
163 }
164
165 if (!unit->ids.selfGenerated) {
166 double udiff = unit->ids.queryRealTime.udiff();
167 vinfolog("Got answer from %s, relayed to %s (DoH3, %d bytes), took %f us", unit->downstream->d_config.remote.toStringWithPort(), unit->ids.origRemote.toStringWithPort(), unit->response.size(), udiff);
168
169 auto backendProtocol = unit->downstream->getProtocol();
170 if (backendProtocol == dnsdist::Protocol::DoUDP && unit->tcp) {
171 backendProtocol = dnsdist::Protocol::DoTCP;
172 }
173 handleResponseSent(unit->ids, udiff, unit->ids.origRemote, unit->downstream->d_config.remote, unit->response.size(), cleartextDH, backendProtocol, true);
174 }
175
176 ++dnsdist::metrics::g_stats.responses;
177 if (unit->ids.cs != nullptr) {
178 ++unit->ids.cs->responses;
179 }
180
181 sendBackDOH3Unit(std::move(unit), "Cross-protocol response");
182 }
183
184 void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
185 {
186 return handleResponse(now, std::move(response));
187 }
188
189 void notifyIOError([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
190 {
191 if (!response.d_idstate.doh3u) {
192 return;
193 }
194
195 auto unit = std::move(response.d_idstate.doh3u);
196 if (unit->dsc == nullptr) {
197 return;
198 }
199
200 /* this will signal an error */
201 unit->response.clear();
202 unit->ids = std::move(response.d_idstate);
203 sendBackDOH3Unit(std::move(unit), "Cross-protocol error");
204 }
205 };
206
207 class DOH3CrossProtocolQuery : public CrossProtocolQuery
208 {
209 public:
210 DOH3CrossProtocolQuery(DOH3UnitUniquePtr&& unit, bool isResponse)
211 {
212 if (isResponse) {
213 /* happens when a response becomes async */
214 query = InternalQuery(std::move(unit->response), std::move(unit->ids));
215 }
216 else {
217 /* we need to duplicate the query here because we might need
218 the existing query later if we get a truncated answer */
219 query = InternalQuery(PacketBuffer(unit->query), std::move(unit->ids));
220 }
221
222 /* it might have been moved when we moved unit->ids */
223 if (unit) {
224 query.d_idstate.doh3u = std::move(unit);
225 }
226
227 /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
228 clearing query.d_proxyProtocolPayloadAdded and unit->proxyProtocolPayloadSize.
229 Leave it for now because we know that the onky case where the payload has been
230 added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
231 and we know the TCP/DoT code can handle it. */
232 query.d_proxyProtocolPayloadAdded = query.d_idstate.doh3u->proxyProtocolPayloadSize > 0;
233 downstream = query.d_idstate.doh3u->downstream;
234 }
235
236 void handleInternalError()
237 {
238 sendBackDOH3Unit(std::move(query.d_idstate.doh3u), "DOH3 internal error");
239 }
240
241 std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
242 {
243 query.d_idstate.doh3u->downstream = downstream;
244 return s_sender;
245 }
246
247 DNSQuestion getDQ() override
248 {
249 auto& ids = query.d_idstate;
250 DNSQuestion dnsQuestion(ids, query.d_buffer);
251 return dnsQuestion;
252 }
253
254 DNSResponse getDR() override
255 {
256 auto& ids = query.d_idstate;
257 DNSResponse dnsResponse(ids, query.d_buffer, downstream);
258 return dnsResponse;
259 }
260
261 DOH3UnitUniquePtr&& releaseDU()
262 {
263 return std::move(query.d_idstate.doh3u);
264 }
265
266 private:
267 static std::shared_ptr<DOH3TCPCrossQuerySender> s_sender;
268 };
269
270 std::shared_ptr<DOH3TCPCrossQuerySender> DOH3CrossProtocolQuery::s_sender = std::make_shared<DOH3TCPCrossQuerySender>();
271
272 static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response)
273 {
274 size_t pos = 0;
275 while (pos < response.size()) {
276 // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
277 auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
278 streamID, &response.at(pos), response.size() - pos, true);
279 if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
280 response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
281 return false;
282 }
283 if (res < 0) {
284 // Shutdown with internal error code
285 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
286 return true;
287 }
288 pos += res;
289 }
290
291 return true;
292 }
293
294 static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
295 {
296 std::string status = std::to_string(statusCode);
297 std::string lenStr = std::to_string(len);
298 std::array<quiche_h3_header, 3> headers{
299 (quiche_h3_header){
300 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
301 .name = reinterpret_cast<const uint8_t*>(":status"),
302 .name_len = sizeof(":status") - 1,
303 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
304 .value = reinterpret_cast<const uint8_t*>(status.data()),
305 .value_len = status.size(),
306 },
307 (quiche_h3_header){
308 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
309 .name = reinterpret_cast<const uint8_t*>("content-length"),
310 .name_len = sizeof("content-length") - 1,
311 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
312 .value = reinterpret_cast<const uint8_t*>(lenStr.data()),
313 .value_len = lenStr.size(),
314 },
315 (quiche_h3_header){
316 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
317 .name = reinterpret_cast<const uint8_t*>("content-type"),
318 .name_len = sizeof("content-type") - 1,
319 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
320 .value = reinterpret_cast<const uint8_t*>("application/dns-message"),
321 .value_len = sizeof("application/dns-message") - 1,
322 },
323 };
324 auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
325 streamID, headers.data(),
326 // do not include content-type header info if there is no content
327 (len > 0 && statusCode == 200U ? headers.size() : headers.size() - 1),
328 len == 0);
329 if (returnValue != 0) {
330 /* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
331 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
332 return;
333 }
334
335 if (len == 0) {
336 return;
337 }
338
339 size_t pos = 0;
340 while (pos < len) {
341 // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
342 auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
343 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
344 streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
345 if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
346 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
347 conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len);
348 return;
349 }
350 if (res < 0) {
351 // Shutdown with internal error code
352 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
353 return;
354 }
355 pos += res;
356 }
357 }
358
359 static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
360 {
361 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
362 h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
363 }
364
365 static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
366 {
367 if (statusCode == 200) {
368 ++frontend.d_validResponses;
369 }
370 else {
371 ++frontend.d_errorResponses;
372 }
373 if (response.empty()) {
374 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
375 }
376 else {
377 h3_send_response(conn, streamID, statusCode, &response.at(0), response.size());
378 }
379 }
380
381 void DOH3Frontend::setup()
382 {
383 auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free);
384 d_quicheParams.d_alpn = std::string(DOH3_ALPN.begin(), DOH3_ALPN.end());
385 configureQuiche(config, d_quicheParams, true);
386
387 auto http3config = QuicheHTTP3Config(quiche_h3_config_new(), quiche_h3_config_free);
388
389 d_server_config = std::make_unique<DOH3ServerConfig>(std::move(config), std::move(http3config), d_internalPipeBufferSize);
390 }
391
392 void DOH3Frontend::reloadCertificates()
393 {
394 auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free);
395 d_quicheParams.d_alpn = std::string(DOH3_ALPN.begin(), DOH3_ALPN.end());
396 configureQuiche(config, d_quicheParams, true);
397 std::atomic_store_explicit(&d_server_config->config, std::move(config), std::memory_order_release);
398 }
399
400 static std::optional<std::reference_wrapper<H3Connection>> getConnection(DOH3ServerConfig::ConnectionsMap& connMap, const PacketBuffer& connID)
401 {
402 auto iter = connMap.find(connID);
403 if (iter == connMap.end()) {
404 return std::nullopt;
405 }
406 return iter->second;
407 }
408
409 static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description)
410 {
411 if (unit->dsc == nullptr) {
412 return;
413 }
414 try {
415 if (!unit->dsc->d_responseSender.send(std::move(unit))) {
416 ++dnsdist::metrics::g_stats.doh3ResponsePipeFull;
417 vinfolog("Unable to pass a %s to the DoH3 worker thread because the pipe is full", description);
418 }
419 }
420 catch (const std::exception& e) {
421 vinfolog("Unable to pass a %s to the DoH3 worker thread because we couldn't write to the pipe: %s", description, e.what());
422 }
423 }
424
425 static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& localAddr, const ComboAddress& peer)
426 {
427 auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire);
428 auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
429 originalDestinationID.data(), originalDestinationID.size(),
430 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
431 reinterpret_cast<const struct sockaddr*>(&localAddr),
432 localAddr.getSocklen(),
433 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
434 reinterpret_cast<const struct sockaddr*>(&peer),
435 peer.getSocklen(),
436 quicheConfig.get()),
437 quiche_conn_free);
438
439 if (config.df && !config.df->d_quicheParams.d_keyLogFile.empty()) {
440 quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str());
441 }
442
443 auto conn = H3Connection(peer, localAddr, std::move(quicheConfig), std::move(quicheConn));
444 auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
445 return pair.first->second;
446 }
447
448 std::unique_ptr<CrossProtocolQuery> getDOH3CrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion, bool isResponse)
449 {
450 if (!dnsQuestion.ids.doh3u) {
451 throw std::runtime_error("Trying to create a DoH3 cross protocol query without a valid DoH3 unit");
452 }
453
454 auto unit = std::move(dnsQuestion.ids.doh3u);
455 if (&dnsQuestion.ids != &unit->ids) {
456 unit->ids = std::move(dnsQuestion.ids);
457 }
458
459 unit->ids.origID = dnsQuestion.getHeader()->id;
460
461 if (!isResponse) {
462 if (unit->query.data() != dnsQuestion.getMutableData().data()) {
463 unit->query = std::move(dnsQuestion.getMutableData());
464 }
465 }
466 else {
467 if (unit->response.data() != dnsQuestion.getMutableData().data()) {
468 unit->response = std::move(dnsQuestion.getMutableData());
469 }
470 }
471
472 return std::make_unique<DOH3CrossProtocolQuery>(std::move(unit), isResponse);
473 }
474
475 static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
476 {
477 const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
478 DEBUGLOG("handleImmediateResponse() reason=" << reason);
479 auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
480 handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
481 unit->ids.doh3u.reset();
482 };
483
484 auto& ids = doh3Unit->ids;
485 ids.doh3u = std::move(doh3Unit);
486 auto& unit = ids.doh3u;
487 uint16_t queryId = 0;
488 ComboAddress remote;
489
490 try {
491
492 remote = unit->ids.origRemote;
493 DOH3ServerConfig* dsc = unit->dsc;
494 auto& holders = dsc->holders;
495 ClientState& clientState = *dsc->clientState;
496
497 if (!holders.acl->match(remote)) {
498 vinfolog("Query from %s (DoH3) dropped because of ACL", remote.toStringWithPort());
499 ++dnsdist::metrics::g_stats.aclDrops;
500 unit->response.clear();
501
502 unit->status_code = 403;
503 handleImmediateResponse(std::move(unit), "DoH3 query dropped because of ACL");
504 return;
505 }
506
507 if (unit->query.size() < sizeof(dnsheader)) {
508 ++dnsdist::metrics::g_stats.nonCompliantQueries;
509 ++clientState.nonCompliantQueries;
510 unit->response.clear();
511
512 unit->status_code = 400;
513 handleImmediateResponse(std::move(unit), "DoH3 non-compliant query");
514 return;
515 }
516
517 ++clientState.queries;
518 ++dnsdist::metrics::g_stats.queries;
519 unit->ids.queryRealTime.start();
520
521 {
522 /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
523 dnsheader_aligned dnsHeader(unit->query.data());
524
525 if (!checkQueryHeaders(*dnsHeader, clientState)) {
526 dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
527 header.rcode = RCode::ServFail;
528 header.qr = true;
529 return true;
530 });
531 unit->response = std::move(unit->query);
532
533 unit->status_code = 400;
534 handleImmediateResponse(std::move(unit), "DoH3 invalid headers");
535 return;
536 }
537
538 if (dnsHeader->qdcount == 0) {
539 dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
540 header.rcode = RCode::NotImp;
541 header.qr = true;
542 return true;
543 });
544 unit->response = std::move(unit->query);
545
546 unit->status_code = 400;
547 handleImmediateResponse(std::move(unit), "DoH3 empty query");
548 return;
549 }
550
551 queryId = ntohs(dnsHeader->id);
552 }
553
554 auto downstream = unit->downstream;
555 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
556 unit->ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass);
557 DNSQuestion dnsQuestion(unit->ids, unit->query);
558 dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
559 const uint16_t* flags = getFlagsFromDNSHeader(&header);
560 ids.origFlags = *flags;
561 return true;
562 });
563 unit->ids.cs = &clientState;
564
565 auto result = processQuery(dnsQuestion, holders, downstream);
566 if (result == ProcessQueryResult::Drop) {
567 unit->status_code = 403;
568 handleImmediateResponse(std::move(unit), "DoH3 dropped query");
569 return;
570 }
571 if (result == ProcessQueryResult::Asynchronous) {
572 return;
573 }
574 if (result == ProcessQueryResult::SendAnswer) {
575 if (unit->response.empty()) {
576 unit->response = std::move(unit->query);
577 }
578 if (unit->response.size() >= sizeof(dnsheader)) {
579 const dnsheader_aligned dnsHeader(unit->response.data());
580
581 handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoH3, dnsdist::Protocol::DoH3, false);
582 }
583 handleImmediateResponse(std::move(unit), "DoH3 self-answered response");
584 return;
585 }
586
587 ++dnsdist::metrics::g_stats.responses;
588 if (unit->ids.cs != nullptr) {
589 ++unit->ids.cs->responses;
590 }
591
592 if (result != ProcessQueryResult::PassToBackend) {
593 unit->status_code = 500;
594 handleImmediateResponse(std::move(unit), "DoH3 no backend available");
595 return;
596 }
597
598 if (downstream == nullptr) {
599 unit->status_code = 502;
600 handleImmediateResponse(std::move(unit), "DoH3 no backend available");
601 return;
602 }
603
604 unit->downstream = downstream;
605
606 std::string proxyProtocolPayload;
607 /* we need to do this _before_ creating the cross protocol query because
608 after that the buffer will have been moved */
609 if (downstream->d_config.useProxyProtocol) {
610 proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
611 }
612
613 unit->ids.origID = htons(queryId);
614 unit->tcp = true;
615
616 /* this moves unit->ids, careful! */
617 auto cpq = std::make_unique<DOH3CrossProtocolQuery>(std::move(unit), false);
618 cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
619
620 if (downstream->passCrossProtocolQuery(std::move(cpq))) {
621 return;
622 }
623 // NOLINTNEXTLINE(bugprone-use-after-move): it was only moved if the call succeeded
624 unit = cpq->releaseDU();
625 unit->status_code = 500;
626 handleImmediateResponse(std::move(unit), "DoH3 internal error");
627 return;
628 }
629 catch (const std::exception& e) {
630 vinfolog("Got an error in DOH3 question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
631 unit->status_code = 500;
632 handleImmediateResponse(std::move(unit), "DoH3 internal error");
633 return;
634 }
635 }
636
637 static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
638 {
639 try {
640 auto unit = std::make_unique<DOH3Unit>(std::move(query));
641 unit->dsc = &dsc;
642 unit->ids.origDest = local;
643 unit->ids.origRemote = remote;
644 unit->ids.protocol = dnsdist::Protocol::DoH3;
645 unit->serverConnID = serverConnID;
646 unit->streamID = streamID;
647
648 processDOH3Query(std::move(unit));
649 }
650 catch (const std::exception& exp) {
651 vinfolog("Had error handling DoH3 DNS packet from %s: %s", remote.toStringWithPort(), exp.what());
652 }
653 }
654
655 static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
656 {
657 for (;;) {
658 try {
659 auto tmp = receiver.receive();
660 if (!tmp) {
661 return;
662 }
663
664 auto unit = std::move(*tmp);
665 auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
666 if (conn) {
667 handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
668 }
669 }
670 catch (const std::exception& e) {
671 errlog("Error while processing response received over DoH3: %s", e.what());
672 }
673 catch (...) {
674 errlog("Unspecified error while processing response received over DoH3");
675 }
676 }
677 }
678
679 static void flushStalledResponses(H3Connection& conn)
680 {
681 for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
682 const auto streamID = streamIt->first;
683 auto& response = streamIt->second;
684 if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
685 if (tryWriteResponse(conn, streamID, response)) {
686 streamIt = conn.d_streamOutBuffers.erase(streamIt);
687 continue;
688 }
689 }
690 ++streamIt;
691 }
692 }
693
694 static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, const uint64_t streamID, quiche_h3_event* event)
695 {
696 auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
697 DEBUGLOG(msg);
698 ++dnsdist::metrics::g_stats.nonCompliantQueries;
699 ++clientState.nonCompliantQueries;
700 ++frontend.d_errorResponses;
701 h3_send_response(conn, streamID, 400, msg);
702 };
703
704 auto& headers = conn.d_headersBuffers.at(streamID);
705 // Callback result. Any value other than 0 will interrupt further header processing.
706 int cbresult = quiche_h3_event_for_each_header(
707 event,
708 [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
709 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
710 std::string_view key(reinterpret_cast<char*>(name), name_len);
711 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
712 std::string_view content(reinterpret_cast<char*>(value), value_len);
713 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
714 auto* headersptr = reinterpret_cast<h3_headers_t*>(argp);
715 headersptr->emplace(key, content);
716 return 0;
717 },
718 &headers);
719
720 #ifdef DEBUGLOG_ENABLED
721 DEBUGLOG("Processed headers of stream " << streamID);
722 for (const auto& [key, value] : headers) {
723 DEBUGLOG(" " << key << ": " << value);
724 }
725 #endif
726 if (cbresult != 0 || headers.count(":method") == 0) {
727 handleImmediateError("Unable to process query headers");
728 return;
729 }
730
731 if (headers.at(":method") == "GET") {
732 if (headers.count(":path") == 0 || headers.at(":path").empty()) {
733 handleImmediateError("Path not found");
734 return;
735 }
736 const auto& path = headers.at(":path");
737 auto payload = dnsdist::doh::getPayloadFromPath(path);
738 if (!payload) {
739 handleImmediateError("Unable to find the DNS parameter");
740 return;
741 }
742 if (payload->size() < sizeof(dnsheader)) {
743 handleImmediateError("DoH3 non-compliant query");
744 return;
745 }
746 DEBUGLOG("Dispatching GET query");
747 doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID);
748 conn.d_streamBuffers.erase(streamID);
749 conn.d_headersBuffers.erase(streamID);
750 return;
751 }
752
753 if (headers.at(":method") == "POST") {
754 if (!quiche_h3_event_headers_has_body(event)) {
755 handleImmediateError("Empty POST query");
756 }
757 return;
758 }
759
760 handleImmediateError("Unsupported HTTP method");
761 }
762
763 static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, const uint64_t streamID, quiche_h3_event* event, PacketBuffer& buffer)
764 {
765 auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
766 DEBUGLOG(msg);
767 ++dnsdist::metrics::g_stats.nonCompliantQueries;
768 ++clientState.nonCompliantQueries;
769 ++frontend.d_errorResponses;
770 h3_send_response(conn, streamID, 400, msg);
771 };
772 auto& headers = conn.d_headersBuffers.at(streamID);
773
774 if (headers.at(":method") != "POST") {
775 handleImmediateError("DATA frame for non-POST method");
776 return;
777 }
778
779 if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
780 handleImmediateError("Unsupported content-type");
781 return;
782 }
783
784 buffer.resize(std::numeric_limits<uint16_t>::max());
785 auto& streamBuffer = conn.d_streamBuffers[streamID];
786
787 while (true) {
788 buffer.resize(std::numeric_limits<uint16_t>::max());
789 ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
790 conn.d_conn.get(), streamID,
791 buffer.data(), buffer.size());
792
793 if (len <= 0) {
794 break;
795 }
796
797 buffer.resize(static_cast<size_t>(len));
798 streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end());
799 }
800
801 if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) {
802 return;
803 }
804
805 if (streamBuffer.size() < sizeof(dnsheader)) {
806 conn.d_streamBuffers.erase(streamID);
807 handleImmediateError("DoH3 non-compliant query");
808 return;
809 }
810
811 DEBUGLOG("Dispatching POST query");
812 doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
813 conn.d_headersBuffers.erase(streamID);
814 conn.d_streamBuffers.erase(streamID);
815 }
816
817 static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, PacketBuffer& buffer)
818 {
819 while (true) {
820 quiche_h3_event* event{nullptr};
821 // Processes HTTP/3 data received from the peer
822 const int64_t streamID = quiche_h3_conn_poll(conn.d_http3.get(),
823 conn.d_conn.get(),
824 &event);
825
826 if (streamID < 0) {
827 break;
828 }
829 conn.d_headersBuffers.try_emplace(streamID, h3_headers_t{});
830
831 switch (quiche_h3_event_type(event)) {
832 case QUICHE_H3_EVENT_HEADERS: {
833 processH3HeaderEvent(clientState, frontend, conn, client, serverConnID, streamID, event);
834 break;
835 }
836 case QUICHE_H3_EVENT_DATA: {
837 processH3DataEvent(clientState, frontend, conn, client, serverConnID, streamID, event, buffer);
838 break;
839 }
840 case QUICHE_H3_EVENT_FINISHED:
841 case QUICHE_H3_EVENT_RESET:
842 case QUICHE_H3_EVENT_PRIORITY_UPDATE:
843 case QUICHE_H3_EVENT_GOAWAY:
844 break;
845 }
846
847 quiche_h3_event_free(event);
848 }
849 }
850
851 static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
852 {
853 // destination connection ID, will have to be sent as original destination connection ID
854 PacketBuffer serverConnID;
855 // source connection ID, will have to be sent as destination connection ID
856 PacketBuffer clientConnID;
857 PacketBuffer tokenBuf;
858 while (true) {
859 ComboAddress client;
860 ComboAddress localAddr;
861 client.sin4.sin_family = clientState.local.sin4.sin_family;
862 localAddr.sin4.sin_family = clientState.local.sin4.sin_family;
863 buffer.resize(4096);
864 if (!dnsdist::doq::recvAsync(sock, buffer, client, localAddr)) {
865 return;
866 }
867 if (localAddr.sin4.sin_family == 0) {
868 localAddr = clientState.local;
869 }
870 else {
871 /* we don't get the port, only the address */
872 localAddr.sin4.sin_port = clientState.local.sin4.sin_port;
873 }
874
875 DEBUGLOG("Received DoH3 datagram of size " << buffer.size() << " from " << client.toStringWithPort());
876
877 uint32_t version{0};
878 uint8_t type{0};
879 std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> scid{};
880 size_t scid_len = scid.size();
881 std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid{};
882 size_t dcid_len = dcid.size();
883 std::array<uint8_t, MAX_TOKEN_LEN> token{};
884 size_t token_len = token.size();
885
886 auto res = quiche_header_info(buffer.data(), buffer.size(), LOCAL_CONN_ID_LEN,
887 &version, &type,
888 scid.data(), &scid_len,
889 dcid.data(), &dcid_len,
890 token.data(), &token_len);
891 if (res != 0) {
892 DEBUGLOG("Error in quiche_header_info: " << res);
893 continue;
894 }
895
896 serverConnID.assign(dcid.begin(), dcid.begin() + dcid_len);
897 // source connection ID, will have to be sent as destination connection ID
898 clientConnID.assign(scid.begin(), scid.begin() + scid_len);
899 auto conn = getConnection(frontend.d_server_config->d_connections, serverConnID);
900
901 if (!conn) {
902 DEBUGLOG("Connection not found");
903 if (type != static_cast<uint8_t>(DOQ_Packet_Types::QUIC_PACKET_TYPE_INITIAL)) {
904 DEBUGLOG("Packet is not initial");
905 continue;
906 }
907
908 if (!quiche_version_is_supported(version)) {
909 DEBUGLOG("Unsupported version");
910 ++frontend.d_doh3UnsupportedVersionErrors;
911 handleVersionNegociation(sock, clientConnID, serverConnID, client, localAddr, buffer);
912 continue;
913 }
914
915 if (token_len == 0) {
916 /* stateless retry */
917 DEBUGLOG("No token received");
918 handleStatelessRetry(sock, clientConnID, serverConnID, client, localAddr, version, buffer);
919 continue;
920 }
921
922 tokenBuf.assign(token.begin(), token.begin() + token_len);
923 auto originalDestinationID = validateToken(tokenBuf, client);
924 if (!originalDestinationID) {
925 ++frontend.d_doh3InvalidTokensReceived;
926 DEBUGLOG("Discarding invalid token");
927 continue;
928 }
929
930 DEBUGLOG("Creating a new connection");
931 conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, localAddr, client);
932 if (!conn) {
933 continue;
934 }
935 }
936 DEBUGLOG("Connection found");
937 quiche_recv_info recv_info = {
938 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
939 reinterpret_cast<struct sockaddr*>(&client),
940 client.getSocklen(),
941 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
942 reinterpret_cast<struct sockaddr*>(&localAddr),
943 localAddr.getSocklen(),
944 };
945
946 auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
947 if (done < 0) {
948 continue;
949 }
950
951 if (quiche_conn_is_established(conn->get().d_conn.get()) || quiche_conn_is_in_early_data(conn->get().d_conn.get())) {
952 DEBUGLOG("Connection is established");
953
954 if (!conn->get().d_http3) {
955 conn->get().d_http3 = QuicheHTTP3Connection(quiche_h3_conn_new_with_transport(conn->get().d_conn.get(), frontend.d_server_config->http3config.get()),
956 quiche_h3_conn_free);
957 if (!conn->get().d_http3) {
958 continue;
959 }
960 DEBUGLOG("Successfully created HTTP/3 connection");
961 }
962
963 processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer);
964
965 flushEgress(sock, conn->get().d_conn, client, localAddr, buffer);
966 }
967 else {
968 DEBUGLOG("Connection not established");
969 }
970 }
971 }
972
973 // this is the entrypoint from dnsdist.cc
974 void doh3Thread(ClientState* clientState)
975 {
976 try {
977 std::shared_ptr<DOH3Frontend>& frontend = clientState->doh3Frontend;
978
979 frontend->d_server_config->clientState = clientState;
980 frontend->d_server_config->df = clientState->doh3Frontend;
981
982 setThreadName("dnsdist/doh3");
983
984 Socket sock(clientState->udpFD);
985 sock.setNonBlocking();
986
987 auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
988
989 auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
990 mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
991 mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {});
992 std::vector<int> readyFDs;
993 PacketBuffer buffer(4096);
994 while (true) {
995 readyFDs.clear();
996 mplexer->getAvailableFDs(readyFDs, 500);
997
998 try {
999 if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
1000 handleSocketReadable(*frontend, *clientState, sock, buffer);
1001 }
1002
1003 if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {
1004 flushResponses(frontend->d_server_config->d_responseReceiver);
1005 }
1006
1007 for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
1008 quiche_conn_on_timeout(conn->second.d_conn.get());
1009
1010 flushEgress(sock, conn->second.d_conn, conn->second.d_peer, conn->second.d_localAddr, buffer);
1011
1012 if (quiche_conn_is_closed(conn->second.d_conn.get())) {
1013 #ifdef DEBUGLOG_ENABLED
1014 quiche_stats stats;
1015 quiche_path_stats path_stats;
1016
1017 quiche_conn_stats(conn->second.d_conn.get(), &stats);
1018 quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats);
1019
1020 DEBUGLOG("Connection (DoH3) closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
1021 #endif
1022 conn = frontend->d_server_config->d_connections.erase(conn);
1023 }
1024 else {
1025 flushStalledResponses(conn->second);
1026 ++conn;
1027 }
1028 }
1029 }
1030 catch (const std::exception& exp) {
1031 vinfolog("Caught exception in the main DoH3 thread: %s", exp.what());
1032 }
1033 catch (...) {
1034 vinfolog("Unknown exception in the main DoH3 thread");
1035 }
1036 }
1037 }
1038 catch (const std::exception& e) {
1039 DEBUGLOG("Caught fatal error in the main DoH3 thread: " << e.what());
1040 }
1041 }
1042
1043 #endif /* HAVE_DNS_OVER_HTTP3 */