]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #9134 from omoerbeek/secpoll-cleanup
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-proxy-protocol.hh"
25 #include "dnsdist-rings.hh"
26 #include "dnsdist-xpf.hh"
27
28 #include "dnsparser.hh"
29 #include "ednsoptions.hh"
30 #include "dolog.hh"
31 #include "lock.hh"
32 #include "gettime.hh"
33 #include "tcpiohandler.hh"
34 #include "threadname.hh"
35 #include <thread>
36 #include <atomic>
37 #include <netinet/tcp.h>
38
39 #include "sstuff.hh"
40
41 using std::thread;
42 using std::atomic;
43
44 /* TCP: the grand design.
45 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
46 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
47 we will not go there.
48
49 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
50 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
51 to guarantee performance.
52
53 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
54 So whenever an answer comes in, we know where it needs to go.
55
56 Let's start naively.
57 */
58
59 static std::mutex tcpClientsCountMutex;
60 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
61 static const size_t g_maxCachedConnectionsPerDownstream = 20;
62 uint64_t g_maxTCPQueuedConnections{1000};
63 size_t g_maxTCPQueriesPerConn{0};
64 size_t g_maxTCPConnectionDuration{0};
65 size_t g_maxTCPConnectionsPerClient{0};
66 uint16_t g_downstreamTCPCleanupInterval{60};
67 bool g_useTCPSinglePipe{false};
68
69 static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
70 {
71 std::unique_ptr<Socket> result;
72
73 do {
74 vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
75 try {
76 result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
77 if (!IsAnyAddress(ds->sourceAddr)) {
78 SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
79 #ifdef IP_BIND_ADDRESS_NO_PORT
80 if (ds->ipBindAddrNoPort) {
81 SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
82 }
83 #endif
84 #ifdef SO_BINDTODEVICE
85 if (!ds->sourceItfName.empty()) {
86 int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, ds->sourceItfName.c_str(), ds->sourceItfName.length());
87 if (res != 0) {
88 vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds->getNameWithAddr(), stringerror());
89 }
90 }
91 #endif
92 result->bind(ds->sourceAddr, false);
93 }
94 result->setNonBlocking();
95 #ifdef MSG_FASTOPEN
96 if (!ds->tcpFastOpen) {
97 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
98 }
99 #else
100 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
101 #endif /* MSG_FASTOPEN */
102 return result;
103 }
104 catch(const std::runtime_error& e) {
105 vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
106 downstreamFailures++;
107 if (downstreamFailures > ds->retries) {
108 throw;
109 }
110 }
111 } while(downstreamFailures <= ds->retries);
112
113 return nullptr;
114 }
115
116 class TCPConnectionToBackend
117 {
118 public:
119 TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
120 {
121 d_socket = setupTCPDownstream(d_ds, downstreamFailures);
122 ++d_ds->tcpCurrentConnections;
123 }
124
125 ~TCPConnectionToBackend()
126 {
127 if (d_ds && d_socket) {
128 --d_ds->tcpCurrentConnections;
129 struct timeval now;
130 gettimeofday(&now, nullptr);
131
132 auto diff = now - d_connectionStartTime;
133 d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
134 }
135 }
136
137 int getHandle() const
138 {
139 if (!d_socket) {
140 throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
141 }
142
143 return d_socket->getHandle();
144 }
145
146 const ComboAddress& getRemote() const
147 {
148 return d_ds->remote;
149 }
150
151 bool isFresh() const
152 {
153 return d_fresh;
154 }
155
156 void incQueries()
157 {
158 ++d_queries;
159 }
160
161 void setReused()
162 {
163 d_fresh = false;
164 }
165
166 void disableFastOpen()
167 {
168 d_enableFastOpen = false;
169 }
170
171 bool isFastOpenEnabled()
172 {
173 return d_enableFastOpen;
174 }
175
176 bool canBeReused() const
177 {
178 /* we can't reuse a connection where a proxy protocol payload has been sent,
179 since:
180 - it cannot be reused for a different client
181 - we might have different TLV values for each query
182 */
183 if (d_ds && d_ds->useProxyProtocol) {
184 return false;
185 }
186 return true;
187 }
188
189 bool matches(const std::shared_ptr<DownstreamState>& ds) const
190 {
191 if (!ds || !d_ds) {
192 return false;
193 }
194 return ds == d_ds;
195 }
196
197 private:
198 std::unique_ptr<Socket> d_socket{nullptr};
199 std::shared_ptr<DownstreamState> d_ds{nullptr};
200 struct timeval d_connectionStartTime;
201 uint64_t d_queries{0};
202 bool d_fresh{true};
203 bool d_enableFastOpen{false};
204 };
205
206 static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
207
208 static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
209 {
210 std::unique_ptr<TCPConnectionToBackend> result;
211
212 const auto& it = t_downstreamConnections.find(ds->remote);
213 if (it != t_downstreamConnections.end()) {
214 auto& list = it->second;
215 if (!list.empty()) {
216 result = std::move(list.front());
217 list.pop_front();
218 result->setReused();
219 return result;
220 }
221 }
222
223 return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
224 }
225
226 static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
227 {
228 if (conn == nullptr) {
229 return;
230 }
231
232 if (!conn->canBeReused()) {
233 conn.reset();
234 return;
235 }
236
237 const auto& remote = conn->getRemote();
238 const auto& it = t_downstreamConnections.find(remote);
239 if (it != t_downstreamConnections.end()) {
240 auto& list = it->second;
241 if (list.size() >= g_maxCachedConnectionsPerDownstream) {
242 /* too many connections queued already */
243 conn.reset();
244 return;
245 }
246 list.push_back(std::move(conn));
247 }
248 else {
249 t_downstreamConnections[remote].push_back(std::move(conn));
250 }
251 }
252
253 struct ConnectionInfo
254 {
255 ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
256 {
257 }
258 ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
259 {
260 rhs.cs = nullptr;
261 rhs.fd = -1;
262 }
263
264 ConnectionInfo(const ConnectionInfo& rhs) = delete;
265 ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
266
267 ConnectionInfo& operator=(ConnectionInfo&& rhs)
268 {
269 remote = rhs.remote;
270 cs = rhs.cs;
271 rhs.cs = nullptr;
272 fd = rhs.fd;
273 rhs.fd = -1;
274 return *this;
275 }
276
277 ~ConnectionInfo()
278 {
279 if (fd != -1) {
280 close(fd);
281 fd = -1;
282 }
283 if (cs) {
284 --cs->tcpCurrentConnections;
285 }
286 }
287
288 ComboAddress remote;
289 ClientState* cs{nullptr};
290 int fd{-1};
291 };
292
293 void tcpClientThread(int pipefd);
294
295 static void decrementTCPClientCount(const ComboAddress& client)
296 {
297 if (g_maxTCPConnectionsPerClient) {
298 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
299 tcpClientsCount[client]--;
300 if (tcpClientsCount[client] == 0) {
301 tcpClientsCount.erase(client);
302 }
303 }
304 }
305
306 void TCPClientCollection::addTCPClientThread()
307 {
308 int pipefds[2] = { -1, -1};
309
310 vinfolog("Adding TCP Client thread");
311
312 if (d_useSinglePipe) {
313 pipefds[0] = d_singlePipe[0];
314 pipefds[1] = d_singlePipe[1];
315 }
316 else {
317 if (pipe(pipefds) < 0) {
318 errlog("Error creating the TCP thread communication pipe: %s", stringerror());
319 return;
320 }
321
322 if (!setNonBlocking(pipefds[0])) {
323 int err = errno;
324 close(pipefds[0]);
325 close(pipefds[1]);
326 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
327 return;
328 }
329
330 if (!setNonBlocking(pipefds[1])) {
331 int err = errno;
332 close(pipefds[0]);
333 close(pipefds[1]);
334 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
335 return;
336 }
337 }
338
339 {
340 std::lock_guard<std::mutex> lock(d_mutex);
341
342 if (d_numthreads >= d_tcpclientthreads.capacity()) {
343 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
344 if (!d_useSinglePipe) {
345 close(pipefds[0]);
346 close(pipefds[1]);
347 }
348 return;
349 }
350
351 try {
352 thread t1(tcpClientThread, pipefds[0]);
353 t1.detach();
354 }
355 catch(const std::runtime_error& e) {
356 /* the thread creation failed, don't leak */
357 errlog("Error creating a TCP thread: %s", e.what());
358 if (!d_useSinglePipe) {
359 close(pipefds[0]);
360 close(pipefds[1]);
361 }
362 return;
363 }
364
365 d_tcpclientthreads.push_back(pipefds[1]);
366 ++d_numthreads;
367 }
368 }
369
370 static void cleanupClosedTCPConnections()
371 {
372 for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
373 for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
374 if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
375 ++connIt;
376 }
377 else {
378 connIt = dsIt->second.erase(connIt);
379 }
380 }
381
382 if (!dsIt->second.empty()) {
383 ++dsIt;
384 }
385 else {
386 dsIt = t_downstreamConnections.erase(dsIt);
387 }
388 }
389 }
390
391 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
392 Updates pos everytime a successful read occurs,
393 throws an std::runtime_error in case of IO error,
394 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
395 would block.
396 */
397 // XXX could probably be implemented as a TCPIOHandler
398 static IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
399 {
400 if (buffer.size() < (pos + toRead)) {
401 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
402 }
403
404 size_t got = 0;
405 do {
406 ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
407 if (res == 0) {
408 throw runtime_error("EOF while reading message");
409 }
410 if (res < 0) {
411 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
412 return IOState::NeedRead;
413 }
414 else {
415 throw std::runtime_error(std::string("Error while reading message: ") + stringerror());
416 }
417 }
418
419 pos += static_cast<size_t>(res);
420 got += static_cast<size_t>(res);
421 }
422 while (got < toRead);
423
424 return IOState::Done;
425 }
426
427 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
428
429 class TCPClientThreadData
430 {
431 public:
432 TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
433 {
434 }
435
436 LocalHolders holders;
437 LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
438 std::unique_ptr<FDMultiplexer> mplexer{nullptr};
439 };
440
441 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
442
443 class IncomingTCPConnectionState
444 {
445 public:
446 IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
447 {
448 d_ids.origDest.reset();
449 d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
450 socklen_t socklen = d_ids.origDest.getSocklen();
451 if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
452 d_ids.origDest = d_ci.cs->local;
453 }
454 }
455
456 IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
457 IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
458
459 ~IncomingTCPConnectionState()
460 {
461 decrementTCPClientCount(d_ci.remote);
462 if (d_ci.cs != nullptr) {
463 struct timeval now;
464 gettimeofday(&now, nullptr);
465
466 auto diff = now - d_connectionStartTime;
467 d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
468 }
469
470 if (d_ds != nullptr) {
471 if (d_outstanding) {
472 --d_ds->outstanding;
473 d_outstanding = false;
474 }
475
476 if (d_downstreamConnection) {
477 try {
478 if (d_lastIOState == IOState::NeedRead) {
479 cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
480 d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
481 }
482 else if (d_lastIOState == IOState::NeedWrite) {
483 cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
484 d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
485 }
486 }
487 catch(const FDMultiplexerException& e) {
488 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
489 }
490 catch(const std::runtime_error& e) {
491 /* might be thrown by getHandle() */
492 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
493 }
494 }
495 }
496
497 try {
498 if (d_lastIOState == IOState::NeedRead) {
499 cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
500 d_threadData.mplexer->removeReadFD(d_ci.fd);
501 }
502 else if (d_lastIOState == IOState::NeedWrite) {
503 cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
504 d_threadData.mplexer->removeWriteFD(d_ci.fd);
505 }
506 }
507 catch(const FDMultiplexerException& e) {
508 vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
509 }
510 }
511
512 void resetForNewQuery()
513 {
514 d_buffer.resize(sizeof(uint16_t));
515 d_currentPos = 0;
516 d_querySize = 0;
517 d_responseSize = 0;
518 d_downstreamFailures = 0;
519 d_state = State::readingQuerySize;
520 d_lastIOState = IOState::Done;
521 d_selfGeneratedResponse = false;
522 }
523
524 boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
525 {
526 if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
527 return boost::none;
528 }
529
530 if (g_maxTCPConnectionDuration > 0) {
531 auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
532 if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
533 return now;
534 }
535 auto remaining = g_maxTCPConnectionDuration - elapsed;
536 if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
537 now.tv_sec += remaining;
538 return now;
539 }
540 }
541
542 now.tv_sec += g_tcpRecvTimeout;
543 return now;
544 }
545
546 boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
547 {
548 if (d_ds == nullptr) {
549 throw std::runtime_error("getBackendReadTTD() without any backend selected");
550 }
551 if (d_ds->tcpRecvTimeout == 0) {
552 return boost::none;
553 }
554
555 struct timeval res = now;
556 res.tv_sec += d_ds->tcpRecvTimeout;
557
558 return res;
559 }
560
561 boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
562 {
563 if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
564 return boost::none;
565 }
566
567 struct timeval res = now;
568
569 if (g_maxTCPConnectionDuration > 0) {
570 auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
571 if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
572 return res;
573 }
574 auto remaining = g_maxTCPConnectionDuration - elapsed;
575 if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
576 res.tv_sec += remaining;
577 return res;
578 }
579 }
580
581 res.tv_sec += g_tcpSendTimeout;
582 return res;
583 }
584
585 boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
586 {
587 if (d_ds == nullptr) {
588 throw std::runtime_error("getBackendReadTTD() called without any backend selected");
589 }
590 if (d_ds->tcpSendTimeout == 0) {
591 return boost::none;
592 }
593
594 struct timeval res = now;
595 res.tv_sec += d_ds->tcpSendTimeout;
596
597 return res;
598 }
599
600 bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
601 {
602 if (maxConnectionDuration) {
603 time_t curtime = now.tv_sec;
604 unsigned int elapsed = 0;
605 if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
606 elapsed = curtime - d_connectionStartTime.tv_sec;
607 }
608 if (elapsed >= maxConnectionDuration) {
609 return true;
610 }
611 d_remainingTime = maxConnectionDuration - elapsed;
612 }
613
614 return false;
615 }
616
617 void dump() const
618 {
619 static std::mutex s_mutex;
620
621 struct timeval now;
622 gettimeofday(&now, 0);
623
624 {
625 std::lock_guard<std::mutex> lock(s_mutex);
626 fprintf(stderr, "State is %p\n", this);
627 cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
628 cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
629 cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
630 if (d_state > State::doingHandshake) {
631 cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
632 }
633 if (d_state > State::readingQuerySize) {
634 cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
635 }
636 if (d_state > State::readingQuerySize) {
637 cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
638 }
639 if (d_state > State::readingQuery) {
640 cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
641 }
642 if (d_state > State::sendingQueryToBackend) {
643 cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
644 }
645 if (d_state > State::readingResponseFromBackend) {
646 cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
647 }
648 }
649 }
650
651 enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
652
653 std::vector<uint8_t> d_buffer;
654 std::vector<uint8_t> d_responseBuffer;
655 TCPClientThreadData& d_threadData;
656 IDState d_ids;
657 ConnectionInfo d_ci;
658 TCPIOHandler d_handler;
659 std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
660 std::shared_ptr<DownstreamState> d_ds{nullptr};
661 dnsheader d_cleartextDH;
662 struct timeval d_connectionStartTime;
663 struct timeval d_handshakeDoneTime;
664 struct timeval d_firstQuerySizeReadTime;
665 struct timeval d_querySizeReadTime;
666 struct timeval d_queryReadTime;
667 struct timeval d_querySentTime;
668 struct timeval d_responseReadTime;
669 size_t d_currentPos{0};
670 size_t d_queriesCount{0};
671 unsigned int d_remainingTime{0};
672 uint16_t d_querySize{0};
673 uint16_t d_responseSize{0};
674 uint16_t d_downstreamFailures{0};
675 State d_state{State::doingHandshake};
676 IOState d_lastIOState{IOState::Done};
677 bool d_readingFirstQuery{true};
678 bool d_outstanding{false};
679 bool d_firstResponsePacket{true};
680 bool d_isXFR{false};
681 bool d_xfrStarted{false};
682 bool d_selfGeneratedResponse{false};
683 bool d_proxyProtocolPayloadAdded{false};
684 bool d_proxyProtocolPayloadHasTLV{false};
685 };
686
687 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
688 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
689 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
690 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
691
692 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
693 {
694 handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
695
696 if (state->d_isXFR && state->d_downstreamConnection) {
697 /* we need to resume reading from the backend! */
698 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
699 state->d_currentPos = 0;
700 handleDownstreamIO(state, now);
701 return;
702 }
703
704 if (state->d_selfGeneratedResponse == false && state->d_ds) {
705 /* if we have no downstream server selected, this was a self-answered response
706 but cache hits have a selected server as well, so be careful */
707 struct timespec answertime;
708 gettime(&answertime);
709 double udiff = state->d_ids.sentTime.udiff();
710 g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote);
711 vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
712 }
713
714 switch (state->d_cleartextDH.rcode) {
715 case RCode::NXDomain:
716 ++g_stats.frontendNXDomain;
717 break;
718 case RCode::ServFail:
719 ++g_stats.servfailResponses;
720 ++g_stats.frontendServFail;
721 break;
722 case RCode::NoError:
723 ++g_stats.frontendNoError;
724 break;
725 }
726
727 if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
728 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
729 return;
730 }
731
732 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
733 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
734 return;
735 }
736
737 state->resetForNewQuery();
738
739 handleIO(state, now);
740 }
741
742 static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
743 {
744 state->d_state = IncomingTCPConnectionState::State::sendingResponse;
745 const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
746 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
747 that could occur if we had to deal with the size during the processing,
748 especially alignment issues */
749 state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
750
751 state->d_currentPos = 0;
752
753 handleIO(state, now);
754 }
755
756 static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
757 {
758 if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) {
759 return;
760 }
761
762 auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
763 unsigned int consumed;
764 if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
765 return;
766 }
767 state->d_firstResponsePacket = false;
768
769 if (state->d_outstanding) {
770 --state->d_ds->outstanding;
771 state->d_outstanding = false;
772 }
773
774 auto dh = reinterpret_cast<struct dnsheader*>(response);
775 uint16_t addRoom = 0;
776 DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
777 if (dr.dnsCryptQuery) {
778 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
779 }
780
781 memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH));
782
783 std::vector<uint8_t> rewrittenResponse;
784 size_t responseSize = state->d_responseBuffer.size();
785 if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
786 return;
787 }
788
789 if (!rewrittenResponse.empty()) {
790 /* responseSize has been updated as well but we don't really care since it will match
791 the capacity of rewrittenResponse anyway */
792 state->d_responseBuffer = std::move(rewrittenResponse);
793 state->d_responseSize = state->d_responseBuffer.size();
794 } else {
795 /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
796 state->d_responseBuffer.resize(state->d_responseSize);
797 }
798
799 if (state->d_isXFR && !state->d_xfrStarted) {
800 /* don't bother parsing the content of the response for now */
801 state->d_xfrStarted = true;
802 ++g_stats.responses;
803 ++state->d_ci.cs->responses;
804 ++state->d_ds->responses;
805 }
806
807 if (!state->d_isXFR) {
808 ++g_stats.responses;
809 ++state->d_ci.cs->responses;
810 ++state->d_ds->responses;
811 }
812
813 sendResponse(state, now);
814 }
815
816 static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
817 {
818 auto ds = state->d_ds;
819 state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
820 state->d_currentPos = 0;
821 state->d_firstResponsePacket = true;
822
823 if (state->d_xfrStarted) {
824 /* sorry, but we are not going to resume a XFR if we have already sent some packets
825 to the client */
826 return;
827 }
828
829 if (!state->d_downstreamConnection) {
830 if (state->d_downstreamFailures < state->d_ds->retries) {
831 try {
832 state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
833 }
834 catch (const std::runtime_error& e) {
835 state->d_downstreamConnection.reset();
836 }
837 }
838
839 if (!state->d_downstreamConnection) {
840 ++ds->tcpGaveUp;
841 ++state->d_ci.cs->tcpGaveUp;
842 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
843 return;
844 }
845
846 if (ds->useProxyProtocol && !state->d_proxyProtocolPayloadAdded) {
847 /* we know there is no TLV values to add, otherwise we would not have tried
848 to reuse the connection and d_proxyProtocolPayloadAdded would be true already */
849 addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, std::vector<ProxyProtocolValue>());
850 state->d_proxyProtocolPayloadAdded = true;
851 }
852 }
853
854 vinfolog("Got query for %s|%s from %s (%s), relayed to %s", state->d_ids.qname.toLogString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), ds->getName());
855
856 handleDownstreamIO(state, now);
857 return;
858 }
859
860 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
861 {
862 if (state->d_querySize < sizeof(dnsheader)) {
863 ++g_stats.nonCompliantQueries;
864 return;
865 }
866
867 state->d_readingFirstQuery = false;
868 state->d_proxyProtocolPayloadAdded = false;
869 ++state->d_queriesCount;
870 ++state->d_ci.cs->queries;
871 ++g_stats.queries;
872
873 if (state->d_handler.isTLS()) {
874 auto tlsVersion = state->d_handler.getTLSVersion();
875 switch (tlsVersion) {
876 case LibsslTLSVersion::TLS10:
877 ++state->d_ci.cs->tls10queries;
878 break;
879 case LibsslTLSVersion::TLS11:
880 ++state->d_ci.cs->tls11queries;
881 break;
882 case LibsslTLSVersion::TLS12:
883 ++state->d_ci.cs->tls12queries;
884 break;
885 case LibsslTLSVersion::TLS13:
886 ++state->d_ci.cs->tls13queries;
887 break;
888 default:
889 ++state->d_ci.cs->tlsUnknownqueries;
890 }
891 }
892
893 /* we need an accurate ("real") value for the response and
894 to store into the IDS, but not for insertion into the
895 rings for example */
896 struct timespec queryRealTime;
897 gettime(&queryRealTime, true);
898
899 auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
900 std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
901 auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
902 if (dnsCryptResponse) {
903 state->d_responseBuffer = std::move(*dnsCryptResponse);
904 state->d_responseSize = state->d_responseBuffer.size();
905 sendResponse(state, now);
906 return;
907 }
908
909 const auto& dh = reinterpret_cast<dnsheader*>(query);
910 if (!checkQueryHeaders(dh)) {
911 return;
912 }
913
914 uint16_t qtype, qclass;
915 unsigned int consumed = 0;
916 DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
917 DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
918 dq.dnsCryptQuery = std::move(dnsCryptQuery);
919 dq.sni = state->d_handler.getServerNameIndication();
920
921 state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
922 if (state->d_isXFR) {
923 dq.skipCache = true;
924 }
925
926 state->d_ds.reset();
927 auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
928
929 if (result == ProcessQueryResult::Drop) {
930 return;
931 }
932
933 if (result == ProcessQueryResult::SendAnswer) {
934 state->d_selfGeneratedResponse = true;
935 state->d_buffer.resize(dq.len);
936 state->d_responseBuffer = std::move(state->d_buffer);
937 state->d_responseSize = state->d_responseBuffer.size();
938 sendResponse(state, now);
939 return;
940 }
941
942 if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
943 return;
944 }
945
946 setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
947
948 const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
949 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
950 that could occur if we had to deal with the size during the processing,
951 especially alignment issues */
952 state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
953 dq.len = dq.len + 2;
954 dq.dh = reinterpret_cast<dnsheader*>(&state->d_buffer.at(0));
955 dq.size = state->d_buffer.size();
956 state->d_buffer.resize(dq.len);
957
958 if (state->d_ds->useProxyProtocol) {
959 /* if we ever sent a TLV over a connection, we can never go back */
960 if (!state->d_proxyProtocolPayloadHasTLV) {
961 state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
962 }
963
964 if (state->d_downstreamConnection && !state->d_proxyProtocolPayloadHasTLV && state->d_downstreamConnection->matches(state->d_ds)) {
965 /* we have an existing connection, on which we already sent a Proxy Protocol header with no values
966 (in the previous query had TLV values we would have reset the connection afterwards),
967 so let's reuse it as long as we still don't have any values */
968 state->d_proxyProtocolPayloadAdded = false;
969 }
970 else {
971 state->d_downstreamConnection.reset();
972 addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector<ProxyProtocolValue>());
973 state->d_proxyProtocolPayloadAdded = true;
974 }
975 }
976
977 sendQueryToBackend(state, now);
978 }
979
980 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
981 {
982 //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
983
984 if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
985 state->d_threadData.mplexer->removeReadFD(fd);
986 //cerr<<__func__<<": remove read FD "<<fd<<endl;
987 state->d_lastIOState = IOState::Done;
988 }
989 else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
990 state->d_threadData.mplexer->removeWriteFD(fd);
991 //cerr<<__func__<<": remove write FD "<<fd<<endl;
992 state->d_lastIOState = IOState::Done;
993 }
994
995 if (iostate == IOState::NeedRead) {
996 if (state->d_lastIOState == IOState::NeedRead) {
997 if (ttd) {
998 /* let's update the TTD ! */
999 state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
1000 }
1001 return;
1002 }
1003
1004 state->d_lastIOState = IOState::NeedRead;
1005 //cerr<<__func__<<": add read FD "<<fd<<endl;
1006 state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
1007 }
1008 else if (iostate == IOState::NeedWrite) {
1009 if (state->d_lastIOState == IOState::NeedWrite) {
1010 return;
1011 }
1012
1013 state->d_lastIOState = IOState::NeedWrite;
1014 //cerr<<__func__<<": add write FD "<<fd<<endl;
1015 state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
1016 }
1017 else if (iostate == IOState::Done) {
1018 state->d_lastIOState = IOState::Done;
1019 }
1020 }
1021
1022 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
1023 {
1024 if (state->d_downstreamConnection == nullptr) {
1025 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
1026 }
1027
1028 int fd = state->d_downstreamConnection->getHandle();
1029 IOState iostate = IOState::Done;
1030 bool connectionDied = false;
1031
1032 try {
1033 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
1034 int socketFlags = 0;
1035 #ifdef MSG_FASTOPEN
1036 if (state->d_downstreamConnection->isFastOpenEnabled()) {
1037 socketFlags |= MSG_FASTOPEN;
1038 }
1039 #endif /* MSG_FASTOPEN */
1040
1041 size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
1042 if (sent == state->d_buffer.size()) {
1043 /* request sent ! */
1044 state->d_downstreamConnection->incQueries();
1045 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
1046 state->d_currentPos = 0;
1047 state->d_querySentTime = now;
1048 iostate = IOState::NeedRead;
1049 if (!state->d_isXFR && !state->d_outstanding) {
1050 /* don't bother with the outstanding count for XFR queries */
1051 ++state->d_ds->outstanding;
1052 state->d_outstanding = true;
1053 }
1054 }
1055 else {
1056 state->d_currentPos += sent;
1057 iostate = IOState::NeedWrite;
1058 /* disable fast open on partial write */
1059 state->d_downstreamConnection->disableFastOpen();
1060 }
1061 }
1062
1063 if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
1064 // then we need to allocate a new buffer (new because we might need to re-send the query if the
1065 // backend dies on us
1066 // We also might need to read and send to the client more than one response in case of XFR (yeah!)
1067 // should very likely be a TCPIOHandler d_downstreamHandler
1068 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
1069 if (iostate == IOState::Done) {
1070 state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
1071 state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
1072 state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
1073 state->d_currentPos = 0;
1074 }
1075 }
1076
1077 if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
1078 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
1079 if (iostate == IOState::Done) {
1080 handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
1081
1082 if (state->d_isXFR) {
1083 /* Don't reuse the TCP connection after an {A,I}XFR */
1084 /* but don't reset it either, we will need to read more messages */
1085 }
1086 else {
1087 /* if we did not send a Proxy Protocol header, let's pool the connection */
1088 if (state->d_ds && state->d_ds->useProxyProtocol == false) {
1089 releaseDownstreamConnection(std::move(state->d_downstreamConnection));
1090 }
1091 else {
1092 if (state->d_proxyProtocolPayloadHasTLV) {
1093 /* sent a Proxy Protocol header with TLV values, we can't reuse it */
1094 state->d_downstreamConnection.reset();
1095 }
1096 else {
1097 /* if we did but there was no TLV values, let's try to reuse it but only
1098 for this incoming connection */
1099 }
1100 }
1101 }
1102 fd = -1;
1103
1104 state->d_responseReadTime = now;
1105 try {
1106 handleResponse(state, now);
1107 }
1108 catch (const std::exception& e) {
1109 vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what());
1110 }
1111 return;
1112 }
1113 }
1114
1115 if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
1116 state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
1117 state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
1118 vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
1119 }
1120 }
1121 catch(const std::exception& e) {
1122 /* most likely an EOF because the other end closed the connection,
1123 but it might also be a real IO error or something else.
1124 Let's just drop the connection
1125 */
1126 vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
1127 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
1128 ++state->d_ds->tcpDiedSendingQuery;
1129 }
1130 else {
1131 ++state->d_ds->tcpDiedReadingResponse;
1132 }
1133
1134 /* don't increase this counter when reusing connections */
1135 if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) {
1136 ++state->d_downstreamFailures;
1137 }
1138
1139 if (state->d_outstanding) {
1140 state->d_outstanding = false;
1141
1142 if (state->d_ds != nullptr) {
1143 --state->d_ds->outstanding;
1144 }
1145 }
1146 /* remove this FD from the IO multiplexer */
1147 iostate = IOState::Done;
1148 connectionDied = true;
1149 }
1150
1151 if (iostate == IOState::Done) {
1152 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
1153 }
1154 else {
1155 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
1156 }
1157
1158 if (connectionDied) {
1159 state->d_downstreamConnection.reset();
1160 sendQueryToBackend(state, now);
1161 }
1162 }
1163
1164 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1165 {
1166 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1167 if (state->d_downstreamConnection == nullptr) {
1168 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
1169 }
1170 if (fd != state->d_downstreamConnection->getHandle()) {
1171 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
1172 }
1173
1174 struct timeval now;
1175 gettimeofday(&now, 0);
1176 handleDownstreamIO(state, now);
1177 }
1178
1179 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
1180 {
1181 int fd = state->d_ci.fd;
1182 IOState iostate = IOState::Done;
1183
1184 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
1185 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
1186 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1187 return;
1188 }
1189
1190 try {
1191 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
1192 iostate = state->d_handler.tryHandshake();
1193 if (iostate == IOState::Done) {
1194 if (state->d_handler.isTLS()) {
1195 if (!state->d_handler.hasTLSSessionBeenResumed()) {
1196 ++state->d_ci.cs->tlsNewSessions;
1197 }
1198 else {
1199 ++state->d_ci.cs->tlsResumptions;
1200 }
1201 if (state->d_handler.getResumedFromInactiveTicketKey()) {
1202 ++state->d_ci.cs->tlsInactiveTicketKey;
1203 }
1204 if (state->d_handler.getUnknownTicketKey()) {
1205 ++state->d_ci.cs->tlsUnknownTicketKey;
1206 }
1207 }
1208
1209 state->d_handshakeDoneTime = now;
1210 state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
1211 }
1212 }
1213
1214 if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
1215 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
1216 if (iostate == IOState::Done) {
1217 state->d_state = IncomingTCPConnectionState::State::readingQuery;
1218 state->d_querySizeReadTime = now;
1219 if (state->d_queriesCount == 0) {
1220 state->d_firstQuerySizeReadTime = now;
1221 }
1222 state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
1223 if (state->d_querySize < sizeof(dnsheader)) {
1224 /* go away */
1225 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1226 return;
1227 }
1228
1229 /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
1230 or to add ECS without allocating a new buffer */
1231 state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
1232 state->d_currentPos = 0;
1233 }
1234 }
1235
1236 if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1237 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
1238 if (iostate == IOState::Done) {
1239 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1240 handleQuery(state, now);
1241 return;
1242 }
1243 }
1244
1245 if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1246 iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
1247 if (iostate == IOState::Done) {
1248 handleResponseSent(state, now);
1249 return;
1250 }
1251 }
1252
1253 if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
1254 state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
1255 state->d_state != IncomingTCPConnectionState::State::readingQuery &&
1256 state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
1257 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
1258 }
1259 }
1260 catch(const std::exception& e) {
1261 /* most likely an EOF because the other end closed the connection,
1262 but it might also be a real IO error or something else.
1263 Let's just drop the connection
1264 */
1265 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
1266 state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
1267 state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1268 ++state->d_ci.cs->tcpDiedReadingQuery;
1269 }
1270 else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1271 ++state->d_ci.cs->tcpDiedSendingResponse;
1272 }
1273
1274 if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
1275 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
1276 }
1277 else {
1278 vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
1279 }
1280 /* remove this FD from the IO multiplexer */
1281 iostate = IOState::Done;
1282 }
1283
1284 if (iostate == IOState::Done) {
1285 handleNewIOState(state, iostate, fd, handleIOCallback);
1286 }
1287 else {
1288 handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
1289 }
1290 }
1291
1292 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1293 {
1294 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1295 if (fd != state->d_ci.fd) {
1296 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
1297 }
1298 struct timeval now;
1299 gettimeofday(&now, 0);
1300
1301 handleIO(state, now);
1302 }
1303
1304 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1305 {
1306 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1307
1308 ConnectionInfo* citmp{nullptr};
1309
1310 ssize_t got = read(pipefd, &citmp, sizeof(citmp));
1311 if (got == 0) {
1312 throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1313 }
1314 else if (got == -1) {
1315 if (errno == EAGAIN || errno == EINTR) {
1316 return;
1317 }
1318 throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
1319 }
1320 else if (got != sizeof(citmp)) {
1321 throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1322 }
1323
1324 try {
1325 g_tcpclientthreads->decrementQueuedCount();
1326
1327 struct timeval now;
1328 gettimeofday(&now, 0);
1329 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1330 delete citmp;
1331 citmp = nullptr;
1332
1333 /* let's update the remaining time */
1334 state->d_remainingTime = g_maxTCPConnectionDuration;
1335
1336 handleIO(state, now);
1337 }
1338 catch(...) {
1339 delete citmp;
1340 citmp = nullptr;
1341 throw;
1342 }
1343 }
1344
1345 void tcpClientThread(int pipefd)
1346 {
1347 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1348 from that point on */
1349
1350 setThreadName("dnsdist/tcpClie");
1351
1352 TCPClientThreadData data;
1353
1354 data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
1355 struct timeval now;
1356 gettimeofday(&now, 0);
1357 time_t lastTCPCleanup = now.tv_sec;
1358 time_t lastTimeoutScan = now.tv_sec;
1359
1360 for (;;) {
1361 data.mplexer->run(&now);
1362
1363 if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
1364 cleanupClosedTCPConnections();
1365 lastTCPCleanup = now.tv_sec;
1366 }
1367
1368 if (now.tv_sec > lastTimeoutScan) {
1369 lastTimeoutScan = now.tv_sec;
1370 auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1371 for(const auto& conn : expiredReadConns) {
1372 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1373 if (conn.first == state->d_ci.fd) {
1374 vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1375 ++state->d_ci.cs->tcpClientTimeouts;
1376 }
1377 else if (state->d_ds) {
1378 vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
1379 ++state->d_ci.cs->tcpDownstreamTimeouts;
1380 ++state->d_ds->tcpReadTimeouts;
1381 }
1382 data.mplexer->removeReadFD(conn.first);
1383 state->d_lastIOState = IOState::Done;
1384 }
1385
1386 auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1387 for(const auto& conn : expiredWriteConns) {
1388 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1389 if (conn.first == state->d_ci.fd) {
1390 vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1391 ++state->d_ci.cs->tcpClientTimeouts;
1392 }
1393 else if (state->d_ds) {
1394 vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
1395 ++state->d_ci.cs->tcpDownstreamTimeouts;
1396 ++state->d_ds->tcpWriteTimeouts;
1397 }
1398 data.mplexer->removeWriteFD(conn.first);
1399 state->d_lastIOState = IOState::Done;
1400 }
1401 }
1402 }
1403 }
1404
1405 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1406 they will hand off to worker threads & spawn more of them if required
1407 */
1408 void tcpAcceptorThread(void* p)
1409 {
1410 setThreadName("dnsdist/tcpAcce");
1411 ClientState* cs = (ClientState*) p;
1412 bool tcpClientCountIncremented = false;
1413 ComboAddress remote;
1414 remote.sin4.sin_family = cs->local.sin4.sin_family;
1415
1416 if(!g_tcpclientthreads->hasReachedMaxThreads()) {
1417 g_tcpclientthreads->addTCPClientThread();
1418 }
1419
1420 auto acl = g_ACL.getLocal();
1421 for(;;) {
1422 bool queuedCounterIncremented = false;
1423 std::unique_ptr<ConnectionInfo> ci;
1424 tcpClientCountIncremented = false;
1425 try {
1426 socklen_t remlen = remote.getSocklen();
1427 ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
1428 #ifdef HAVE_ACCEPT4
1429 ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1430 #else
1431 ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1432 #endif
1433 ++cs->tcpCurrentConnections;
1434
1435 if(ci->fd < 0) {
1436 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
1437 }
1438
1439 if(!acl->match(remote)) {
1440 ++g_stats.aclDrops;
1441 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1442 continue;
1443 }
1444
1445 #ifndef HAVE_ACCEPT4
1446 if (!setNonBlocking(ci->fd)) {
1447 continue;
1448 }
1449 #endif
1450 setTCPNoDelay(ci->fd); // disable NAGLE
1451 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1452 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1453 continue;
1454 }
1455
1456 if (g_maxTCPConnectionsPerClient) {
1457 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
1458
1459 if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
1460 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1461 continue;
1462 }
1463 tcpClientsCount[remote]++;
1464 tcpClientCountIncremented = true;
1465 }
1466
1467 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1468
1469 ci->remote = remote;
1470 int pipe = g_tcpclientthreads->getThread();
1471 if (pipe >= 0) {
1472 queuedCounterIncremented = true;
1473 auto tmp = ci.release();
1474 try {
1475 writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
1476 }
1477 catch(...) {
1478 delete tmp;
1479 tmp = nullptr;
1480 throw;
1481 }
1482 }
1483 else {
1484 g_tcpclientthreads->decrementQueuedCount();
1485 queuedCounterIncremented = false;
1486 if(tcpClientCountIncremented) {
1487 decrementTCPClientCount(remote);
1488 }
1489 }
1490 }
1491 catch(const std::exception& e) {
1492 errlog("While reading a TCP question: %s", e.what());
1493 if(tcpClientCountIncremented) {
1494 decrementTCPClientCount(remote);
1495 }
1496 if (queuedCounterIncremented) {
1497 g_tcpclientthreads->decrementQueuedCount();
1498 }
1499 }
1500 catch(...){}
1501 }
1502 }