]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/doq.cc
Merge pull request #13387 from omoerbeek/rec-b-root-servers
[thirdparty/pdns.git] / pdns / dnsdistdist / doq.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22
23 #include "doq.hh"
24
25 #ifdef HAVE_DNS_OVER_QUIC
26 #include <quiche.h>
27
28 #include "dnsparser.hh"
29 #include "dolog.hh"
30 #include "iputils.hh"
31 #include "misc.hh"
32 #include "sodcrypto.hh"
33 #include "sstuff.hh"
34 #include "threadname.hh"
35
36 #include "dnsdist-ecs.hh"
37 #include "dnsdist-dnsparser.hh"
38 #include "dnsdist-proxy-protocol.hh"
39 #include "dnsdist-tcp.hh"
40 #include "dnsdist-random.hh"
41
42 static std::string s_quicRetryTokenKey = newKey(false);
43
44 std::map<const string, int> DOQFrontend::s_available_cc_algorithms = {
45 {"reno", QUICHE_CC_RENO},
46 {"cubic", QUICHE_CC_CUBIC},
47 {"bbr", QUICHE_CC_BBR},
48 };
49
50 using QuicheConnection = std::unique_ptr<quiche_conn, decltype(&quiche_conn_free)>;
51 using QuicheConfig = std::unique_ptr<quiche_config, decltype(&quiche_config_free)>;
52
53 class Connection
54 {
55 public:
56 Connection(const ComboAddress& peer, QuicheConnection&& conn) :
57 d_peer(peer), d_conn(std::move(conn))
58 {
59 }
60 Connection(const Connection&) = delete;
61 Connection(Connection&&) = default;
62 Connection& operator=(const Connection&) = delete;
63 Connection& operator=(Connection&&) = default;
64 ~Connection() = default;
65
66 ComboAddress d_peer;
67 QuicheConnection d_conn;
68 std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
69 };
70
71 static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
72
73 struct DOQServerConfig
74 {
75 DOQServerConfig(QuicheConfig&& config_, uint32_t internalPipeBufferSize) :
76 config(std::move(config_))
77 {
78 {
79 auto [sender, receiver] = pdns::channel::createObjectQueue<DOQUnit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
80 d_responseSender = std::move(sender);
81 d_responseReceiver = std::move(receiver);
82 }
83 }
84 DOQServerConfig(const DOQServerConfig&) = delete;
85 DOQServerConfig(DOQServerConfig&&) = default;
86 DOQServerConfig& operator=(const DOQServerConfig&) = delete;
87 DOQServerConfig& operator=(DOQServerConfig&&) = default;
88 ~DOQServerConfig() = default;
89
90 using ConnectionsMap = std::map<PacketBuffer, Connection>;
91
92 LocalHolders holders;
93 ConnectionsMap d_connections;
94 QuicheConfig config;
95 ClientState* clientState{nullptr};
96 std::shared_ptr<DOQFrontend> df{nullptr};
97 pdns::channel::Sender<DOQUnit> d_responseSender;
98 pdns::channel::Receiver<DOQUnit> d_responseReceiver;
99 };
100
101 /* these might seem useless, but they are needed because
102 they need to be declared _after_ the definition of DOQServerConfig
103 so that we can use a unique_ptr in DOQFrontend */
104 DOQFrontend::DOQFrontend() = default;
105 DOQFrontend::~DOQFrontend() = default;
106
107 #if 0
108 #define DEBUGLOG_ENABLED
109 #define DEBUGLOG(x) std::cerr << x << std::endl;
110 #else
111 #define DEBUGLOG(x)
112 #endif
113
114 static constexpr size_t MAX_DATAGRAM_SIZE = 1200;
115 static constexpr size_t LOCAL_CONN_ID_LEN = 16;
116
117 class DOQTCPCrossQuerySender final : public TCPQuerySender
118 {
119 public:
120 DOQTCPCrossQuerySender() = default;
121
122 [[nodiscard]] bool active() const override
123 {
124 return true;
125 }
126
127 void handleResponse([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
128 {
129 if (!response.d_idstate.doqu) {
130 return;
131 }
132
133 auto unit = std::move(response.d_idstate.doqu);
134 if (unit->dsc == nullptr) {
135 return;
136 }
137
138 unit->response = std::move(response.d_buffer);
139 unit->ids = std::move(response.d_idstate);
140 DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream);
141
142 dnsheader cleartextDH{};
143 memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
144
145 if (!response.isAsync()) {
146
147 static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
148 static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
149
150 dnsResponse.ids.doqu = std::move(unit);
151
152 if (!processResponse(dnsResponse.ids.doqu->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) {
153 if (dnsResponse.ids.doqu) {
154
155 sendBackDOQUnit(std::move(dnsResponse.ids.doqu), "Response dropped by rules");
156 }
157 return;
158 }
159
160 if (dnsResponse.isAsynchronous()) {
161 return;
162 }
163
164 unit = std::move(dnsResponse.ids.doqu);
165 }
166
167 if (!unit->ids.selfGenerated) {
168 double udiff = unit->ids.queryRealTime.udiff();
169 vinfolog("Got answer from %s, relayed to %s (quic), took %f us", unit->downstream->d_config.remote.toStringWithPort(), unit->ids.origRemote.toStringWithPort(), udiff);
170
171 auto backendProtocol = unit->downstream->getProtocol();
172 if (backendProtocol == dnsdist::Protocol::DoUDP && unit->tcp) {
173 backendProtocol = dnsdist::Protocol::DoTCP;
174 }
175 handleResponseSent(unit->ids, udiff, unit->ids.origRemote, unit->downstream->d_config.remote, unit->response.size(), cleartextDH, backendProtocol, true);
176 }
177
178 ++dnsdist::metrics::g_stats.responses;
179 if (unit->ids.cs != nullptr) {
180 ++unit->ids.cs->responses;
181 }
182
183 sendBackDOQUnit(std::move(unit), "Cross-protocol response");
184 }
185
186 void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
187 {
188 return handleResponse(now, std::move(response));
189 }
190
191 void notifyIOError([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
192 {
193 if (!response.d_idstate.doqu) {
194 return;
195 }
196
197 auto unit = std::move(response.d_idstate.doqu);
198 if (unit->dsc == nullptr) {
199 return;
200 }
201
202 /* this will signal an error */
203 unit->response.clear();
204 unit->ids = std::move(response.d_idstate);
205 sendBackDOQUnit(std::move(unit), "Cross-protocol error");
206 }
207 };
208
209 class DOQCrossProtocolQuery : public CrossProtocolQuery
210 {
211 public:
212 DOQCrossProtocolQuery(DOQUnitUniquePtr&& unit, bool isResponse)
213 {
214 if (isResponse) {
215 /* happens when a response becomes async */
216 query = InternalQuery(std::move(unit->response), std::move(unit->ids));
217 }
218 else {
219 /* we need to duplicate the query here because we might need
220 the existing query later if we get a truncated answer */
221 query = InternalQuery(PacketBuffer(unit->query), std::move(unit->ids));
222 }
223
224 /* it might have been moved when we moved unit->ids */
225 if (unit) {
226 query.d_idstate.doqu = std::move(unit);
227 }
228
229 /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
230 clearing query.d_proxyProtocolPayloadAdded and unit->proxyProtocolPayloadSize.
231 Leave it for now because we know that the onky case where the payload has been
232 added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
233 and we know the TCP/DoT code can handle it. */
234 query.d_proxyProtocolPayloadAdded = query.d_idstate.doqu->proxyProtocolPayloadSize > 0;
235 downstream = query.d_idstate.doqu->downstream;
236 }
237
238 void handleInternalError()
239 {
240 sendBackDOQUnit(std::move(query.d_idstate.doqu), "DOQ internal error");
241 }
242
243 std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
244 {
245 query.d_idstate.doqu->downstream = downstream;
246 return s_sender;
247 }
248
249 DNSQuestion getDQ() override
250 {
251 auto& ids = query.d_idstate;
252 DNSQuestion dnsQuestion(ids, query.d_buffer);
253 return dnsQuestion;
254 }
255
256 DNSResponse getDR() override
257 {
258 auto& ids = query.d_idstate;
259 DNSResponse dnsResponse(ids, query.d_buffer, downstream);
260 return dnsResponse;
261 }
262
263 DOQUnitUniquePtr&& releaseDU()
264 {
265 return std::move(query.d_idstate.doqu);
266 }
267
268 private:
269 static std::shared_ptr<DOQTCPCrossQuerySender> s_sender;
270 };
271
272 std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();
273
274 /* from rfc9250 section-4.3 */
275 enum class DOQ_Error_Codes : uint64_t
276 {
277 DOQ_NO_ERROR = 0,
278 DOQ_INTERNAL_ERROR = 1,
279 DOQ_PROTOCOL_ERROR = 2,
280 DOQ_REQUEST_CANCELLED = 3,
281 DOQ_EXCESSIVE_LOAD = 4,
282 DOQ_UNSPECIFIED_ERROR = 5
283 };
284
285 static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
286 {
287 if (response.empty()) {
288 ++frontend.d_errorResponses;
289 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
290 return;
291 }
292 ++frontend.d_validResponses;
293 auto responseSize = static_cast<uint16_t>(response.size());
294 const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
295 size_t pos = 0;
296 while (pos < sizeBytes.size()) {
297 auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &sizeBytes.at(pos), sizeBytes.size() - pos, false);
298 if (res < 0) {
299 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
300 return;
301 }
302 pos += res;
303 }
304
305 pos = 0;
306 while (pos < response.size()) {
307 auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
308 if (res < 0) {
309 quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
310 return;
311 }
312 pos += res;
313 }
314 }
315
316 static void fillRandom(PacketBuffer& buffer, size_t size)
317 {
318 buffer.reserve(size);
319 while (size > 0) {
320 buffer.insert(buffer.end(), dnsdist::getRandomValue(std::numeric_limits<uint8_t>::max()));
321 --size;
322 }
323 }
324
325 void DOQFrontend::setup()
326 {
327 auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free);
328 for (const auto& pair : d_tlsConfig.d_certKeyPairs) {
329 auto res = quiche_config_load_cert_chain_from_pem_file(config.get(), pair.d_cert.c_str());
330 if (res != 0) {
331 throw std::runtime_error("Error loading the server certificate: " + std::to_string(res));
332 }
333 if (pair.d_key) {
334 res = quiche_config_load_priv_key_from_pem_file(config.get(), pair.d_key->c_str());
335 if (res != 0) {
336 throw std::runtime_error("Error loading the server key: " + std::to_string(res));
337 }
338 }
339 }
340
341 {
342 constexpr std::array<uint8_t, 4> alpn{'\x03', 'd', 'o', 'q'};
343 auto res = quiche_config_set_application_protos(config.get(),
344 alpn.data(),
345 alpn.size());
346 if (res != 0) {
347 throw std::runtime_error("Error setting ALPN: " + std::to_string(res));
348 }
349 }
350
351 quiche_config_set_max_idle_timeout(config.get(), d_idleTimeout * 1000);
352 /* maximum size of an outgoing packet, which means the buffer we pass to quiche_conn_send() should be at least that big */
353 quiche_config_set_max_send_udp_payload_size(config.get(), MAX_DATAGRAM_SIZE);
354
355 // The number of concurrent remotely-initiated bidirectional streams to be open at any given time
356 // https://docs.rs/quiche/latest/quiche/struct.Config.html#method.set_initial_max_streams_bidi
357 // 0 means none will get accepted, that's why we have a default value of 65535
358 quiche_config_set_initial_max_streams_bidi(config.get(), d_maxInFlight);
359
360 // The number of bytes of incoming stream data to be buffered for each localy or remotely-initiated bidirectional stream
361 quiche_config_set_initial_max_stream_data_bidi_local(config.get(), 8192);
362 quiche_config_set_initial_max_stream_data_bidi_remote(config.get(), 8192);
363
364 // The number of total bytes of incoming stream data to be buffered for the whole connection
365 // https://docs.rs/quiche/latest/quiche/struct.Config.html#method.set_initial_max_data
366 quiche_config_set_initial_max_data(config.get(), 8192 * d_maxInFlight);
367 if (!d_keyLogFile.empty()) {
368 quiche_config_log_keys(config.get());
369 }
370
371 auto algo = DOQFrontend::s_available_cc_algorithms.find(d_ccAlgo);
372 if (algo != DOQFrontend::s_available_cc_algorithms.end()) {
373 quiche_config_set_cc_algorithm(config.get(), static_cast<enum quiche_cc_algorithm>(algo->second));
374 }
375
376 {
377 PacketBuffer resetToken;
378 fillRandom(resetToken, 16);
379 quiche_config_set_stateless_reset_token(config.get(), resetToken.data());
380 }
381
382 d_server_config = std::make_unique<DOQServerConfig>(std::move(config), d_internalPipeBufferSize);
383 }
384
385 static std::optional<PacketBuffer> getCID()
386 {
387 PacketBuffer buffer;
388
389 fillRandom(buffer, LOCAL_CONN_ID_LEN);
390
391 return buffer;
392 }
393
394 static constexpr size_t MAX_TOKEN_LEN = dnsdist::crypto::authenticated::getEncryptedSize(std::tuple_size<decltype(SodiumNonce::value)>{} /* nonce */ + sizeof(uint64_t) /* TTD */ + 16 /* IPv6 */ + QUICHE_MAX_CONN_ID_LEN);
395
396 static PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer)
397 {
398 try {
399 SodiumNonce nonce;
400 nonce.init();
401
402 const auto addrBytes = peer.toByteString();
403 // this token will be valid for 60s
404 const uint64_t ttd = time(nullptr) + 60U;
405 PacketBuffer plainTextToken;
406 plainTextToken.reserve(sizeof(ttd) + addrBytes.size() + dcid.size());
407 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
408 plainTextToken.insert(plainTextToken.end(), reinterpret_cast<const uint8_t*>(&ttd), reinterpret_cast<const uint8_t*>(&ttd) + sizeof(ttd));
409 plainTextToken.insert(plainTextToken.end(), addrBytes.begin(), addrBytes.end());
410 plainTextToken.insert(plainTextToken.end(), dcid.begin(), dcid.end());
411 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
412 const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast<const char*>(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false);
413 // a bit sad, let's see if we can do better later
414 auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end());
415 encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end());
416 return encryptedTokenPacket;
417 }
418 catch (const std::exception& exp) {
419 vinfolog("Error while minting DoQ token: %s", exp.what());
420 throw;
421 }
422 }
423
424 // returns the original destination ID if the token is valid, nothing otherwise
425 static std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const ComboAddress& peer)
426 {
427 try {
428 SodiumNonce nonce;
429 auto addrBytes = peer.toByteString();
430 const uint64_t now = time(nullptr);
431 const auto minimumSize = nonce.value.size() + sizeof(now) + addrBytes.size();
432 if (token.size() <= minimumSize) {
433 return std::nullopt;
434 }
435
436 memcpy(nonce.value.data(), token.data(), nonce.value.size());
437
438 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
439 auto cipher = std::string_view(reinterpret_cast<const char*>(&token.at(nonce.value.size())), token.size() - nonce.value.size());
440 auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false);
441
442 if (plainText.size() <= sizeof(now) + addrBytes.size()) {
443 return std::nullopt;
444 }
445
446 uint64_t ttd{0};
447 memcpy(&ttd, plainText.data(), sizeof(ttd));
448 if (ttd < now) {
449 return std::nullopt;
450 }
451
452 if (std::memcmp(&plainText.at(sizeof(ttd)), &*addrBytes.begin(), addrBytes.size()) != 0) {
453 return std::nullopt;
454 }
455 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
456 return PacketBuffer(plainText.begin() + (sizeof(ttd) + addrBytes.size()), plainText.end());
457 }
458 catch (const std::exception& exp) {
459 vinfolog("Error while validating DoQ token: %s", exp.what());
460 return std::nullopt;
461 }
462 }
463
464 static void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version)
465 {
466 auto newServerConnID = getCID();
467 if (!newServerConnID) {
468 return;
469 }
470
471 auto token = mintToken(serverConnID, peer);
472
473 PacketBuffer out(MAX_DATAGRAM_SIZE);
474 auto written = quiche_retry(clientConnID.data(), clientConnID.size(),
475 serverConnID.data(), serverConnID.size(),
476 newServerConnID->data(), newServerConnID->size(),
477 token.data(), token.size(),
478 version,
479 out.data(), out.size());
480
481 if (written < 0) {
482 DEBUGLOG("failed to create retry packet " << written);
483 return;
484 }
485
486 out.resize(written);
487 sock.sendTo(std::string(out.begin(), out.end()), peer);
488 }
489
490 static void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer)
491 {
492 PacketBuffer out(MAX_DATAGRAM_SIZE);
493
494 auto written = quiche_negotiate_version(clientConnID.data(), clientConnID.size(),
495 serverConnID.data(), serverConnID.size(),
496 out.data(), out.size());
497
498 if (written < 0) {
499 DEBUGLOG("failed to create vneg packet " << written);
500 return;
501 }
502 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
503 sock.sendTo(reinterpret_cast<const char*>(out.data()), written, peer);
504 }
505
506 static std::optional<std::reference_wrapper<Connection>> getConnection(DOQServerConfig::ConnectionsMap& connMap, const PacketBuffer& connID)
507 {
508 auto iter = connMap.find(connID);
509 if (iter == connMap.end()) {
510 return std::nullopt;
511 }
512 return iter->second;
513 }
514
515 static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description)
516 {
517 if (unit->dsc == nullptr) {
518 return;
519 }
520 try {
521 if (!unit->dsc->d_responseSender.send(std::move(unit))) {
522 ++dnsdist::metrics::g_stats.doqResponsePipeFull;
523 vinfolog("Unable to pass a %s to the DoQ worker thread because the pipe is full", description);
524 }
525 }
526 catch (const std::exception& e) {
527 vinfolog("Unable to pass a %s to the DoQ worker thread because we couldn't write to the pipe: %s", description, e.what());
528 }
529 }
530
531 static std::optional<std::reference_wrapper<Connection>> createConnection(DOQServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer)
532 {
533 auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
534 originalDestinationID.data(), originalDestinationID.size(),
535 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
536 reinterpret_cast<const struct sockaddr*>(&local),
537 local.getSocklen(),
538 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
539 reinterpret_cast<const struct sockaddr*>(&peer),
540 peer.getSocklen(),
541 config.config.get()),
542 quiche_conn_free);
543
544 if (config.df && !config.df->d_keyLogFile.empty()) {
545 quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_keyLogFile.c_str());
546 }
547
548 auto conn = Connection(peer, std::move(quicheConn));
549 auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
550 return pair.first->second;
551 }
552
553 static void flushEgress(Socket& sock, Connection& conn)
554 {
555 std::array<uint8_t, MAX_DATAGRAM_SIZE> out{};
556 quiche_send_info send_info;
557
558 while (true) {
559 auto written = quiche_conn_send(conn.d_conn.get(), out.data(), out.size(), &send_info);
560 if (written == QUICHE_ERR_DONE) {
561 return;
562 }
563
564 if (written < 0) {
565 return;
566 }
567 // FIXME pacing (as send_info.at should tell us when to send the packet) ?
568 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
569 sock.sendTo(reinterpret_cast<const char*>(out.data()), written, conn.d_peer);
570 }
571 }
572
573 std::unique_ptr<CrossProtocolQuery> getDOQCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion, bool isResponse)
574 {
575 if (!dnsQuestion.ids.doqu) {
576 throw std::runtime_error("Trying to create a DoQ cross protocol query without a valid DoQ unit");
577 }
578
579 auto unit = std::move(dnsQuestion.ids.doqu);
580 if (&dnsQuestion.ids != &unit->ids) {
581 unit->ids = std::move(dnsQuestion.ids);
582 }
583
584 unit->ids.origID = dnsQuestion.getHeader()->id;
585
586 if (!isResponse) {
587 if (unit->query.data() != dnsQuestion.getMutableData().data()) {
588 unit->query = std::move(dnsQuestion.getMutableData());
589 }
590 }
591 else {
592 if (unit->response.data() != dnsQuestion.getMutableData().data()) {
593 unit->response = std::move(dnsQuestion.getMutableData());
594 }
595 }
596
597 return std::make_unique<DOQCrossProtocolQuery>(std::move(unit), isResponse);
598 }
599
600 /*
601 We are not in the main DoQ thread but in the DoQ 'client' thread.
602 */
603 static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
604 {
605 const auto handleImmediateResponse = [](DOQUnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
606 DEBUGLOG("handleImmediateResponse() reason=" << reason);
607 auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
608 handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->response);
609 unit->ids.doqu.reset();
610 };
611
612 auto& ids = doqUnit->ids;
613 ids.doqu = std::move(doqUnit);
614 auto& unit = ids.doqu;
615 uint16_t queryId = 0;
616 ComboAddress remote;
617
618 try {
619
620 remote = unit->ids.origRemote;
621 DOQServerConfig* dsc = unit->dsc;
622 auto& holders = dsc->holders;
623 ClientState& clientState = *dsc->clientState;
624
625 if (unit->query.size() < sizeof(dnsheader)) {
626 ++dnsdist::metrics::g_stats.nonCompliantQueries;
627 ++clientState.nonCompliantQueries;
628 unit->response.clear();
629
630 handleImmediateResponse(std::move(unit), "DoQ non-compliant query");
631 return;
632 }
633
634 ++clientState.queries;
635 ++dnsdist::metrics::g_stats.queries;
636 unit->ids.queryRealTime.start();
637
638 {
639 /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
640 dnsheader_aligned dnsHeader(unit->query.data());
641
642 if (!checkQueryHeaders(dnsHeader.get(), clientState)) {
643 dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
644 header.rcode = RCode::ServFail;
645 header.qr = true;
646 return true;
647 });
648 unit->response = std::move(unit->query);
649
650 handleImmediateResponse(std::move(unit), "DoQ invalid headers");
651 return;
652 }
653
654 if (dnsHeader->qdcount == 0) {
655 dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
656 header.rcode = RCode::NotImp;
657 header.qr = true;
658 return true;
659 });
660 unit->response = std::move(unit->query);
661
662 handleImmediateResponse(std::move(unit), "DoQ empty query");
663 return;
664 }
665
666 queryId = ntohs(dnsHeader->id);
667 }
668
669 auto downstream = unit->downstream;
670 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
671 unit->ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass);
672 DNSQuestion dnsQuestion(unit->ids, unit->query);
673 dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
674 const uint16_t* flags = getFlagsFromDNSHeader(&header);
675 ids.origFlags = *flags;
676 return true;
677 });
678 unit->ids.cs = &clientState;
679
680 auto result = processQuery(dnsQuestion, holders, downstream);
681 if (result == ProcessQueryResult::Drop) {
682 handleImmediateResponse(std::move(unit), "DoQ dropped query");
683 return;
684 }
685 if (result == ProcessQueryResult::Asynchronous) {
686 return;
687 }
688 if (result == ProcessQueryResult::SendAnswer) {
689 if (unit->response.empty()) {
690 unit->response = std::move(unit->query);
691 }
692 if (unit->response.size() >= sizeof(dnsheader)) {
693 const dnsheader_aligned dnsHeader(unit->response.data());
694
695 handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoQ, dnsdist::Protocol::DoQ, false);
696 }
697 handleImmediateResponse(std::move(unit), "DoQ self-answered response");
698 return;
699 }
700
701 ++dnsdist::metrics::g_stats.responses;
702 if (unit->ids.cs != nullptr) {
703 ++unit->ids.cs->responses;
704 }
705
706 if (result != ProcessQueryResult::PassToBackend) {
707 handleImmediateResponse(std::move(unit), "DoQ no backend available");
708 return;
709 }
710
711 if (downstream == nullptr) {
712 handleImmediateResponse(std::move(unit), "DoQ no backend available");
713 return;
714 }
715
716 unit->downstream = downstream;
717
718 std::string proxyProtocolPayload;
719 /* we need to do this _before_ creating the cross protocol query because
720 after that the buffer will have been moved */
721 if (downstream->d_config.useProxyProtocol) {
722 proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
723 }
724
725 unit->ids.origID = htons(queryId);
726 unit->tcp = true;
727
728 /* this moves unit->ids, careful! */
729 auto cpq = std::make_unique<DOQCrossProtocolQuery>(std::move(unit), false);
730 cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
731
732 if (downstream->passCrossProtocolQuery(std::move(cpq))) {
733 return;
734 }
735 // NOLINTNEXTLINE(bugprone-use-after-move): it was only moved if the call succeeded
736 unit = cpq->releaseDU();
737 handleImmediateResponse(std::move(unit), "DoQ internal error");
738 return;
739 }
740 catch (const std::exception& e) {
741 vinfolog("Got an error in DOQ question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
742 handleImmediateResponse(std::move(unit), "DoQ internal error");
743 return;
744 }
745 }
746
747 static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
748 {
749 try {
750 /* we only parse it there as a sanity check, we will parse it again later */
751 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
752 DNSPacketMangler mangler(reinterpret_cast<char*>(query.data()), query.size());
753 mangler.skipDomainName();
754 mangler.skipBytes(4);
755
756 auto unit = std::make_unique<DOQUnit>(std::move(query));
757 unit->dsc = &dsc;
758 unit->ids.origDest = local;
759 unit->ids.origRemote = remote;
760 unit->ids.protocol = dnsdist::Protocol::DoQ;
761 unit->serverConnID = serverConnID;
762 unit->streamID = streamID;
763
764 processDOQQuery(std::move(unit));
765 }
766 catch (const std::exception& exp) {
767 vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what());
768 }
769 }
770
771 static void flushResponses(pdns::channel::Receiver<DOQUnit>& receiver)
772 {
773 for (;;) {
774 try {
775 auto tmp = receiver.receive();
776 if (!tmp) {
777 return;
778 }
779
780 auto unit = std::move(*tmp);
781 auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
782 if (conn) {
783 handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->response);
784 }
785 }
786 catch (const std::exception& e) {
787 errlog("Error while processing response received over DoQ: %s", e.what());
788 }
789 catch (...) {
790 errlog("Unspecified error while processing response received over DoQ");
791 }
792 }
793 }
794
795 // this is the entrypoint from dnsdist.cc
796 void doqThread(ClientState* clientState)
797 {
798 try {
799 std::shared_ptr<DOQFrontend>& frontend = clientState->doqFrontend;
800
801 frontend->d_server_config->clientState = clientState;
802 frontend->d_server_config->df = clientState->doqFrontend;
803
804 setThreadName("dnsdist/doq");
805
806 Socket sock(clientState->udpFD);
807
808 PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
809 auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
810
811 auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
812 mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
813 mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {});
814 while (true) {
815 std::vector<int> readyFDs;
816 mplexer->getAvailableFDs(readyFDs, 500);
817
818 if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
819 DEBUGLOG("Received datagram");
820 std::string bufferStr;
821 ComboAddress client;
822 sock.recvFrom(bufferStr, client);
823
824 uint32_t version{0};
825 uint8_t type{0};
826 std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> scid{};
827 size_t scid_len = scid.size();
828 std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid{};
829 size_t dcid_len = dcid.size();
830 std::array<uint8_t, MAX_TOKEN_LEN> token{};
831 size_t token_len = token.size();
832
833 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
834 auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
835 &version, &type,
836 scid.data(), &scid_len,
837 dcid.data(), &dcid_len,
838 token.data(), &token_len);
839 if (res != 0) {
840 DEBUGLOG("Error in quiche_header_info: " << res);
841 continue;
842 }
843
844 // destination connection ID, will have to be sent as original destination connection ID
845 PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
846 // source connection ID, will have to be sent as destination connection ID
847 PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
848 auto conn = getConnection(frontend->d_server_config->d_connections, serverConnID);
849
850 if (!conn) {
851 DEBUGLOG("Connection not found");
852 if (!quiche_version_is_supported(version)) {
853 DEBUGLOG("Unsupported version");
854 ++frontend->d_doqUnsupportedVersionErrors;
855 handleVersionNegociation(sock, clientConnID, serverConnID, client);
856 continue;
857 }
858
859 if (token_len == 0) {
860 /* stateless retry */
861 DEBUGLOG("No token received");
862 handleStatelessRetry(sock, clientConnID, serverConnID, client, version);
863 continue;
864 }
865
866 PacketBuffer tokenBuf(token.begin(), token.begin() + token_len);
867 auto originalDestinationID = validateToken(tokenBuf, client);
868 if (!originalDestinationID) {
869 ++frontend->d_doqInvalidTokensReceived;
870 DEBUGLOG("Discarding invalid token");
871 continue;
872 }
873
874 DEBUGLOG("Creating a new connection");
875 conn = createConnection(*frontend->d_server_config, serverConnID, *originalDestinationID, clientState->local, client);
876 if (!conn) {
877 continue;
878 }
879 }
880 DEBUGLOG("Connection found");
881 quiche_recv_info recv_info = {
882 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
883 reinterpret_cast<struct sockaddr*>(&client),
884 client.getSocklen(),
885 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
886 reinterpret_cast<struct sockaddr*>(&clientState->local),
887 clientState->local.getSocklen(),
888 };
889
890 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
891 auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.size(), &recv_info);
892 if (done < 0) {
893 continue;
894 }
895
896 if (quiche_conn_is_established(conn->get().d_conn.get())) {
897 auto readable = std::unique_ptr<quiche_stream_iter, decltype(&quiche_stream_iter_free)>(quiche_conn_readable(conn->get().d_conn.get()), quiche_stream_iter_free);
898
899 uint64_t streamID = 0;
900 while (quiche_stream_iter_next(readable.get(), &streamID)) {
901 auto& streamBuffer = conn->get().d_streamBuffers[streamID];
902 auto existingLength = streamBuffer.size();
903 bool fin = false;
904 streamBuffer.resize(existingLength + 512);
905 auto received = quiche_conn_stream_recv(conn->get().d_conn.get(), streamID,
906 &streamBuffer.at(existingLength), 512,
907 &fin);
908 streamBuffer.resize(existingLength + received);
909 if (fin) {
910 if (streamBuffer.size() < (sizeof(uint16_t) + sizeof(dnsheader))) {
911 ++dnsdist::metrics::g_stats.nonCompliantQueries;
912 ++clientState->nonCompliantQueries;
913 quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
914 break;
915 }
916 uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
917 streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
918 if (payloadLength != streamBuffer.size()) {
919 ++dnsdist::metrics::g_stats.nonCompliantQueries;
920 ++clientState->nonCompliantQueries;
921 quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
922 break;
923 }
924 DEBUGLOG("Dispatching query");
925 doq_dispatch_query(*(frontend->d_server_config), std::move(streamBuffer), clientState->local, client, serverConnID, streamID);
926 conn->get().d_streamBuffers.erase(streamID);
927 }
928 }
929 }
930 else {
931 DEBUGLOG("Connection not established");
932 }
933 }
934
935 if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {
936 flushResponses(frontend->d_server_config->d_responseReceiver);
937 }
938
939 for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
940 quiche_conn_on_timeout(conn->second.d_conn.get());
941
942 flushEgress(sock, conn->second);
943
944 if (quiche_conn_is_closed(conn->second.d_conn.get())) {
945 #ifdef DEBUGLOG_ENABLED
946 quiche_stats stats;
947 quiche_path_stats path_stats;
948
949 quiche_conn_stats(conn->second.d_conn.get(), &stats);
950 quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats);
951
952 DEBUGLOG("Connection closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
953 #endif
954 conn = frontend->d_server_config->d_connections.erase(conn);
955 }
956 else {
957 ++conn;
958 }
959 }
960 }
961 }
962 catch (const std::exception& e) {
963 DEBUGLOG("Caught fatal error: " << e.what());
964 }
965 }
966
967 #endif /* HAVE_DNS_OVER_QUIC */