]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
76817540e0e37cf22b3e8b43b69f554c5cc14e62
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-proxy-protocol.hh"
25 #include "dnsdist-rings.hh"
26 #include "dnsdist-xpf.hh"
27
28 #include "dnsparser.hh"
29 #include "ednsoptions.hh"
30 #include "dolog.hh"
31 #include "lock.hh"
32 #include "gettime.hh"
33 #include "tcpiohandler.hh"
34 #include "threadname.hh"
35 #include <thread>
36 #include <atomic>
37 #include <netinet/tcp.h>
38
39 #include "sstuff.hh"
40
41 using std::thread;
42 using std::atomic;
43
44 /* TCP: the grand design.
45 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
46 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
47 we will not go there.
48
49 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
50 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
51 to guarantee performance.
52
53 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
54 So whenever an answer comes in, we know where it needs to go.
55
56 Let's start naively.
57 */
58
59 static std::mutex tcpClientsCountMutex;
60 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
61 static const size_t g_maxCachedConnectionsPerDownstream = 20;
62 uint64_t g_maxTCPQueuedConnections{1000};
63 size_t g_maxTCPQueriesPerConn{0};
64 size_t g_maxTCPConnectionDuration{0};
65 size_t g_maxTCPConnectionsPerClient{0};
66 uint16_t g_downstreamTCPCleanupInterval{60};
67 bool g_useTCPSinglePipe{false};
68
69 static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
70 {
71 std::unique_ptr<Socket> result;
72
73 do {
74 vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
75 try {
76 result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
77 if (!IsAnyAddress(ds->sourceAddr)) {
78 SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
79 #ifdef IP_BIND_ADDRESS_NO_PORT
80 if (ds->ipBindAddrNoPort) {
81 SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
82 }
83 #endif
84 #ifdef SO_BINDTODEVICE
85 if (!ds->sourceItfName.empty()) {
86 int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, ds->sourceItfName.c_str(), ds->sourceItfName.length());
87 if (res != 0) {
88 vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds->getNameWithAddr(), stringerror());
89 }
90 }
91 #endif
92 result->bind(ds->sourceAddr, false);
93 }
94 result->setNonBlocking();
95 #ifdef MSG_FASTOPEN
96 if (!ds->tcpFastOpen) {
97 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
98 }
99 #else
100 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
101 #endif /* MSG_FASTOPEN */
102 return result;
103 }
104 catch(const std::runtime_error& e) {
105 vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
106 downstreamFailures++;
107 if (downstreamFailures > ds->retries) {
108 throw;
109 }
110 }
111 } while(downstreamFailures <= ds->retries);
112
113 return nullptr;
114 }
115
116 class TCPConnectionToBackend
117 {
118 public:
119 TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
120 {
121 d_socket = setupTCPDownstream(d_ds, downstreamFailures);
122 ++d_ds->tcpCurrentConnections;
123 }
124
125 ~TCPConnectionToBackend()
126 {
127 if (d_ds && d_socket) {
128 --d_ds->tcpCurrentConnections;
129 struct timeval now;
130 gettimeofday(&now, nullptr);
131
132 auto diff = now - d_connectionStartTime;
133 d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
134 }
135 }
136
137 int getHandle() const
138 {
139 if (!d_socket) {
140 throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
141 }
142
143 return d_socket->getHandle();
144 }
145
146 const ComboAddress& getRemote() const
147 {
148 return d_ds->remote;
149 }
150
151 bool isFresh() const
152 {
153 return d_fresh;
154 }
155
156 void incQueries()
157 {
158 ++d_queries;
159 }
160
161 void setReused()
162 {
163 d_fresh = false;
164 }
165
166 void disableFastOpen()
167 {
168 d_enableFastOpen = false;
169 }
170
171 bool isFastOpenEnabled()
172 {
173 return d_enableFastOpen;
174 }
175
176 private:
177 std::unique_ptr<Socket> d_socket{nullptr};
178 std::shared_ptr<DownstreamState> d_ds{nullptr};
179 struct timeval d_connectionStartTime;
180 uint64_t d_queries{0};
181 bool d_fresh{true};
182 bool d_enableFastOpen{false};
183 };
184
185 static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
186
187 static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
188 {
189 std::unique_ptr<TCPConnectionToBackend> result;
190
191 const auto& it = t_downstreamConnections.find(ds->remote);
192 if (it != t_downstreamConnections.end()) {
193 auto& list = it->second;
194 if (!list.empty()) {
195 result = std::move(list.front());
196 list.pop_front();
197 result->setReused();
198 return result;
199 }
200 }
201
202 return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
203 }
204
205 static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
206 {
207 if (conn == nullptr) {
208 return;
209 }
210
211 const auto& remote = conn->getRemote();
212 const auto& it = t_downstreamConnections.find(remote);
213 if (it != t_downstreamConnections.end()) {
214 auto& list = it->second;
215 if (list.size() >= g_maxCachedConnectionsPerDownstream) {
216 /* too many connections queued already */
217 conn.reset();
218 return;
219 }
220 list.push_back(std::move(conn));
221 }
222 else {
223 t_downstreamConnections[remote].push_back(std::move(conn));
224 }
225 }
226
227 struct ConnectionInfo
228 {
229 ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
230 {
231 }
232 ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
233 {
234 rhs.cs = nullptr;
235 rhs.fd = -1;
236 }
237
238 ConnectionInfo(const ConnectionInfo& rhs) = delete;
239 ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
240
241 ConnectionInfo& operator=(ConnectionInfo&& rhs)
242 {
243 remote = rhs.remote;
244 cs = rhs.cs;
245 rhs.cs = nullptr;
246 fd = rhs.fd;
247 rhs.fd = -1;
248 return *this;
249 }
250
251 ~ConnectionInfo()
252 {
253 if (fd != -1) {
254 close(fd);
255 fd = -1;
256 }
257 if (cs) {
258 --cs->tcpCurrentConnections;
259 }
260 }
261
262 ComboAddress remote;
263 ClientState* cs{nullptr};
264 int fd{-1};
265 };
266
267 void tcpClientThread(int pipefd);
268
269 static void decrementTCPClientCount(const ComboAddress& client)
270 {
271 if (g_maxTCPConnectionsPerClient) {
272 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
273 tcpClientsCount[client]--;
274 if (tcpClientsCount[client] == 0) {
275 tcpClientsCount.erase(client);
276 }
277 }
278 }
279
280 void TCPClientCollection::addTCPClientThread()
281 {
282 int pipefds[2] = { -1, -1};
283
284 vinfolog("Adding TCP Client thread");
285
286 if (d_useSinglePipe) {
287 pipefds[0] = d_singlePipe[0];
288 pipefds[1] = d_singlePipe[1];
289 }
290 else {
291 if (pipe(pipefds) < 0) {
292 errlog("Error creating the TCP thread communication pipe: %s", stringerror());
293 return;
294 }
295
296 if (!setNonBlocking(pipefds[0])) {
297 int err = errno;
298 close(pipefds[0]);
299 close(pipefds[1]);
300 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
301 return;
302 }
303
304 if (!setNonBlocking(pipefds[1])) {
305 int err = errno;
306 close(pipefds[0]);
307 close(pipefds[1]);
308 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
309 return;
310 }
311 }
312
313 {
314 std::lock_guard<std::mutex> lock(d_mutex);
315
316 if (d_numthreads >= d_tcpclientthreads.capacity()) {
317 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
318 if (!d_useSinglePipe) {
319 close(pipefds[0]);
320 close(pipefds[1]);
321 }
322 return;
323 }
324
325 try {
326 thread t1(tcpClientThread, pipefds[0]);
327 t1.detach();
328 }
329 catch(const std::runtime_error& e) {
330 /* the thread creation failed, don't leak */
331 errlog("Error creating a TCP thread: %s", e.what());
332 if (!d_useSinglePipe) {
333 close(pipefds[0]);
334 close(pipefds[1]);
335 }
336 return;
337 }
338
339 d_tcpclientthreads.push_back(pipefds[1]);
340 ++d_numthreads;
341 }
342 }
343
344 static void cleanupClosedTCPConnections()
345 {
346 for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
347 for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
348 if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
349 ++connIt;
350 }
351 else {
352 connIt = dsIt->second.erase(connIt);
353 }
354 }
355
356 if (!dsIt->second.empty()) {
357 ++dsIt;
358 }
359 else {
360 dsIt = t_downstreamConnections.erase(dsIt);
361 }
362 }
363 }
364
365 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
366 Updates pos everytime a successful read occurs,
367 throws an std::runtime_error in case of IO error,
368 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
369 would block.
370 */
371 // XXX could probably be implemented as a TCPIOHandler
372 IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
373 {
374 if (buffer.size() < (pos + toRead)) {
375 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
376 }
377
378 size_t got = 0;
379 do {
380 ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
381 if (res == 0) {
382 throw runtime_error("EOF while reading message");
383 }
384 if (res < 0) {
385 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
386 return IOState::NeedRead;
387 }
388 else {
389 throw std::runtime_error(std::string("Error while reading message: ") + stringerror());
390 }
391 }
392
393 pos += static_cast<size_t>(res);
394 got += static_cast<size_t>(res);
395 }
396 while (got < toRead);
397
398 return IOState::Done;
399 }
400
401 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
402
403 class TCPClientThreadData
404 {
405 public:
406 TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
407 {
408 }
409
410 LocalHolders holders;
411 LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
412 std::unique_ptr<FDMultiplexer> mplexer{nullptr};
413 };
414
415 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
416
417 class IncomingTCPConnectionState
418 {
419 public:
420 IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
421 {
422 d_ids.origDest.reset();
423 d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
424 socklen_t socklen = d_ids.origDest.getSocklen();
425 if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
426 d_ids.origDest = d_ci.cs->local;
427 }
428 }
429
430 IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
431 IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
432
433 ~IncomingTCPConnectionState()
434 {
435 decrementTCPClientCount(d_ci.remote);
436 if (d_ci.cs != nullptr) {
437 struct timeval now;
438 gettimeofday(&now, nullptr);
439
440 auto diff = now - d_connectionStartTime;
441 d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
442 }
443
444 if (d_ds != nullptr) {
445 if (d_outstanding) {
446 --d_ds->outstanding;
447 d_outstanding = false;
448 }
449
450 if (d_downstreamConnection) {
451 try {
452 if (d_lastIOState == IOState::NeedRead) {
453 cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
454 d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
455 }
456 else if (d_lastIOState == IOState::NeedWrite) {
457 cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
458 d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
459 }
460 }
461 catch(const FDMultiplexerException& e) {
462 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
463 }
464 catch(const std::runtime_error& e) {
465 /* might be thrown by getHandle() */
466 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
467 }
468 }
469 }
470
471 try {
472 if (d_lastIOState == IOState::NeedRead) {
473 cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
474 d_threadData.mplexer->removeReadFD(d_ci.fd);
475 }
476 else if (d_lastIOState == IOState::NeedWrite) {
477 cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
478 d_threadData.mplexer->removeWriteFD(d_ci.fd);
479 }
480 }
481 catch(const FDMultiplexerException& e) {
482 vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
483 }
484 }
485
486 void resetForNewQuery()
487 {
488 d_buffer.resize(sizeof(uint16_t));
489 d_currentPos = 0;
490 d_querySize = 0;
491 d_responseSize = 0;
492 d_downstreamFailures = 0;
493 d_state = State::readingQuerySize;
494 d_lastIOState = IOState::Done;
495 d_selfGeneratedResponse = false;
496 }
497
498 boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
499 {
500 if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
501 return boost::none;
502 }
503
504 if (g_maxTCPConnectionDuration > 0) {
505 auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
506 if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
507 return now;
508 }
509 auto remaining = g_maxTCPConnectionDuration - elapsed;
510 if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
511 now.tv_sec += remaining;
512 return now;
513 }
514 }
515
516 now.tv_sec += g_tcpRecvTimeout;
517 return now;
518 }
519
520 boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
521 {
522 if (d_ds == nullptr) {
523 throw std::runtime_error("getBackendReadTTD() without any backend selected");
524 }
525 if (d_ds->tcpRecvTimeout == 0) {
526 return boost::none;
527 }
528
529 struct timeval res = now;
530 res.tv_sec += d_ds->tcpRecvTimeout;
531
532 return res;
533 }
534
535 boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
536 {
537 if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
538 return boost::none;
539 }
540
541 struct timeval res = now;
542
543 if (g_maxTCPConnectionDuration > 0) {
544 auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
545 if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
546 return res;
547 }
548 auto remaining = g_maxTCPConnectionDuration - elapsed;
549 if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
550 res.tv_sec += remaining;
551 return res;
552 }
553 }
554
555 res.tv_sec += g_tcpSendTimeout;
556 return res;
557 }
558
559 boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
560 {
561 if (d_ds == nullptr) {
562 throw std::runtime_error("getBackendReadTTD() called without any backend selected");
563 }
564 if (d_ds->tcpSendTimeout == 0) {
565 return boost::none;
566 }
567
568 struct timeval res = now;
569 res.tv_sec += d_ds->tcpSendTimeout;
570
571 return res;
572 }
573
574 bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
575 {
576 if (maxConnectionDuration) {
577 time_t curtime = now.tv_sec;
578 unsigned int elapsed = 0;
579 if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
580 elapsed = curtime - d_connectionStartTime.tv_sec;
581 }
582 if (elapsed >= maxConnectionDuration) {
583 return true;
584 }
585 d_remainingTime = maxConnectionDuration - elapsed;
586 }
587
588 return false;
589 }
590
591 void dump() const
592 {
593 static std::mutex s_mutex;
594
595 struct timeval now;
596 gettimeofday(&now, 0);
597
598 {
599 std::lock_guard<std::mutex> lock(s_mutex);
600 fprintf(stderr, "State is %p\n", this);
601 cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
602 cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
603 cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
604 if (d_state > State::doingHandshake) {
605 cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
606 }
607 if (d_state > State::readingQuerySize) {
608 cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
609 }
610 if (d_state > State::readingQuerySize) {
611 cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
612 }
613 if (d_state > State::readingQuery) {
614 cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
615 }
616 if (d_state > State::sendingQueryToBackend) {
617 cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
618 }
619 if (d_state > State::readingResponseFromBackend) {
620 cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
621 }
622 }
623 }
624
625 enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
626
627 std::vector<uint8_t> d_buffer;
628 std::vector<uint8_t> d_responseBuffer;
629 TCPClientThreadData& d_threadData;
630 IDState d_ids;
631 ConnectionInfo d_ci;
632 TCPIOHandler d_handler;
633 std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
634 std::shared_ptr<DownstreamState> d_ds{nullptr};
635 dnsheader d_cleartextDH;
636 struct timeval d_connectionStartTime;
637 struct timeval d_handshakeDoneTime;
638 struct timeval d_firstQuerySizeReadTime;
639 struct timeval d_querySizeReadTime;
640 struct timeval d_queryReadTime;
641 struct timeval d_querySentTime;
642 struct timeval d_responseReadTime;
643 size_t d_currentPos{0};
644 size_t d_queriesCount{0};
645 unsigned int d_remainingTime{0};
646 uint16_t d_querySize{0};
647 uint16_t d_responseSize{0};
648 uint16_t d_downstreamFailures{0};
649 State d_state{State::doingHandshake};
650 IOState d_lastIOState{IOState::Done};
651 bool d_readingFirstQuery{true};
652 bool d_outstanding{false};
653 bool d_firstResponsePacket{true};
654 bool d_isXFR{false};
655 bool d_xfrStarted{false};
656 bool d_selfGeneratedResponse{false};
657 };
658
659 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
660 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
661 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
662 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
663
664 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
665 {
666 handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
667
668 if (state->d_isXFR && state->d_downstreamConnection) {
669 /* we need to resume reading from the backend! */
670 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
671 state->d_currentPos = 0;
672 handleDownstreamIO(state, now);
673 return;
674 }
675
676 if (state->d_selfGeneratedResponse == false && state->d_ds) {
677 /* if we have no downstream server selected, this was a self-answered response
678 but cache hits have a selected server as well, so be careful */
679 struct timespec answertime;
680 gettime(&answertime);
681 double udiff = state->d_ids.sentTime.udiff();
682 g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote);
683 vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
684 }
685
686 switch (state->d_cleartextDH.rcode) {
687 case RCode::NXDomain:
688 ++g_stats.frontendNXDomain;
689 break;
690 case RCode::ServFail:
691 ++g_stats.servfailResponses;
692 ++g_stats.frontendServFail;
693 break;
694 case RCode::NoError:
695 ++g_stats.frontendNoError;
696 break;
697 }
698
699 if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
700 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
701 return;
702 }
703
704 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
705 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
706 return;
707 }
708
709 state->resetForNewQuery();
710
711 handleIO(state, now);
712 }
713
714 static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
715 {
716 state->d_state = IncomingTCPConnectionState::State::sendingResponse;
717 const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
718 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
719 that could occur if we had to deal with the size during the processing,
720 especially alignment issues */
721 state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
722
723 state->d_currentPos = 0;
724
725 handleIO(state, now);
726 }
727
728 static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
729 {
730 if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) {
731 return;
732 }
733
734 auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
735 unsigned int consumed;
736 if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
737 return;
738 }
739 state->d_firstResponsePacket = false;
740
741 if (state->d_outstanding) {
742 --state->d_ds->outstanding;
743 state->d_outstanding = false;
744 }
745
746 auto dh = reinterpret_cast<struct dnsheader*>(response);
747 uint16_t addRoom = 0;
748 DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
749 if (dr.dnsCryptQuery) {
750 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
751 }
752
753 memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH));
754
755 std::vector<uint8_t> rewrittenResponse;
756 size_t responseSize = state->d_responseBuffer.size();
757 if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
758 return;
759 }
760
761 if (!rewrittenResponse.empty()) {
762 /* responseSize has been updated as well but we don't really care since it will match
763 the capacity of rewrittenResponse anyway */
764 state->d_responseBuffer = std::move(rewrittenResponse);
765 state->d_responseSize = state->d_responseBuffer.size();
766 } else {
767 /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
768 state->d_responseBuffer.resize(state->d_responseSize);
769 }
770
771 if (state->d_isXFR && !state->d_xfrStarted) {
772 /* don't bother parsing the content of the response for now */
773 state->d_xfrStarted = true;
774 ++g_stats.responses;
775 ++state->d_ci.cs->responses;
776 ++state->d_ds->responses;
777 }
778
779 if (!state->d_isXFR) {
780 ++g_stats.responses;
781 ++state->d_ci.cs->responses;
782 ++state->d_ds->responses;
783 }
784
785 sendResponse(state, now);
786 }
787
788 static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
789 {
790 auto ds = state->d_ds;
791 state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
792 state->d_currentPos = 0;
793 state->d_firstResponsePacket = true;
794 state->d_downstreamConnection.reset();
795
796 if (state->d_xfrStarted) {
797 /* sorry, but we are not going to resume a XFR if we have already sent some packets
798 to the client */
799 return;
800 }
801
802 if (state->d_downstreamFailures < state->d_ds->retries) {
803 try {
804 state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
805 }
806 catch (const std::runtime_error& e) {
807 state->d_downstreamConnection.reset();
808 }
809 }
810
811 if (!state->d_downstreamConnection) {
812 ++ds->tcpGaveUp;
813 ++state->d_ci.cs->tcpGaveUp;
814 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
815 return;
816 }
817
818 vinfolog("Got query for %s|%s from %s (%s), relayed to %s", state->d_ids.qname.toLogString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), ds->getName());
819
820 handleDownstreamIO(state, now);
821 return;
822 }
823
824 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
825 {
826 if (state->d_querySize < sizeof(dnsheader)) {
827 ++g_stats.nonCompliantQueries;
828 return;
829 }
830
831 state->d_readingFirstQuery = false;
832 ++state->d_queriesCount;
833 ++state->d_ci.cs->queries;
834 ++g_stats.queries;
835
836 if (state->d_handler.isTLS()) {
837 auto tlsVersion = state->d_handler.getTLSVersion();
838 switch (tlsVersion) {
839 case LibsslTLSVersion::TLS10:
840 ++state->d_ci.cs->tls10queries;
841 break;
842 case LibsslTLSVersion::TLS11:
843 ++state->d_ci.cs->tls11queries;
844 break;
845 case LibsslTLSVersion::TLS12:
846 ++state->d_ci.cs->tls12queries;
847 break;
848 case LibsslTLSVersion::TLS13:
849 ++state->d_ci.cs->tls13queries;
850 break;
851 default:
852 ++state->d_ci.cs->tlsUnknownqueries;
853 }
854 }
855
856 /* we need an accurate ("real") value for the response and
857 to store into the IDS, but not for insertion into the
858 rings for example */
859 struct timespec queryRealTime;
860 gettime(&queryRealTime, true);
861
862 auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
863 std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
864 auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
865 if (dnsCryptResponse) {
866 state->d_responseBuffer = std::move(*dnsCryptResponse);
867 state->d_responseSize = state->d_responseBuffer.size();
868 sendResponse(state, now);
869 return;
870 }
871
872 const auto& dh = reinterpret_cast<dnsheader*>(query);
873 if (!checkQueryHeaders(dh)) {
874 return;
875 }
876
877 uint16_t qtype, qclass;
878 unsigned int consumed = 0;
879 DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
880 DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
881 dq.dnsCryptQuery = std::move(dnsCryptQuery);
882 dq.sni = state->d_handler.getServerNameIndication();
883
884 state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
885 if (state->d_isXFR) {
886 dq.skipCache = true;
887 }
888
889 state->d_ds.reset();
890 auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
891
892 if (result == ProcessQueryResult::Drop) {
893 return;
894 }
895
896 if (result == ProcessQueryResult::SendAnswer) {
897 state->d_selfGeneratedResponse = true;
898 state->d_buffer.resize(dq.len);
899 state->d_responseBuffer = std::move(state->d_buffer);
900 state->d_responseSize = state->d_responseBuffer.size();
901 sendResponse(state, now);
902 return;
903 }
904
905 if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
906 return;
907 }
908
909 setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
910
911 const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
912 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
913 that could occur if we had to deal with the size during the processing,
914 especially alignment issues */
915 state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
916 dq.len = dq.len + 2;
917 dq.dh = reinterpret_cast<dnsheader*>(&state->d_buffer.at(0));
918 dq.size = state->d_buffer.size();
919
920 if (dq.addProxyProtocol && state->d_ds->useProxyProtocol) {
921 addProxyProtocol(dq);
922 }
923
924 state->d_buffer.resize(dq.len);
925
926 sendQueryToBackend(state, now);
927 }
928
929 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
930 {
931 //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
932
933 if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
934 state->d_threadData.mplexer->removeReadFD(fd);
935 //cerr<<__func__<<": remove read FD "<<fd<<endl;
936 state->d_lastIOState = IOState::Done;
937 }
938 else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
939 state->d_threadData.mplexer->removeWriteFD(fd);
940 //cerr<<__func__<<": remove write FD "<<fd<<endl;
941 state->d_lastIOState = IOState::Done;
942 }
943
944 if (iostate == IOState::NeedRead) {
945 if (state->d_lastIOState == IOState::NeedRead) {
946 if (ttd) {
947 /* let's update the TTD ! */
948 state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
949 }
950 return;
951 }
952
953 state->d_lastIOState = IOState::NeedRead;
954 //cerr<<__func__<<": add read FD "<<fd<<endl;
955 state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
956 }
957 else if (iostate == IOState::NeedWrite) {
958 if (state->d_lastIOState == IOState::NeedWrite) {
959 return;
960 }
961
962 state->d_lastIOState = IOState::NeedWrite;
963 //cerr<<__func__<<": add write FD "<<fd<<endl;
964 state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
965 }
966 else if (iostate == IOState::Done) {
967 state->d_lastIOState = IOState::Done;
968 }
969 }
970
971 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
972 {
973 if (state->d_downstreamConnection == nullptr) {
974 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
975 }
976
977 int fd = state->d_downstreamConnection->getHandle();
978 IOState iostate = IOState::Done;
979 bool connectionDied = false;
980
981 try {
982 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
983 int socketFlags = 0;
984 #ifdef MSG_FASTOPEN
985 if (state->d_downstreamConnection->isFastOpenEnabled()) {
986 socketFlags |= MSG_FASTOPEN;
987 }
988 #endif /* MSG_FASTOPEN */
989
990 size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
991 if (sent == state->d_buffer.size()) {
992 /* request sent ! */
993 state->d_downstreamConnection->incQueries();
994 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
995 state->d_currentPos = 0;
996 state->d_querySentTime = now;
997 iostate = IOState::NeedRead;
998 if (!state->d_isXFR && !state->d_outstanding) {
999 /* don't bother with the outstanding count for XFR queries */
1000 ++state->d_ds->outstanding;
1001 state->d_outstanding = true;
1002 }
1003 }
1004 else {
1005 state->d_currentPos += sent;
1006 iostate = IOState::NeedWrite;
1007 /* disable fast open on partial write */
1008 state->d_downstreamConnection->disableFastOpen();
1009 }
1010 }
1011
1012 if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
1013 // then we need to allocate a new buffer (new because we might need to re-send the query if the
1014 // backend dies on us
1015 // We also might need to read and send to the client more than one response in case of XFR (yeah!)
1016 // should very likely be a TCPIOHandler d_downstreamHandler
1017 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
1018 if (iostate == IOState::Done) {
1019 state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
1020 state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
1021 state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
1022 state->d_currentPos = 0;
1023 }
1024 }
1025
1026 if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
1027 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
1028 if (iostate == IOState::Done) {
1029 handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
1030
1031 if (state->d_isXFR) {
1032 /* Don't reuse the TCP connection after an {A,I}XFR */
1033 /* but don't reset it either, we will need to read more messages */
1034 }
1035 else {
1036 releaseDownstreamConnection(std::move(state->d_downstreamConnection));
1037 }
1038 fd = -1;
1039
1040 state->d_responseReadTime = now;
1041 try {
1042 handleResponse(state, now);
1043 }
1044 catch (const std::exception& e) {
1045 vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what());
1046 }
1047 return;
1048 }
1049 }
1050
1051 if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
1052 state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
1053 state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
1054 vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
1055 }
1056 }
1057 catch(const std::exception& e) {
1058 /* most likely an EOF because the other end closed the connection,
1059 but it might also be a real IO error or something else.
1060 Let's just drop the connection
1061 */
1062 vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
1063 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
1064 ++state->d_ds->tcpDiedSendingQuery;
1065 }
1066 else {
1067 ++state->d_ds->tcpDiedReadingResponse;
1068 }
1069
1070 /* don't increase this counter when reusing connections */
1071 if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) {
1072 ++state->d_downstreamFailures;
1073 }
1074
1075 if (state->d_outstanding) {
1076 state->d_outstanding = false;
1077
1078 if (state->d_ds != nullptr) {
1079 --state->d_ds->outstanding;
1080 }
1081 }
1082 /* remove this FD from the IO multiplexer */
1083 iostate = IOState::Done;
1084 connectionDied = true;
1085 }
1086
1087 if (iostate == IOState::Done) {
1088 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
1089 }
1090 else {
1091 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
1092 }
1093
1094 if (connectionDied) {
1095 sendQueryToBackend(state, now);
1096 }
1097 }
1098
1099 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1100 {
1101 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1102 if (state->d_downstreamConnection == nullptr) {
1103 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
1104 }
1105 if (fd != state->d_downstreamConnection->getHandle()) {
1106 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
1107 }
1108
1109 struct timeval now;
1110 gettimeofday(&now, 0);
1111 handleDownstreamIO(state, now);
1112 }
1113
1114 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
1115 {
1116 int fd = state->d_ci.fd;
1117 IOState iostate = IOState::Done;
1118
1119 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
1120 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
1121 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1122 return;
1123 }
1124
1125 try {
1126 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
1127 iostate = state->d_handler.tryHandshake();
1128 if (iostate == IOState::Done) {
1129 if (state->d_handler.isTLS()) {
1130 if (!state->d_handler.hasTLSSessionBeenResumed()) {
1131 ++state->d_ci.cs->tlsNewSessions;
1132 }
1133 else {
1134 ++state->d_ci.cs->tlsResumptions;
1135 }
1136 if (state->d_handler.getResumedFromInactiveTicketKey()) {
1137 ++state->d_ci.cs->tlsInactiveTicketKey;
1138 }
1139 if (state->d_handler.getUnknownTicketKey()) {
1140 ++state->d_ci.cs->tlsUnknownTicketKey;
1141 }
1142 }
1143
1144 state->d_handshakeDoneTime = now;
1145 state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
1146 }
1147 }
1148
1149 if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
1150 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
1151 if (iostate == IOState::Done) {
1152 state->d_state = IncomingTCPConnectionState::State::readingQuery;
1153 state->d_querySizeReadTime = now;
1154 if (state->d_queriesCount == 0) {
1155 state->d_firstQuerySizeReadTime = now;
1156 }
1157 state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
1158 if (state->d_querySize < sizeof(dnsheader)) {
1159 /* go away */
1160 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1161 return;
1162 }
1163
1164 /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
1165 or to add ECS without allocating a new buffer */
1166 state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
1167 state->d_currentPos = 0;
1168 }
1169 }
1170
1171 if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1172 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
1173 if (iostate == IOState::Done) {
1174 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1175 handleQuery(state, now);
1176 return;
1177 }
1178 }
1179
1180 if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1181 iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
1182 if (iostate == IOState::Done) {
1183 handleResponseSent(state, now);
1184 return;
1185 }
1186 }
1187
1188 if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
1189 state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
1190 state->d_state != IncomingTCPConnectionState::State::readingQuery &&
1191 state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
1192 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
1193 }
1194 }
1195 catch(const std::exception& e) {
1196 /* most likely an EOF because the other end closed the connection,
1197 but it might also be a real IO error or something else.
1198 Let's just drop the connection
1199 */
1200 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
1201 state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
1202 state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1203 ++state->d_ci.cs->tcpDiedReadingQuery;
1204 }
1205 else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1206 ++state->d_ci.cs->tcpDiedSendingResponse;
1207 }
1208
1209 if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
1210 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
1211 }
1212 else {
1213 vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
1214 }
1215 /* remove this FD from the IO multiplexer */
1216 iostate = IOState::Done;
1217 }
1218
1219 if (iostate == IOState::Done) {
1220 handleNewIOState(state, iostate, fd, handleIOCallback);
1221 }
1222 else {
1223 handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
1224 }
1225 }
1226
1227 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1228 {
1229 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1230 if (fd != state->d_ci.fd) {
1231 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
1232 }
1233 struct timeval now;
1234 gettimeofday(&now, 0);
1235
1236 handleIO(state, now);
1237 }
1238
1239 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1240 {
1241 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1242
1243 ConnectionInfo* citmp{nullptr};
1244
1245 ssize_t got = read(pipefd, &citmp, sizeof(citmp));
1246 if (got == 0) {
1247 throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1248 }
1249 else if (got == -1) {
1250 if (errno == EAGAIN || errno == EINTR) {
1251 return;
1252 }
1253 throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
1254 }
1255 else if (got != sizeof(citmp)) {
1256 throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1257 }
1258
1259 try {
1260 g_tcpclientthreads->decrementQueuedCount();
1261
1262 struct timeval now;
1263 gettimeofday(&now, 0);
1264 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1265 delete citmp;
1266 citmp = nullptr;
1267
1268 /* let's update the remaining time */
1269 state->d_remainingTime = g_maxTCPConnectionDuration;
1270
1271 handleIO(state, now);
1272 }
1273 catch(...) {
1274 delete citmp;
1275 citmp = nullptr;
1276 throw;
1277 }
1278 }
1279
1280 void tcpClientThread(int pipefd)
1281 {
1282 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1283 from that point on */
1284
1285 setThreadName("dnsdist/tcpClie");
1286
1287 TCPClientThreadData data;
1288
1289 data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
1290 struct timeval now;
1291 gettimeofday(&now, 0);
1292 time_t lastTCPCleanup = now.tv_sec;
1293 time_t lastTimeoutScan = now.tv_sec;
1294
1295 for (;;) {
1296 data.mplexer->run(&now);
1297
1298 if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
1299 cleanupClosedTCPConnections();
1300 lastTCPCleanup = now.tv_sec;
1301 }
1302
1303 if (now.tv_sec > lastTimeoutScan) {
1304 lastTimeoutScan = now.tv_sec;
1305 auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1306 for(const auto& conn : expiredReadConns) {
1307 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1308 if (conn.first == state->d_ci.fd) {
1309 vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1310 ++state->d_ci.cs->tcpClientTimeouts;
1311 }
1312 else if (state->d_ds) {
1313 vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
1314 ++state->d_ci.cs->tcpDownstreamTimeouts;
1315 ++state->d_ds->tcpReadTimeouts;
1316 }
1317 data.mplexer->removeReadFD(conn.first);
1318 state->d_lastIOState = IOState::Done;
1319 }
1320
1321 auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1322 for(const auto& conn : expiredWriteConns) {
1323 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1324 if (conn.first == state->d_ci.fd) {
1325 vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1326 ++state->d_ci.cs->tcpClientTimeouts;
1327 }
1328 else if (state->d_ds) {
1329 vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
1330 ++state->d_ci.cs->tcpDownstreamTimeouts;
1331 ++state->d_ds->tcpWriteTimeouts;
1332 }
1333 data.mplexer->removeWriteFD(conn.first);
1334 state->d_lastIOState = IOState::Done;
1335 }
1336 }
1337 }
1338 }
1339
1340 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1341 they will hand off to worker threads & spawn more of them if required
1342 */
1343 void tcpAcceptorThread(void* p)
1344 {
1345 setThreadName("dnsdist/tcpAcce");
1346 ClientState* cs = (ClientState*) p;
1347 bool tcpClientCountIncremented = false;
1348 ComboAddress remote;
1349 remote.sin4.sin_family = cs->local.sin4.sin_family;
1350
1351 if(!g_tcpclientthreads->hasReachedMaxThreads()) {
1352 g_tcpclientthreads->addTCPClientThread();
1353 }
1354
1355 auto acl = g_ACL.getLocal();
1356 for(;;) {
1357 bool queuedCounterIncremented = false;
1358 std::unique_ptr<ConnectionInfo> ci;
1359 tcpClientCountIncremented = false;
1360 try {
1361 socklen_t remlen = remote.getSocklen();
1362 ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
1363 #ifdef HAVE_ACCEPT4
1364 ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1365 #else
1366 ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1367 #endif
1368 ++cs->tcpCurrentConnections;
1369
1370 if(ci->fd < 0) {
1371 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
1372 }
1373
1374 if(!acl->match(remote)) {
1375 ++g_stats.aclDrops;
1376 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1377 continue;
1378 }
1379
1380 #ifndef HAVE_ACCEPT4
1381 if (!setNonBlocking(ci->fd)) {
1382 continue;
1383 }
1384 #endif
1385 setTCPNoDelay(ci->fd); // disable NAGLE
1386 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1387 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1388 continue;
1389 }
1390
1391 if (g_maxTCPConnectionsPerClient) {
1392 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
1393
1394 if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
1395 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1396 continue;
1397 }
1398 tcpClientsCount[remote]++;
1399 tcpClientCountIncremented = true;
1400 }
1401
1402 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1403
1404 ci->remote = remote;
1405 int pipe = g_tcpclientthreads->getThread();
1406 if (pipe >= 0) {
1407 queuedCounterIncremented = true;
1408 auto tmp = ci.release();
1409 try {
1410 writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
1411 }
1412 catch(...) {
1413 delete tmp;
1414 tmp = nullptr;
1415 throw;
1416 }
1417 }
1418 else {
1419 g_tcpclientthreads->decrementQueuedCount();
1420 queuedCounterIncremented = false;
1421 if(tcpClientCountIncremented) {
1422 decrementTCPClientCount(remote);
1423 }
1424 }
1425 }
1426 catch(const std::exception& e) {
1427 errlog("While reading a TCP question: %s", e.what());
1428 if(tcpClientCountIncremented) {
1429 decrementTCPClientCount(remote);
1430 }
1431 if (queuedCounterIncremented) {
1432 g_tcpclientthreads->decrementQueuedCount();
1433 }
1434 }
1435 catch(...){}
1436 }
1437 }