]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #13381 from rgacogne/ddist-clean-up-nghttp2-no-doh
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22
23 #include <thread>
24 #include <netinet/tcp.h>
25 #include <queue>
26
27 #include "dnsdist.hh"
28 #include "dnsdist-concurrent-connections.hh"
29 #include "dnsdist-dnsparser.hh"
30 #include "dnsdist-ecs.hh"
31 #include "dnsdist-nghttp2-in.hh"
32 #include "dnsdist-proxy-protocol.hh"
33 #include "dnsdist-rings.hh"
34 #include "dnsdist-tcp.hh"
35 #include "dnsdist-tcp-downstream.hh"
36 #include "dnsdist-downstream-connection.hh"
37 #include "dnsdist-tcp-upstream.hh"
38 #include "dnsdist-xpf.hh"
39 #include "dnsparser.hh"
40 #include "dolog.hh"
41 #include "gettime.hh"
42 #include "lock.hh"
43 #include "sstuff.hh"
44 #include "tcpiohandler.hh"
45 #include "tcpiohandler-mplexer.hh"
46 #include "threadname.hh"
47
48 /* TCP: the grand design.
49 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
50 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
51 we will not go there.
52
53 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
54 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
55 to guarantee performance.
56
57 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
58 So whenever an answer comes in, we know where it needs to go.
59
60 Let's start naively.
61 */
62
63 size_t g_maxTCPQueriesPerConn{0};
64 size_t g_maxTCPConnectionDuration{0};
65
66 #ifdef __linux__
67 // On Linux this gives us 128k pending queries (default is 8192 queries),
68 // which should be enough to deal with huge spikes
69 size_t g_tcpInternalPipeBufferSize{1024*1024};
70 uint64_t g_maxTCPQueuedConnections{10000};
71 #else
72 size_t g_tcpInternalPipeBufferSize{0};
73 uint64_t g_maxTCPQueuedConnections{1000};
74 #endif
75
76 int g_tcpRecvTimeout{2};
77 int g_tcpSendTimeout{2};
78 std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
79
80 LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> dnsdist::IncomingConcurrentTCPConnectionsManager::s_tcpClientsConcurrentConnectionsCount;
81 size_t dnsdist::IncomingConcurrentTCPConnectionsManager::s_maxTCPConnectionsPerClient = 0;
82
83 IncomingTCPConnectionState::~IncomingTCPConnectionState()
84 {
85 dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote);
86
87 if (d_ci.cs != nullptr) {
88 struct timeval now;
89 gettimeofday(&now, nullptr);
90
91 auto diff = now - d_connectionStartTime;
92 d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
93 }
94
95 // would have been done when the object is destroyed anyway,
96 // but that way we make sure it's done before the ConnectionInfo is destroyed,
97 // closing the descriptor, instead of relying on the declaration order of the objects in the class
98 d_handler.close();
99 }
100
101 dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const
102 {
103 if (d_ci.cs->dohFrontend) {
104 return dnsdist::Protocol::DoH;
105 }
106 if (d_handler.isTLS()) {
107 return dnsdist::Protocol::DoT;
108 }
109 return dnsdist::Protocol::DoTCP;
110 }
111
112 size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
113 {
114 return t_downstreamTCPConnectionsManager.clear();
115 }
116
117 std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
118 {
119 std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
120
121 downstream = getOwnedDownstreamConnection(ds, tlvs);
122
123 if (!downstream) {
124 /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */
125 downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, ds, now, std::string());
126 if (ds->d_config.useProxyProtocol) {
127 registerOwnedDownstreamConnection(downstream);
128 }
129 }
130
131 return downstream;
132 }
133
134 static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates);
135
136 TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
137 {
138 for (size_t idx = 0; idx < maxThreads; idx++) {
139 addTCPClientThread(tcpAcceptStates);
140 }
141 }
142
143 void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates)
144 {
145 try {
146 auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue<ConnectionInfo>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize);
147
148 auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize);
149
150 auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue<TCPCrossProtocolResponse>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize);
151
152 vinfolog("Adding TCP Client thread");
153
154 if (d_numthreads >= d_tcpclientthreads.size()) {
155 vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
156 return;
157 }
158
159 TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender));
160
161 try {
162 std::thread t1(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates);
163 t1.detach();
164 }
165 catch (const std::runtime_error& e) {
166 errlog("Error creating a TCP thread: %s", e.what());
167 return;
168 }
169
170 d_tcpclientthreads.at(d_numthreads) = std::move(worker);
171 ++d_numthreads;
172 }
173 catch (const std::exception& e) {
174 errlog("Error creating TCP worker: %", e.what());
175 }
176 }
177
178 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
179
180 static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
181 {
182 IOState result = IOState::Done;
183
184 while (state->active() && !state->d_queuedResponses.empty()) {
185 DEBUGLOG("queue size is "<<state->d_queuedResponses.size()<<", sending the next one");
186 TCPResponse resp = std::move(state->d_queuedResponses.front());
187 state->d_queuedResponses.pop_front();
188 state->d_state = IncomingTCPConnectionState::State::idle;
189 result = state->sendResponse(now, std::move(resp));
190 if (result != IOState::Done) {
191 return result;
192 }
193 }
194
195 state->d_state = IncomingTCPConnectionState::State::idle;
196 return IOState::Done;
197 }
198
199 void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse)
200 {
201 if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
202 return;
203 }
204
205 --d_currentQueriesCount;
206
207 const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
208 if (currentResponse.d_idstate.selfGenerated == false && ds) {
209 const auto& ids = currentResponse.d_idstate;
210 double udiff = ids.queryRealTime.udiff();
211 vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff);
212
213 auto backendProtocol = ds->getProtocol();
214 if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) {
215 backendProtocol = dnsdist::Protocol::DoTCP;
216 }
217 ::handleResponseSent(ids, udiff, d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true);
218 } else {
219 const auto& ids = currentResponse.d_idstate;
220 ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
221 }
222
223 currentResponse.d_buffer.clear();
224 currentResponse.d_connection.reset();
225 }
226
227 static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayloadSize)
228 {
229 if (buffer.size() <= proxyProtocolPayloadSize) {
230 throw std::runtime_error("The payload size is smaller or equal to the buffer size");
231 }
232
233 uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size();
234 const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) };
235 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
236 that could occur if we had to deal with the size during the processing,
237 especially alignment issues */
238 buffer.insert(buffer.begin() + proxyProtocolPayloadSize, sizeBytes, sizeBytes + 2);
239 }
240
241 bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
242 {
243 if (d_hadErrors) {
244 DEBUGLOG("not accepting new queries because we encountered some error during the processing already");
245 return false;
246 }
247
248 // for DoH, this is already handled by the underlying library
249 if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
250 DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn);
251 return false;
252 }
253
254 if (g_maxTCPQueriesPerConn && d_queriesCount > g_maxTCPQueriesPerConn) {
255 vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, g_maxTCPQueriesPerConn);
256 return false;
257 }
258
259 if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
260 vinfolog("not accepting new queries from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
261 return false;
262 }
263
264 return true;
265 }
266
267 void IncomingTCPConnectionState::resetForNewQuery()
268 {
269 d_buffer.clear();
270 d_currentPos = 0;
271 d_querySize = 0;
272 d_state = State::waitingForQuery;
273 }
274
275 std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
276 {
277 auto it = d_ownedConnectionsToBackend.find(ds);
278 if (it == d_ownedConnectionsToBackend.end()) {
279 DEBUGLOG("no owned connection found for "<<ds->getName());
280 return nullptr;
281 }
282
283 for (auto& conn : it->second) {
284 if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) {
285 DEBUGLOG("Got one owned connection accepting more for "<<ds->getName());
286 conn->setReused();
287 return conn;
288 }
289 DEBUGLOG("not accepting more for "<<ds->getName());
290 }
291
292 return nullptr;
293 }
294
295 void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
296 {
297 d_ownedConnectionsToBackend[conn->getDS()].push_front(conn);
298 }
299
300 /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
301 IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
302 {
303 d_state = State::sendingResponse;
304
305 uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
306 const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
307 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
308 that could occur if we had to deal with the size during the processing,
309 especially alignment issues */
310 response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
311 d_currentPos = 0;
312 d_currentResponse = std::move(response);
313
314 try {
315 auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
316 if (iostate == IOState::Done) {
317 DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
318 handleResponseSent(d_currentResponse);
319 return iostate;
320 } else {
321 d_lastIOBlocked = true;
322 DEBUGLOG("partial write");
323 return iostate;
324 }
325 }
326 catch (const std::exception& e) {
327 vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
328 DEBUGLOG("Closing TCP client connection: "<<e.what());
329 ++d_ci.cs->tcpDiedSendingResponse;
330
331 terminateClientConnection();
332
333 return IOState::Done;
334 }
335 }
336
337 void IncomingTCPConnectionState::terminateClientConnection()
338 {
339 DEBUGLOG("terminating client connection");
340 d_queuedResponses.clear();
341 /* we have already released idle connections that could be reused,
342 we don't care about the ones still waiting for responses */
343 for (auto& backend : d_ownedConnectionsToBackend) {
344 for (auto& conn : backend.second) {
345 conn->release();
346 }
347 }
348 d_ownedConnectionsToBackend.clear();
349
350 /* meaning we will no longer be 'active' when the backend
351 response or timeout comes in */
352 d_ioState.reset();
353
354 /* if we do have remaining async descriptors associated with this TLS
355 connection, we need to defer the destruction of the TLS object until
356 the engine has reported back, otherwise we have a use-after-free.. */
357 auto afds = d_handler.getAsyncFDs();
358 if (afds.empty()) {
359 d_handler.close();
360 }
361 else {
362 /* we might already be waiting, but we might also not because sometimes we have already been
363 notified via the descriptor, not received Async again, but the async job still exists.. */
364 auto state = shared_from_this();
365 for (const auto fd : afds) {
366 try {
367 state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state);
368 }
369 catch (...) {
370 }
371 }
372
373 }
374 }
375
376 void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend)
377 {
378 // queue response
379 state->d_queuedResponses.emplace_back(std::move(response));
380 DEBUGLOG("queueing response, state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size());
381
382 // when the response comes from a backend, there is a real possibility that we are currently
383 // idle, and thus not trying to send the response right away would make our ref count go to 0.
384 // Even if we are waiting for a query, we will not wake up before the new query arrives or a
385 // timeout occurs
386 if (state->d_state == State::idle ||
387 state->d_state == State::waitingForQuery) {
388 auto iostate = sendQueuedResponses(state, now);
389
390 if (iostate == IOState::Done && state->active()) {
391 if (state->canAcceptNewQueries(now)) {
392 state->resetForNewQuery();
393 state->d_state = State::waitingForQuery;
394 iostate = IOState::NeedRead;
395 }
396 else {
397 state->d_state = State::idle;
398 }
399 }
400
401 // for the same reason we need to update the state right away, nobody will do that for us
402 if (state->active()) {
403 updateIO(state, iostate, now);
404 // if we have not finished reading every available byte, we _need_ to do an actual read
405 // attempt before waiting for the socket to become readable again, because if there is
406 // buffered data available the socket might never become readable again.
407 // This is true as soon as we deal with TLS because TLS records are processed one by
408 // one and might not match what we see at the application layer, so data might already
409 // be available in the TLS library's buffers. This is especially true when OpenSSL's
410 // read-ahead mode is enabled because then it buffers even more than one SSL record
411 // for performance reasons.
412 if (fromBackend && !state->d_lastIOBlocked) {
413 state->handleIO();
414 }
415 }
416 }
417 }
418
419 void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param)
420 {
421 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
422
423 /* If we are here, the async jobs for this SSL* are finished
424 so we should be able to remove all FDs */
425 auto afds = state->d_handler.getAsyncFDs();
426 for (const auto afd : afds) {
427 try {
428 state->d_threadData.mplexer->removeReadFD(afd);
429 }
430 catch (...) {
431 }
432 }
433
434 if (state->active()) {
435 /* and now we restart our own I/O state machine */
436 state->handleIO();
437 }
438 else {
439 /* we were only waiting for the engine to come back,
440 to prevent a use-after-free */
441 state->d_handler.close();
442 }
443 }
444
445 void IncomingTCPConnectionState::updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now)
446 {
447 if (newState == IOState::Async) {
448 auto fds = state->d_handler.getAsyncFDs();
449 for (const auto fd : fds) {
450 state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state);
451 }
452 state->d_ioState->update(IOState::Done, handleIOCallback, state);
453 }
454 else {
455 state->d_ioState->update(newState, handleIOCallback, state, newState == IOState::NeedWrite ? state->getClientWriteTTD(now) : state->getClientReadTTD(now));
456 }
457 }
458
459 /* called from the backend code when a new response has been received */
460 void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
461 {
462 if (std::this_thread::get_id() != d_creatorThreadID) {
463 handleCrossProtocolResponse(now, std::move(response));
464 return;
465 }
466
467 std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
468
469 if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
470 // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway
471 if (!response.d_connection->willBeReusable(true)) {
472 // if it can't be reused even by us, well
473 const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS());
474 if (connIt != state->d_ownedConnectionsToBackend.end()) {
475 auto& list = connIt->second;
476
477 for (auto it = list.begin(); it != list.end(); ++it) {
478 if (*it == response.d_connection) {
479 try {
480 response.d_connection->release();
481 }
482 catch (const std::exception& e) {
483 vinfolog("Error releasing connection: %s", e.what());
484 }
485 list.erase(it);
486 break;
487 }
488 }
489 }
490 }
491 }
492
493 if (response.d_buffer.size() < sizeof(dnsheader)) {
494 state->terminateClientConnection();
495 return;
496 }
497
498 if (!response.isAsync()) {
499 try {
500 auto& ids = response.d_idstate;
501 unsigned int qnameWireLength;
502 std::shared_ptr<DownstreamState> ds = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
503 if (!ds || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, ds, qnameWireLength)) {
504 state->terminateClientConnection();
505 return;
506 }
507
508 if (ds) {
509 ++ds->responses;
510 }
511
512 DNSResponse dr(ids, response.d_buffer, ds);
513 dr.d_incomingTCPState = state;
514
515 memcpy(&response.d_cleartextDH, dr.getHeader().get(), sizeof(response.d_cleartextDH));
516
517 if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) {
518 state->terminateClientConnection();
519 return;
520 }
521
522 if (dr.isAsynchronous()) {
523 /* we are done for now */
524 return;
525 }
526 }
527 catch (const std::exception& e) {
528 vinfolog("Unexpected exception while handling response from backend: %s", e.what());
529 state->terminateClientConnection();
530 return;
531 }
532 }
533
534 ++dnsdist::metrics::g_stats.responses;
535 ++state->d_ci.cs->responses;
536
537 queueResponse(state, now, std::move(response), true);
538 }
539
540 struct TCPCrossProtocolResponse
541 {
542 TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now)
543 {
544 }
545
546 TCPResponse d_response;
547 std::shared_ptr<IncomingTCPConnectionState> d_state;
548 struct timeval d_now;
549 };
550
551 class TCPCrossProtocolQuery : public CrossProtocolQuery
552 {
553 public:
554 TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
555 {
556 }
557
558 ~TCPCrossProtocolQuery()
559 {
560 }
561
562 std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
563 {
564 return d_sender;
565 }
566
567 DNSQuestion getDQ() override
568 {
569 auto& ids = query.d_idstate;
570 DNSQuestion dq(ids, query.d_buffer);
571 dq.d_incomingTCPState = d_sender;
572 return dq;
573 }
574
575 DNSResponse getDR() override
576 {
577 auto& ids = query.d_idstate;
578 DNSResponse dr(ids, query.d_buffer, downstream);
579 dr.d_incomingTCPState = d_sender;
580 return dr;
581 }
582
583 private:
584 std::shared_ptr<IncomingTCPConnectionState> d_sender;
585 };
586
587 std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds)
588 {
589 return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), ds, shared_from_this());
590 }
591
592 std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq)
593 {
594 auto state = dq.getIncomingTCPState();
595 if (!state) {
596 throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state");
597 }
598
599 dq.ids.origID = dq.getHeader()->id;
600 return std::make_unique<TCPCrossProtocolQuery>(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state));
601 }
602
603 void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
604 {
605 std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
606 try {
607 auto ptr = std::make_unique<TCPCrossProtocolResponse>(std::move(response), state, now);
608 if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) {
609 ++dnsdist::metrics::g_stats.tcpCrossProtocolResponsePipeFull;
610 vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
611 }
612 }
613 catch (const std::exception& e) {
614 vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
615 }
616 }
617
618 IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID)
619 {
620 auto query = std::move(queryIn);
621 if (query.size() < sizeof(dnsheader)) {
622 ++dnsdist::metrics::g_stats.nonCompliantQueries;
623 ++d_ci.cs->nonCompliantQueries;
624 return QueryProcessingResult::TooSmall;
625 }
626
627 ++d_queriesCount;
628 ++d_ci.cs->queries;
629 ++dnsdist::metrics::g_stats.queries;
630
631 if (d_handler.isTLS()) {
632 auto tlsVersion = d_handler.getTLSVersion();
633 switch (tlsVersion) {
634 case LibsslTLSVersion::TLS10:
635 ++d_ci.cs->tls10queries;
636 break;
637 case LibsslTLSVersion::TLS11:
638 ++d_ci.cs->tls11queries;
639 break;
640 case LibsslTLSVersion::TLS12:
641 ++d_ci.cs->tls12queries;
642 break;
643 case LibsslTLSVersion::TLS13:
644 ++d_ci.cs->tls13queries;
645 break;
646 default:
647 ++d_ci.cs->tlsUnknownqueries;
648 }
649 }
650
651 auto state = shared_from_this();
652 InternalQueryState ids;
653 ids.origDest = d_proxiedDestination;
654 ids.origRemote = d_proxiedRemote;
655 ids.cs = d_ci.cs;
656 ids.queryRealTime.start();
657 if (streamID) {
658 ids.d_streamID = *streamID;
659 }
660
661 auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
662 if (dnsCryptResponse) {
663 TCPResponse response;
664 d_state = State::idle;
665 ++d_currentQueriesCount;
666 queueResponse(state, now, std::move(response), false);
667 return QueryProcessingResult::SelfAnswered;
668 }
669
670 {
671 /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
672 const dnsheader_aligned dh(query.data());
673 if (!checkQueryHeaders(dh.get(), *d_ci.cs)) {
674 return QueryProcessingResult::InvalidHeaders;
675 }
676
677 if (dh->qdcount == 0) {
678 TCPResponse response;
679 auto queryID = dh->id;
680 dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
681 header.rcode = RCode::NotImp;
682 header.qr = true;
683 return true;
684 });
685 response.d_idstate = std::move(ids);
686 response.d_idstate.origID = queryID;
687 response.d_idstate.selfGenerated = true;
688 response.d_buffer = std::move(query);
689 d_state = State::idle;
690 ++d_currentQueriesCount;
691 queueResponse(state, now, std::move(response), false);
692 return QueryProcessingResult::SelfAnswered;
693 }
694 }
695
696 ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
697 ids.protocol = getProtocol();
698 if (ids.dnsCryptQuery) {
699 ids.protocol = dnsdist::Protocol::DNSCryptTCP;
700 }
701
702 DNSQuestion dq(ids, query);
703 dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&ids](dnsheader& header) {
704 const uint16_t* flags = getFlagsFromDNSHeader(&header);
705 ids.origFlags = *flags;
706 return true;
707 });
708 dq.d_incomingTCPState = state;
709 dq.sni = d_handler.getServerNameIndication();
710
711 if (d_proxyProtocolValues) {
712 /* we need to copy them, because the next queries received on that connection will
713 need to get the _unaltered_ values */
714 dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues);
715 }
716
717 if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) {
718 dq.ids.skipCache = true;
719 }
720
721 if (forwardViaUDPFirst()) {
722 // if there was no EDNS, we add it with a large buffer size
723 // so we can use UDP to talk to the backend.
724 const dnsheader_aligned dh(query.data());
725 if (!dh->arcount) {
726 if (addEDNS(query, 4096, false, 4096, 0)) {
727 dq.ids.ednsAdded = true;
728 }
729 }
730 }
731
732 if (streamID) {
733 auto unit = getDOHUnit(*streamID);
734 dq.ids.du = std::move(unit);
735 }
736
737 std::shared_ptr<DownstreamState> ds;
738 auto result = processQuery(dq, d_threadData.holders, ds);
739
740 if (result == ProcessQueryResult::Asynchronous) {
741 /* we are done for now */
742 ++d_currentQueriesCount;
743 return QueryProcessingResult::Asynchronous;
744 }
745
746 if (streamID) {
747 restoreDOHUnit(std::move(dq.ids.du));
748 }
749
750 if (result == ProcessQueryResult::Drop) {
751 return QueryProcessingResult::Dropped;
752 }
753
754 // the buffer might have been invalidated by now
755 uint16_t queryID;
756 {
757 const auto dh = dq.getHeader();
758 queryID = dh->id;
759 }
760
761 if (result == ProcessQueryResult::SendAnswer) {
762 TCPResponse response;
763 {
764 const auto dh = dq.getHeader();
765 memcpy(&response.d_cleartextDH, dh.get(), sizeof(response.d_cleartextDH));
766 }
767 response.d_idstate = std::move(ids);
768 response.d_idstate.origID = queryID;
769 response.d_idstate.selfGenerated = true;
770 response.d_idstate.cs = d_ci.cs;
771 response.d_buffer = std::move(query);
772
773 d_state = State::idle;
774 ++d_currentQueriesCount;
775 queueResponse(state, now, std::move(response), false);
776 return QueryProcessingResult::SelfAnswered;
777 }
778
779 if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
780 return QueryProcessingResult::NoBackend;
781 }
782
783 dq.ids.origID = queryID;
784
785 ++d_currentQueriesCount;
786
787 std::string proxyProtocolPayload;
788 if (ds->isDoH()) {
789 vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), ds->getNameWithAddr());
790
791 /* we need to do this _before_ creating the cross protocol query because
792 after that the buffer will have been moved */
793 if (ds->d_config.useProxyProtocol) {
794 proxyProtocolPayload = getProxyProtocolPayload(dq);
795 }
796
797 auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), ds, state);
798 cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
799
800 ds->passCrossProtocolQuery(std::move(cpq));
801 return QueryProcessingResult::Forwarded;
802 }
803 else if (!ds->isTCPOnly() && forwardViaUDPFirst()) {
804 auto unit = getDOHUnit(*streamID);
805 dq.ids.du = std::move(unit);
806 if (assignOutgoingUDPQueryToBackend(ds, queryID, dq, query)) {
807 return QueryProcessingResult::Forwarded;
808 }
809 restoreDOHUnit(std::move(dq.ids.du));
810 // fallback to the normal flow
811 }
812
813 prependSizeToTCPQuery(query, 0);
814
815 auto downstreamConnection = getDownstreamConnection(ds, dq.proxyProtocolValues, now);
816
817 if (ds->d_config.useProxyProtocol) {
818 /* if we ever sent a TLV over a connection, we can never go back */
819 if (!d_proxyProtocolPayloadHasTLV) {
820 d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
821 }
822
823 proxyProtocolPayload = getProxyProtocolPayload(dq);
824 }
825
826 if (dq.proxyProtocolValues) {
827 downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
828 }
829
830 TCPQuery tcpquery(std::move(query), std::move(ids));
831 tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
832
833 vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), ds->getNameWithAddr());
834 std::shared_ptr<TCPQuerySender> incoming = state;
835 downstreamConnection->queueQuery(incoming, std::move(tcpquery));
836 return QueryProcessingResult::Forwarded;
837 }
838
839 void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
840 {
841 auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
842 if (fd != conn->d_handler.getDescriptor()) {
843 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
844 }
845
846 conn->handleIO();
847 }
848
849 void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now)
850 {
851 if (d_handler.isTLS()) {
852 if (!d_handler.hasTLSSessionBeenResumed()) {
853 ++d_ci.cs->tlsNewSessions;
854 }
855 else {
856 ++d_ci.cs->tlsResumptions;
857 }
858 if (d_handler.getResumedFromInactiveTicketKey()) {
859 ++d_ci.cs->tlsInactiveTicketKey;
860 }
861 if (d_handler.getUnknownTicketKey()) {
862 ++d_ci.cs->tlsUnknownTicketKey;
863 }
864 }
865
866 d_handshakeDoneTime = now;
867 }
868
869 IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload()
870 {
871 do {
872 DEBUGLOG("reading proxy protocol header");
873 auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS());
874 if (iostate == IOState::Done) {
875 d_buffer.resize(d_currentPos);
876 ssize_t remaining = isProxyHeaderComplete(d_buffer);
877 if (remaining == 0) {
878 vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort());
879 ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
880 return ProxyProtocolResult::Error;
881 }
882 else if (remaining < 0) {
883 d_proxyProtocolNeed += -remaining;
884 d_buffer.resize(d_currentPos + d_proxyProtocolNeed);
885 /* we need to keep reading, since we might have buffered data */
886 }
887 else {
888 /* proxy header received */
889 std::vector<ProxyProtocolValue> proxyProtocolValues;
890 if (!handleProxyProtocol(d_ci.remote, true, *d_threadData.holders.acl, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) {
891 vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort());
892 return ProxyProtocolResult::Error;
893 }
894
895 if (!proxyProtocolValues.empty()) {
896 d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
897 }
898
899 return ProxyProtocolResult::Done;
900 }
901 }
902 else {
903 d_lastIOBlocked = true;
904 }
905 }
906 while (active() && !d_lastIOBlocked);
907
908 return ProxyProtocolResult::Reading;
909 }
910
911 IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now)
912 {
913 DEBUGLOG("doing handshake");
914 auto iostate = d_handler.tryHandshake();
915 if (iostate == IOState::Done) {
916 DEBUGLOG("handshake done");
917 handleHandshakeDone(now);
918
919 if (!isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
920 d_state = State::readingProxyProtocolHeader;
921 d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
922 d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
923 }
924 else {
925 d_state = State::readingQuerySize;
926 }
927 }
928 else {
929 d_lastIOBlocked = true;
930 }
931
932 return iostate;
933 }
934
935 void IncomingTCPConnectionState::handleIO()
936 {
937 // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
938 // even though the underlying socket is not ready, so we need to actually ask for the data first
939 IOState iostate = IOState::Done;
940 struct timeval now;
941 gettimeofday(&now, nullptr);
942
943 do {
944 iostate = IOState::Done;
945 IOStateGuard ioGuard(d_ioState);
946
947 if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
948 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
949 // will be handled by the ioGuard
950 //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
951 return;
952 }
953
954 d_lastIOBlocked = false;
955
956 try {
957 if (d_state == State::starting) {
958 if (isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
959 d_state = State::readingProxyProtocolHeader;
960 d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
961 d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
962 }
963 else {
964 d_state = State::doingHandshake;
965 }
966 }
967
968 if (d_state == State::doingHandshake) {
969 iostate = handleHandshake(now);
970 }
971
972 if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) {
973 auto status = handleProxyProtocolPayload();
974 if (status == ProxyProtocolResult::Done) {
975 if (isProxyPayloadOutsideTLS()) {
976 d_state = State::doingHandshake;
977 iostate = handleHandshake(now);
978 }
979 else {
980 d_state = State::readingQuerySize;
981 d_buffer.resize(sizeof(uint16_t));
982 d_currentPos = 0;
983 d_proxyProtocolNeed = 0;
984 }
985 }
986 else if (status == ProxyProtocolResult::Error) {
987 iostate = IOState::Done;
988 }
989 else {
990 iostate = IOState::NeedRead;
991 }
992 }
993
994 if (!d_lastIOBlocked && (d_state == State::waitingForQuery ||
995 d_state == State::readingQuerySize)) {
996 DEBUGLOG("reading query size");
997 d_buffer.resize(sizeof(uint16_t));
998 iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
999 if (d_currentPos > 0) {
1000 /* if we got at least one byte, we can't go around sending responses */
1001 d_state = State::readingQuerySize;
1002 }
1003
1004 if (iostate == IOState::Done) {
1005 DEBUGLOG("query size received");
1006 d_state = State::readingQuery;
1007 d_querySizeReadTime = now;
1008 if (d_queriesCount == 0) {
1009 d_firstQuerySizeReadTime = now;
1010 }
1011 d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1);
1012 if (d_querySize < sizeof(dnsheader)) {
1013 /* go away */
1014 terminateClientConnection();
1015 return;
1016 }
1017
1018 /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
1019 or to add ECS without allocating a new buffer */
1020 d_buffer.resize(std::max(d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
1021 d_currentPos = 0;
1022 }
1023 else {
1024 d_lastIOBlocked = true;
1025 }
1026 }
1027
1028 if (!d_lastIOBlocked && d_state == State::readingQuery) {
1029 DEBUGLOG("reading query");
1030 iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
1031 if (iostate == IOState::Done) {
1032 DEBUGLOG("query received");
1033 d_buffer.resize(d_querySize);
1034
1035 d_state = State::idle;
1036 auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
1037 switch (processingResult) {
1038 case QueryProcessingResult::TooSmall:
1039 /* fall-through */
1040 case QueryProcessingResult::InvalidHeaders:
1041 /* fall-through */
1042 case QueryProcessingResult::Dropped:
1043 /* fall-through */
1044 case QueryProcessingResult::NoBackend:
1045 terminateClientConnection();
1046 break;
1047 default:
1048 break;
1049 }
1050
1051 /* the state might have been updated in the meantime, we don't want to override it
1052 in that case */
1053 if (active() && d_state != State::idle) {
1054 if (d_ioState->isWaitingForRead()) {
1055 iostate = IOState::NeedRead;
1056 }
1057 else if (d_ioState->isWaitingForWrite()) {
1058 iostate = IOState::NeedWrite;
1059 }
1060 else {
1061 iostate = IOState::Done;
1062 }
1063 }
1064 }
1065 else {
1066 d_lastIOBlocked = true;
1067 }
1068 }
1069
1070 if (!d_lastIOBlocked && d_state == State::sendingResponse) {
1071 DEBUGLOG("sending response");
1072 iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
1073 if (iostate == IOState::Done) {
1074 DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
1075 handleResponseSent(d_currentResponse);
1076 d_state = State::idle;
1077 }
1078 else {
1079 d_lastIOBlocked = true;
1080 }
1081 }
1082
1083 if (active() &&
1084 !d_lastIOBlocked &&
1085 iostate == IOState::Done &&
1086 (d_state == State::idle ||
1087 d_state == State::waitingForQuery))
1088 {
1089 // try sending queued responses
1090 DEBUGLOG("send responses, if any");
1091 auto state = shared_from_this();
1092 iostate = sendQueuedResponses(state, now);
1093
1094 if (!d_lastIOBlocked && active() && iostate == IOState::Done) {
1095 // if the query has been passed to a backend, or dropped, and the responses have been sent,
1096 // we can start reading again
1097 if (canAcceptNewQueries(now)) {
1098 resetForNewQuery();
1099 iostate = IOState::NeedRead;
1100 }
1101 else {
1102 d_state = State::idle;
1103 iostate = IOState::Done;
1104 }
1105 }
1106 }
1107
1108 if (d_state != State::idle &&
1109 d_state != State::doingHandshake &&
1110 d_state != State::readingProxyProtocolHeader &&
1111 d_state != State::waitingForQuery &&
1112 d_state != State::readingQuerySize &&
1113 d_state != State::readingQuery &&
1114 d_state != State::sendingResponse) {
1115 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
1116 }
1117 }
1118 catch (const std::exception& e) {
1119 /* most likely an EOF because the other end closed the connection,
1120 but it might also be a real IO error or something else.
1121 Let's just drop the connection
1122 */
1123 if (d_state == State::idle ||
1124 d_state == State::waitingForQuery) {
1125 /* no need to increase any counters in that case, the client is simply done with us */
1126 }
1127 else if (d_state == State::doingHandshake ||
1128 d_state != State::readingProxyProtocolHeader ||
1129 d_state == State::waitingForQuery ||
1130 d_state == State::readingQuerySize ||
1131 d_state == State::readingQuery) {
1132 ++d_ci.cs->tcpDiedReadingQuery;
1133 }
1134 else if (d_state == State::sendingResponse) {
1135 /* unlikely to happen here, the exception should be handled in sendResponse() */
1136 ++d_ci.cs->tcpDiedSendingResponse;
1137 }
1138
1139 if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) {
1140 DEBUGLOG("Got an exception while handling TCP query: "<<e.what());
1141 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), e.what());
1142 }
1143 else {
1144 vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
1145 DEBUGLOG("Closing TCP client connection: "<<e.what());
1146 }
1147 /* remove this FD from the IO multiplexer */
1148 terminateClientConnection();
1149 }
1150
1151 if (!active()) {
1152 DEBUGLOG("state is no longer active");
1153 return;
1154 }
1155
1156 auto state = shared_from_this();
1157 if (iostate == IOState::Done) {
1158 d_ioState->update(iostate, handleIOCallback, state);
1159 }
1160 else {
1161 updateIO(state, iostate, now);
1162 }
1163 ioGuard.release();
1164 }
1165 while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked);
1166 }
1167
1168 void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response)
1169 {
1170 if (std::this_thread::get_id() != d_creatorThreadID) {
1171 /* empty buffer will signal an IO error */
1172 response.d_buffer.clear();
1173 handleCrossProtocolResponse(now, std::move(response));
1174 return;
1175 }
1176
1177 std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
1178 --state->d_currentQueriesCount;
1179 state->d_hadErrors = true;
1180
1181 if (state->d_state == State::sendingResponse) {
1182 /* if we have responses to send, let's do that first */
1183 }
1184 else if (!state->d_queuedResponses.empty()) {
1185 /* stop reading and send what we have */
1186 try {
1187 auto iostate = sendQueuedResponses(state, now);
1188
1189 if (state->active() && iostate != IOState::Done) {
1190 // we need to update the state right away, nobody will do that for us
1191 updateIO(state, iostate, now);
1192 }
1193 }
1194 catch (const std::exception& e) {
1195 vinfolog("Exception in notifyIOError: %s", e.what());
1196 }
1197 }
1198 else {
1199 // the backend code already tried to reconnect if it was possible
1200 state->terminateClientConnection();
1201 }
1202 }
1203
1204 void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response)
1205 {
1206 if (std::this_thread::get_id() != d_creatorThreadID) {
1207 handleCrossProtocolResponse(now, std::move(response));
1208 return;
1209 }
1210
1211 std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
1212 queueResponse(state, now, std::move(response), true);
1213 }
1214
1215 void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write)
1216 {
1217 vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort());
1218 DEBUGLOG("client timeout");
1219 DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_ownedConnectionsToBackend.size()<<" owned connections, "<<state->d_queuedResponses.size()<<" response queued");
1220
1221 if (write || state->d_currentQueriesCount == 0) {
1222 ++state->d_ci.cs->tcpClientTimeouts;
1223 state->d_ioState.reset();
1224 }
1225 else {
1226 DEBUGLOG("Going idle");
1227 /* we still have some queries in flight, let's just stop reading for now */
1228 state->d_state = State::idle;
1229 state->d_ioState->update(IOState::Done, handleIOCallback, state);
1230 }
1231 }
1232
1233 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1234 {
1235 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1236
1237 std::unique_ptr<ConnectionInfo> citmp{nullptr};
1238 try {
1239 auto tmp = threadData->queryReceiver.receive();
1240 if (!tmp) {
1241 return;
1242 }
1243 citmp = std::move(*tmp);
1244 }
1245 catch (const std::exception& e) {
1246 throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what()));
1247 }
1248
1249 g_tcpclientthreads->decrementQueuedCount();
1250
1251 struct timeval now;
1252 gettimeofday(&now, nullptr);
1253
1254 if (citmp->cs->dohFrontend) {
1255 #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1256 auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now);
1257 state->handleIO();
1258 #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1259 }
1260 else {
1261 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1262 state->handleIO();
1263 }
1264 }
1265
1266 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1267 {
1268 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1269
1270 std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
1271 try {
1272 auto tmp = threadData->crossProtocolQueryReceiver.receive();
1273 if (!tmp) {
1274 return;
1275 }
1276 cpq = std::move(*tmp);
1277 }
1278 catch (const std::exception& e) {
1279 throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what()));
1280 }
1281
1282 struct timeval now;
1283 gettimeofday(&now, nullptr);
1284
1285 std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
1286 auto query = std::move(cpq->query);
1287 auto downstreamServer = std::move(cpq->downstream);
1288
1289 try {
1290 auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
1291
1292 prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize);
1293
1294 vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
1295
1296 downstream->queueQuery(tqs, std::move(query));
1297 }
1298 catch (...) {
1299 tqs->notifyIOError(now, std::move(query));
1300 }
1301 }
1302
1303 static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
1304 {
1305 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1306
1307 std::unique_ptr<TCPCrossProtocolResponse> cpr{nullptr};
1308 try {
1309 auto tmp = threadData->crossProtocolResponseReceiver.receive();
1310 if (!tmp) {
1311 return;
1312 }
1313 cpr = std::move(*tmp);
1314 }
1315 catch (const std::exception& e) {
1316 throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what()));
1317 }
1318
1319 auto response = std::move(*cpr);
1320
1321 try {
1322 if (response.d_response.d_buffer.empty()) {
1323 response.d_state->notifyIOError(response.d_now, std::move(response.d_response));
1324 }
1325 else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
1326 response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
1327 }
1328 else {
1329 response.d_state->handleResponse(response.d_now, std::move(response.d_response));
1330 }
1331 }
1332 catch (...) {
1333 /* no point bubbling up from there */
1334 }
1335 }
1336
1337 struct TCPAcceptorParam
1338 {
1339 ClientState& cs;
1340 ComboAddress local;
1341 LocalStateHolder<NetmaskGroup>& acl;
1342 int socket{-1};
1343 };
1344
1345 static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData);
1346
1347 static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates)
1348 {
1349 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1350 from that point on */
1351
1352 setThreadName("dnsdist/tcpClie");
1353
1354 try {
1355 TCPClientThreadData data;
1356 data.crossProtocolResponseSender = std::move(crossProtocolResponseSender);
1357 data.queryReceiver = std::move(queryReceiver);
1358 data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver);
1359 data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver);
1360
1361 data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data);
1362 data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data);
1363 data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data);
1364
1365 /* only used in single acceptor mode for now */
1366 auto acl = g_ACL.getLocal();
1367 std::vector<TCPAcceptorParam> acceptParams;
1368 acceptParams.reserve(tcpAcceptStates.size());
1369
1370 for (auto& state : tcpAcceptStates) {
1371 acceptParams.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD});
1372 for (const auto& [addr, socket] : state->d_additionalAddresses) {
1373 acceptParams.emplace_back(TCPAcceptorParam{*state, addr, acl, socket});
1374 }
1375 }
1376
1377 auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) {
1378 auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
1379 acceptNewConnection(*acceptorParam, &data);
1380 };
1381
1382 for (size_t idx = 0; idx < acceptParams.size(); idx++) {
1383 const auto& param = acceptParams.at(idx);
1384 setNonBlocking(param.socket);
1385 data.mplexer->addReadFD(param.socket, acceptCallback, &param);
1386 }
1387
1388 struct timeval now;
1389 gettimeofday(&now, nullptr);
1390 time_t lastTimeoutScan = now.tv_sec;
1391
1392 for (;;) {
1393 data.mplexer->run(&now);
1394
1395 try {
1396 t_downstreamTCPConnectionsManager.cleanupClosedConnections(now);
1397
1398 if (now.tv_sec > lastTimeoutScan) {
1399 lastTimeoutScan = now.tv_sec;
1400 auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1401 for (const auto& cbData : expiredReadConns) {
1402 if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1403 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
1404 if (cbData.first == state->d_handler.getDescriptor()) {
1405 vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1406 state->handleTimeout(state, false);
1407 }
1408 }
1409 #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1410 else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
1411 auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1412 if (cbData.first == state->d_handler.getDescriptor()) {
1413 vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
1414 std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1415 state->handleTimeout(parentState, false);
1416 }
1417 }
1418 #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1419 else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1420 auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
1421 vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
1422 conn->handleTimeout(now, false);
1423 }
1424 }
1425
1426 auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1427 for (const auto& cbData : expiredWriteConns) {
1428 if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1429 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
1430 if (cbData.first == state->d_handler.getDescriptor()) {
1431 vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1432 state->handleTimeout(state, true);
1433 }
1434 }
1435 #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1436 else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
1437 auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1438 if (cbData.first == state->d_handler.getDescriptor()) {
1439 vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
1440 std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1441 state->handleTimeout(parentState, true);
1442 }
1443 }
1444 #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1445 else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1446 auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
1447 vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
1448 conn->handleTimeout(now, true);
1449 }
1450 }
1451
1452 if (g_tcpStatesDumpRequested > 0) {
1453 /* just to keep things clean in the output, debug only */
1454 static std::mutex s_lock;
1455 std::lock_guard<decltype(s_lock)> lck(s_lock);
1456 if (g_tcpStatesDumpRequested > 0) {
1457 /* no race here, we took the lock so it can only be increased in the meantime */
1458 --g_tcpStatesDumpRequested;
1459 infolog("Dumping the TCP states, as requested:");
1460 data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd)
1461 {
1462 struct timeval lnow;
1463 gettimeofday(&lnow, nullptr);
1464 if (ttd.tv_sec > 0) {
1465 infolog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec));
1466 }
1467 else {
1468 infolog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write"));
1469 }
1470
1471 if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
1472 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1473 infolog(" - %s", state->toString());
1474 }
1475 #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1476 else if (param.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
1477 auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
1478 infolog(" - %s", state->toString());
1479 }
1480 #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1481 else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
1482 auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
1483 infolog(" - %s", conn->toString());
1484 }
1485 else if (param.type() == typeid(TCPClientThreadData*)) {
1486 infolog(" - Worker thread pipe");
1487 }
1488 });
1489 infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount());
1490 }
1491 }
1492 }
1493 }
1494 catch (const std::exception& e) {
1495 warnlog("Error in TCP worker thread: %s", e.what());
1496 }
1497 }
1498 }
1499 catch (const std::exception& e) {
1500 errlog("Fatal error in TCP worker thread: %s", e.what());
1501 }
1502 }
1503
1504 static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData)
1505 {
1506 auto& cs = param.cs;
1507 auto& acl = param.acl;
1508 const bool checkACL = !cs.dohFrontend || (!cs.dohFrontend->d_trustForwardedForHeader && cs.dohFrontend->d_earlyACLDrop);
1509 const int socket = param.socket;
1510 bool tcpClientCountIncremented = false;
1511 ComboAddress remote;
1512 remote.sin4.sin_family = param.local.sin4.sin_family;
1513
1514 tcpClientCountIncremented = false;
1515 try {
1516 socklen_t remlen = remote.getSocklen();
1517 ConnectionInfo ci(&cs);
1518 #ifdef HAVE_ACCEPT4
1519 ci.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1520 #else
1521 ci.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1522 #endif
1523 // will be decremented when the ConnectionInfo object is destroyed, no matter the reason
1524 auto concurrentConnections = ++cs.tcpCurrentConnections;
1525
1526 if (ci.fd < 0) {
1527 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
1528 }
1529
1530 if (checkACL && !acl->match(remote)) {
1531 ++dnsdist::metrics::g_stats.aclDrops;
1532 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1533 return;
1534 }
1535
1536 if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) {
1537 vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort());
1538 return;
1539 }
1540
1541 if (concurrentConnections > cs.tcpMaxConcurrentConnections.load()) {
1542 cs.tcpMaxConcurrentConnections.store(concurrentConnections);
1543 }
1544
1545 #ifndef HAVE_ACCEPT4
1546 if (!setNonBlocking(ci.fd)) {
1547 return;
1548 }
1549 #endif
1550
1551 setTCPNoDelay(ci.fd); // disable NAGLE
1552
1553 if (g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1554 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1555 return;
1556 }
1557
1558 if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
1559 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1560 return;
1561 }
1562 tcpClientCountIncremented = true;
1563
1564 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1565
1566 ci.remote = remote;
1567
1568 if (threadData == nullptr) {
1569 if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) {
1570 if (tcpClientCountIncremented) {
1571 dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
1572 }
1573 }
1574 }
1575 else {
1576 struct timeval now;
1577 gettimeofday(&now, nullptr);
1578
1579 if (ci.cs->dohFrontend) {
1580 #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1581 auto state = std::make_shared<IncomingHTTP2Connection>(std::move(ci), *threadData, now);
1582 state->handleIO();
1583 #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1584 }
1585 else {
1586 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
1587 state->handleIO();
1588 }
1589 }
1590 }
1591 catch (const std::exception& e) {
1592 errlog("While reading a TCP question: %s", e.what());
1593 if (tcpClientCountIncremented) {
1594 dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
1595 }
1596 }
1597 catch (...){}
1598 }
1599
1600 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1601 they will hand off to worker threads & spawn more of them if required
1602 */
1603 #ifndef USE_SINGLE_ACCEPTOR_THREAD
1604 void tcpAcceptorThread(std::vector<ClientState*> states)
1605 {
1606 setThreadName("dnsdist/tcpAcce");
1607
1608 auto acl = g_ACL.getLocal();
1609 std::vector<TCPAcceptorParam> params;
1610 params.reserve(states.size());
1611
1612 for (auto& state : states) {
1613 params.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD});
1614 for (const auto& [addr, socket] : state->d_additionalAddresses) {
1615 params.emplace_back(TCPAcceptorParam{*state, addr, acl, socket});
1616 }
1617 }
1618
1619 if (params.size() == 1) {
1620 while (true) {
1621 acceptNewConnection(params.at(0), nullptr);
1622 }
1623 }
1624 else {
1625 auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) {
1626 auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
1627 acceptNewConnection(*acceptorParam, nullptr);
1628 };
1629
1630 auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(params.size()));
1631 for (size_t idx = 0; idx < params.size(); idx++) {
1632 const auto& param = params.at(idx);
1633 mplexer->addReadFD(param.socket, acceptCallback, &param);
1634 }
1635
1636 struct timeval tv;
1637 while (true) {
1638 mplexer->run(&tv, -1);
1639 }
1640 }
1641 }
1642 #endif