]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #7908 from omoerbeek/rec-4.1.14-changelog
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-rings.hh"
25 #include "dnsdist-xpf.hh"
26
27 #include "dnsparser.hh"
28 #include "ednsoptions.hh"
29 #include "dolog.hh"
30 #include "lock.hh"
31 #include "gettime.hh"
32 #include "tcpiohandler.hh"
33 #include "threadname.hh"
34 #include <thread>
35 #include <atomic>
36 #include <netinet/tcp.h>
37
38 #include "sstuff.hh"
39
40 using std::thread;
41 using std::atomic;
42
43 /* TCP: the grand design.
44 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
45 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
46 we will not go there.
47
48 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
49 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
50 to guarantee performance.
51
52 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
53 So whenever an answer comes in, we know where it needs to go.
54
55 Let's start naively.
56 */
57
58 static std::mutex tcpClientsCountMutex;
59 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
60 static const size_t g_maxCachedConnectionsPerDownstream = 20;
61 uint64_t g_maxTCPQueuedConnections{1000};
62 size_t g_maxTCPQueriesPerConn{0};
63 size_t g_maxTCPConnectionDuration{0};
64 size_t g_maxTCPConnectionsPerClient{0};
65 uint16_t g_downstreamTCPCleanupInterval{60};
66 bool g_useTCPSinglePipe{false};
67
68 static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
69 {
70 std::unique_ptr<Socket> result;
71
72 do {
73 vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
74 result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
75 try {
76 if (!IsAnyAddress(ds->sourceAddr)) {
77 SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
78 #ifdef IP_BIND_ADDRESS_NO_PORT
79 if (ds->ipBindAddrNoPort) {
80 SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
81 }
82 #endif
83 result->bind(ds->sourceAddr, false);
84 }
85 result->setNonBlocking();
86 #ifdef MSG_FASTOPEN
87 if (!ds->tcpFastOpen) {
88 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
89 }
90 #else
91 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
92 #endif /* MSG_FASTOPEN */
93 return result;
94 }
95 catch(const std::runtime_error& e) {
96 vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
97 downstreamFailures++;
98 if (downstreamFailures > ds->retries) {
99 throw;
100 }
101 }
102 } while(downstreamFailures <= ds->retries);
103
104 return nullptr;
105 }
106
107 class TCPConnectionToBackend
108 {
109 public:
110 TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now)
111 {
112 d_socket = setupTCPDownstream(d_ds, downstreamFailures);
113 ++d_ds->tcpCurrentConnections;
114 }
115
116 ~TCPConnectionToBackend()
117 {
118 if (d_ds && d_socket) {
119 --d_ds->tcpCurrentConnections;
120 struct timeval now;
121 gettimeofday(&now, nullptr);
122
123 auto diff = now - d_connectionStartTime;
124 d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
125 }
126 }
127
128 int getHandle() const
129 {
130 if (!d_socket) {
131 throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
132 }
133
134 return d_socket->getHandle();
135 }
136
137 const ComboAddress& getRemote() const
138 {
139 return d_ds->remote;
140 }
141
142 bool isFresh() const
143 {
144 return d_fresh;
145 }
146
147 void incQueries()
148 {
149 ++d_queries;
150 }
151
152 void setReused()
153 {
154 d_fresh = false;
155 }
156
157 private:
158 std::unique_ptr<Socket> d_socket{nullptr};
159 std::shared_ptr<DownstreamState> d_ds{nullptr};
160 struct timeval d_connectionStartTime;
161 uint64_t d_queries{0};
162 bool d_fresh{true};
163 };
164
165 static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
166
167 static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
168 {
169 std::unique_ptr<TCPConnectionToBackend> result;
170
171 const auto& it = t_downstreamConnections.find(ds->remote);
172 if (it != t_downstreamConnections.end()) {
173 auto& list = it->second;
174 if (!list.empty()) {
175 result = std::move(list.front());
176 list.pop_front();
177 result->setReused();
178 return result;
179 }
180 }
181
182 return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
183 }
184
185 static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
186 {
187 if (conn == nullptr) {
188 return;
189 }
190
191 const auto& remote = conn->getRemote();
192 const auto& it = t_downstreamConnections.find(remote);
193 if (it != t_downstreamConnections.end()) {
194 auto& list = it->second;
195 if (list.size() >= g_maxCachedConnectionsPerDownstream) {
196 /* too many connections queued already */
197 conn.reset();
198 return;
199 }
200 list.push_back(std::move(conn));
201 }
202 else {
203 t_downstreamConnections[remote].push_back(std::move(conn));
204 }
205 }
206
207 struct ConnectionInfo
208 {
209 ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
210 {
211 }
212 ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
213 {
214 rhs.cs = nullptr;
215 rhs.fd = -1;
216 }
217
218 ConnectionInfo(const ConnectionInfo& rhs) = delete;
219 ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
220
221 ConnectionInfo& operator=(ConnectionInfo&& rhs)
222 {
223 remote = rhs.remote;
224 cs = rhs.cs;
225 rhs.cs = nullptr;
226 fd = rhs.fd;
227 rhs.fd = -1;
228 return *this;
229 }
230
231 ~ConnectionInfo()
232 {
233 if (fd != -1) {
234 close(fd);
235 fd = -1;
236 }
237 if (cs) {
238 --cs->tcpCurrentConnections;
239 }
240 }
241
242 ComboAddress remote;
243 ClientState* cs{nullptr};
244 int fd{-1};
245 };
246
247 void tcpClientThread(int pipefd);
248
249 static void decrementTCPClientCount(const ComboAddress& client)
250 {
251 if (g_maxTCPConnectionsPerClient) {
252 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
253 tcpClientsCount[client]--;
254 if (tcpClientsCount[client] == 0) {
255 tcpClientsCount.erase(client);
256 }
257 }
258 }
259
260 void TCPClientCollection::addTCPClientThread()
261 {
262 int pipefds[2] = { -1, -1};
263
264 vinfolog("Adding TCP Client thread");
265
266 if (d_useSinglePipe) {
267 pipefds[0] = d_singlePipe[0];
268 pipefds[1] = d_singlePipe[1];
269 }
270 else {
271 if (pipe(pipefds) < 0) {
272 errlog("Error creating the TCP thread communication pipe: %s", strerror(errno));
273 return;
274 }
275
276 if (!setNonBlocking(pipefds[0])) {
277 close(pipefds[0]);
278 close(pipefds[1]);
279 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
280 return;
281 }
282
283 if (!setNonBlocking(pipefds[1])) {
284 close(pipefds[0]);
285 close(pipefds[1]);
286 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
287 return;
288 }
289 }
290
291 {
292 std::lock_guard<std::mutex> lock(d_mutex);
293
294 if (d_numthreads >= d_tcpclientthreads.capacity()) {
295 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
296 if (!d_useSinglePipe) {
297 close(pipefds[0]);
298 close(pipefds[1]);
299 }
300 return;
301 }
302
303 try {
304 thread t1(tcpClientThread, pipefds[0]);
305 t1.detach();
306 }
307 catch(const std::runtime_error& e) {
308 /* the thread creation failed, don't leak */
309 errlog("Error creating a TCP thread: %s", e.what());
310 if (!d_useSinglePipe) {
311 close(pipefds[0]);
312 close(pipefds[1]);
313 }
314 return;
315 }
316
317 d_tcpclientthreads.push_back(pipefds[1]);
318 }
319
320 ++d_numthreads;
321 }
322
323 static void cleanupClosedTCPConnections()
324 {
325 for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
326 for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
327 if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
328 ++connIt;
329 }
330 else {
331 connIt = dsIt->second.erase(connIt);
332 }
333 }
334
335 if (!dsIt->second.empty()) {
336 ++dsIt;
337 }
338 else {
339 dsIt = t_downstreamConnections.erase(dsIt);
340 }
341 }
342 }
343
344 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
345 Updates pos everytime a successful read occurs,
346 throws an std::runtime_error in case of IO error,
347 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
348 would block.
349 */
350 // XXX could probably be implemented as a TCPIOHandler
351 IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
352 {
353 if (buffer.size() < (pos + toRead)) {
354 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
355 }
356
357 size_t got = 0;
358 do {
359 ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
360 if (res == 0) {
361 throw runtime_error("EOF while reading message");
362 }
363 if (res < 0) {
364 if (errno == EAGAIN || errno == EWOULDBLOCK) {
365 return IOState::NeedRead;
366 }
367 else {
368 throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
369 }
370 }
371
372 pos += static_cast<size_t>(res);
373 got += static_cast<size_t>(res);
374 }
375 while (got < toRead);
376
377 return IOState::Done;
378 }
379
380 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
381
382 class TCPClientThreadData
383 {
384 public:
385 TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
386 {
387 }
388
389 LocalHolders holders;
390 LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
391 std::unique_ptr<FDMultiplexer> mplexer{nullptr};
392 };
393
394 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
395
396 class IncomingTCPConnectionState
397 {
398 public:
399 IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
400 {
401 d_ids.origDest.reset();
402 d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
403 socklen_t socklen = d_ids.origDest.getSocklen();
404 if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
405 d_ids.origDest = d_ci.cs->local;
406 }
407 }
408
409 IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
410 IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
411
412 ~IncomingTCPConnectionState()
413 {
414 decrementTCPClientCount(d_ci.remote);
415 if (d_ci.cs != nullptr) {
416 struct timeval now;
417 gettimeofday(&now, nullptr);
418
419 auto diff = now - d_connectionStartTime;
420 d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
421 }
422
423 if (d_ds != nullptr) {
424 if (d_outstanding) {
425 --d_ds->outstanding;
426 d_outstanding = false;
427 }
428
429 if (d_downstreamConnection) {
430 try {
431 if (d_lastIOState == IOState::NeedRead) {
432 cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
433 d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
434 }
435 else if (d_lastIOState == IOState::NeedWrite) {
436 cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
437 d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
438 }
439 }
440 catch(const FDMultiplexerException& e) {
441 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
442 }
443 catch(const std::runtime_error& e) {
444 /* might be thrown by getHandle() */
445 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
446 }
447 }
448 }
449
450 try {
451 if (d_lastIOState == IOState::NeedRead) {
452 cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
453 d_threadData.mplexer->removeReadFD(d_ci.fd);
454 }
455 else if (d_lastIOState == IOState::NeedWrite) {
456 cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
457 d_threadData.mplexer->removeWriteFD(d_ci.fd);
458 }
459 }
460 catch(const FDMultiplexerException& e) {
461 vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
462 }
463 }
464
465 void resetForNewQuery()
466 {
467 d_buffer.resize(sizeof(uint16_t));
468 d_currentPos = 0;
469 d_querySize = 0;
470 d_responseSize = 0;
471 d_downstreamFailures = 0;
472 d_state = State::readingQuerySize;
473 d_lastIOState = IOState::Done;
474 }
475
476 boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
477 {
478 if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
479 return boost::none;
480 }
481
482 if (g_maxTCPConnectionDuration > 0) {
483 auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
484 if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
485 return now;
486 }
487 auto remaining = g_maxTCPConnectionDuration - elapsed;
488 if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
489 now.tv_sec += remaining;
490 return now;
491 }
492 }
493
494 now.tv_sec += g_tcpRecvTimeout;
495 return now;
496 }
497
498 boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
499 {
500 if (d_ds == nullptr) {
501 throw std::runtime_error("getBackendReadTTD() without any backend selected");
502 }
503 if (d_ds->tcpRecvTimeout == 0) {
504 return boost::none;
505 }
506
507 struct timeval res = now;
508 res.tv_sec += d_ds->tcpRecvTimeout;
509
510 return res;
511 }
512
513 boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
514 {
515 if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
516 return boost::none;
517 }
518
519 struct timeval res = now;
520
521 if (g_maxTCPConnectionDuration > 0) {
522 auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
523 if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
524 return res;
525 }
526 auto remaining = g_maxTCPConnectionDuration - elapsed;
527 if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
528 res.tv_sec += remaining;
529 return res;
530 }
531 }
532
533 res.tv_sec += g_tcpSendTimeout;
534 return res;
535 }
536
537 boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
538 {
539 if (d_ds == nullptr) {
540 throw std::runtime_error("getBackendReadTTD() called without any backend selected");
541 }
542 if (d_ds->tcpSendTimeout == 0) {
543 return boost::none;
544 }
545
546 struct timeval res = now;
547 res.tv_sec += d_ds->tcpSendTimeout;
548
549 return res;
550 }
551
552 bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
553 {
554 if (maxConnectionDuration) {
555 time_t curtime = now.tv_sec;
556 unsigned int elapsed = 0;
557 if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
558 elapsed = curtime - d_connectionStartTime.tv_sec;
559 }
560 if (elapsed >= maxConnectionDuration) {
561 return true;
562 }
563 d_remainingTime = maxConnectionDuration - elapsed;
564 }
565
566 return false;
567 }
568
569 void dump() const
570 {
571 static std::mutex s_mutex;
572
573 struct timeval now;
574 gettimeofday(&now, 0);
575
576 {
577 std::lock_guard<std::mutex> lock(s_mutex);
578 fprintf(stderr, "State is %p\n", this);
579 cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
580 cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
581 cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
582 if (d_state > State::doingHandshake) {
583 cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
584 }
585 if (d_state > State::readingQuerySize) {
586 cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
587 }
588 if (d_state > State::readingQuerySize) {
589 cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
590 }
591 if (d_state > State::readingQuery) {
592 cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
593 }
594 if (d_state > State::sendingQueryToBackend) {
595 cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
596 }
597 if (d_state > State::readingResponseFromBackend) {
598 cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
599 }
600 }
601 }
602
603 enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
604
605 std::vector<uint8_t> d_buffer;
606 std::vector<uint8_t> d_responseBuffer;
607 TCPClientThreadData& d_threadData;
608 IDState d_ids;
609 ConnectionInfo d_ci;
610 TCPIOHandler d_handler;
611 std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
612 std::shared_ptr<DownstreamState> d_ds{nullptr};
613 struct timeval d_connectionStartTime;
614 struct timeval d_handshakeDoneTime;
615 struct timeval d_firstQuerySizeReadTime;
616 struct timeval d_querySizeReadTime;
617 struct timeval d_queryReadTime;
618 struct timeval d_querySentTime;
619 struct timeval d_responseReadTime;
620 size_t d_currentPos{0};
621 size_t d_queriesCount{0};
622 unsigned int d_remainingTime{0};
623 uint16_t d_querySize{0};
624 uint16_t d_responseSize{0};
625 uint16_t d_downstreamFailures{0};
626 State d_state{State::doingHandshake};
627 IOState d_lastIOState{IOState::Done};
628 bool d_readingFirstQuery{true};
629 bool d_outstanding{false};
630 bool d_firstResponsePacket{true};
631 bool d_isXFR{false};
632 bool d_xfrStarted{false};
633 };
634
635 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
636 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
637 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
638 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
639
640 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
641 {
642 handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
643
644 if (state->d_isXFR && state->d_downstreamConnection) {
645 /* we need to resume reading from the backend! */
646 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
647 state->d_currentPos = 0;
648 handleDownstreamIO(state, now);
649 return;
650 }
651
652 if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
653 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
654 return;
655 }
656
657 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
658 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
659 return;
660 }
661
662 state->resetForNewQuery();
663
664 handleIO(state, now);
665 }
666
667 static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
668 {
669 state->d_state = IncomingTCPConnectionState::State::sendingResponse;
670 const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
671 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
672 that could occur if we had to deal with the size during the processing,
673 especially alignment issues */
674 state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
675
676 state->d_currentPos = 0;
677
678 handleIO(state, now);
679 }
680
681 static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
682 {
683 if (state->d_responseSize < sizeof(dnsheader)) {
684 return;
685 }
686
687 auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
688 unsigned int consumed;
689 if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
690 return;
691 }
692 state->d_firstResponsePacket = false;
693
694 if (state->d_outstanding) {
695 --state->d_ds->outstanding;
696 state->d_outstanding = false;
697 }
698
699 auto dh = reinterpret_cast<struct dnsheader*>(response);
700 uint16_t addRoom = 0;
701 DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
702 if (dr.dnsCryptQuery) {
703 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
704 }
705
706 dnsheader cleartextDH;
707 memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
708
709 std::vector<uint8_t> rewrittenResponse;
710 size_t responseSize = state->d_responseBuffer.size();
711 if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
712 return;
713 }
714
715 if (!rewrittenResponse.empty()) {
716 /* responseSize has been updated as well but we don't really care since it will match
717 the capacity of rewrittenResponse anyway */
718 state->d_responseBuffer = std::move(rewrittenResponse);
719 state->d_responseSize = state->d_responseBuffer.size();
720 } else {
721 /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
722 state->d_responseBuffer.resize(state->d_responseSize);
723 }
724
725 if (state->d_isXFR && !state->d_xfrStarted) {
726 /* don't bother parsing the content of the response for now */
727 state->d_xfrStarted = true;
728 }
729
730 sendResponse(state, now);
731
732 ++g_stats.responses;
733 struct timespec answertime;
734 gettime(&answertime);
735 double udiff = state->d_ids.sentTime.udiff();
736 g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
737 }
738
739 static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
740 {
741 auto ds = state->d_ds;
742 state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
743 state->d_currentPos = 0;
744 state->d_firstResponsePacket = true;
745 state->d_downstreamConnection.reset();
746
747 if (state->d_xfrStarted) {
748 /* sorry, but we are not going to resume a XFR if we have already sent some packets
749 to the client */
750 return;
751 }
752
753 while (state->d_downstreamFailures < state->d_ds->retries)
754 {
755 state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
756
757 if (!state->d_downstreamConnection) {
758 ++ds->tcpGaveUp;
759 ++state->d_ci.cs->tcpGaveUp;
760 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
761 return;
762 }
763
764 handleDownstreamIO(state, now);
765 return;
766 }
767
768 ++ds->tcpGaveUp;
769 ++state->d_ci.cs->tcpGaveUp;
770 vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
771 }
772
773 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
774 {
775 if (state->d_querySize < sizeof(dnsheader)) {
776 ++g_stats.nonCompliantQueries;
777 return;
778 }
779
780 state->d_readingFirstQuery = false;
781 ++state->d_queriesCount;
782 ++state->d_ci.cs->queries;
783 ++g_stats.queries;
784
785 /* we need an accurate ("real") value for the response and
786 to store into the IDS, but not for insertion into the
787 rings for example */
788 struct timespec queryRealTime;
789 gettime(&queryRealTime, true);
790
791 auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
792 std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
793 auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
794 if (dnsCryptResponse) {
795 state->d_responseBuffer = std::move(*dnsCryptResponse);
796 state->d_responseSize = state->d_responseBuffer.size();
797 sendResponse(state, now);
798 return;
799 }
800
801 const auto& dh = reinterpret_cast<dnsheader*>(query);
802 if (!checkQueryHeaders(dh)) {
803 return;
804 }
805
806 uint16_t qtype, qclass;
807 unsigned int consumed = 0;
808 DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
809 DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
810 dq.dnsCryptQuery = std::move(dnsCryptQuery);
811 dq.sni = state->d_handler.getServerNameIndication();
812
813 state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
814 if (state->d_isXFR) {
815 dq.skipCache = true;
816 }
817
818 state->d_ds.reset();
819 auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
820
821 if (result == ProcessQueryResult::Drop) {
822 return;
823 }
824
825 if (result == ProcessQueryResult::SendAnswer) {
826 state->d_buffer.resize(dq.len);
827 state->d_responseBuffer = std::move(state->d_buffer);
828 state->d_responseSize = state->d_responseBuffer.size();
829 sendResponse(state, now);
830 return;
831 }
832
833 if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
834 return;
835 }
836
837 state->d_buffer.resize(dq.len);
838 setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
839
840 const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
841 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
842 that could occur if we had to deal with the size during the processing,
843 especially alignment issues */
844 state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
845 sendQueryToBackend(state, now);
846 }
847
848 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
849 {
850 //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
851
852 if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
853 state->d_threadData.mplexer->removeReadFD(fd);
854 //cerr<<__func__<<": remove read FD "<<fd<<endl;
855 state->d_lastIOState = IOState::Done;
856 }
857 else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
858 state->d_threadData.mplexer->removeWriteFD(fd);
859 //cerr<<__func__<<": remove write FD "<<fd<<endl;
860 state->d_lastIOState = IOState::Done;
861 }
862
863 if (iostate == IOState::NeedRead) {
864 if (state->d_lastIOState == IOState::NeedRead) {
865 if (ttd) {
866 /* let's update the TTD ! */
867 state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
868 }
869 return;
870 }
871
872 state->d_lastIOState = IOState::NeedRead;
873 //cerr<<__func__<<": add read FD "<<fd<<endl;
874 state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
875 }
876 else if (iostate == IOState::NeedWrite) {
877 if (state->d_lastIOState == IOState::NeedWrite) {
878 return;
879 }
880
881 state->d_lastIOState = IOState::NeedWrite;
882 //cerr<<__func__<<": add write FD "<<fd<<endl;
883 state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
884 }
885 else if (iostate == IOState::Done) {
886 state->d_lastIOState = IOState::Done;
887 }
888 }
889
890 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
891 {
892 if (state->d_downstreamConnection == nullptr) {
893 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
894 }
895
896 int fd = state->d_downstreamConnection->getHandle();
897 IOState iostate = IOState::Done;
898 bool connectionDied = false;
899
900 try {
901 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
902 int socketFlags = 0;
903 #ifdef MSG_FASTOPEN
904 if (state->d_ds->tcpFastOpen && state->d_downstreamConnection->isFresh()) {
905 socketFlags |= MSG_FASTOPEN;
906 }
907 #endif /* MSG_FASTOPEN */
908
909 size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
910 if (sent == state->d_buffer.size()) {
911 /* request sent ! */
912 state->d_downstreamConnection->incQueries();
913 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
914 state->d_currentPos = 0;
915 state->d_querySentTime = now;
916 iostate = IOState::NeedRead;
917 if (!state->d_isXFR) {
918 /* don't bother with the outstanding count for XFR queries */
919 ++state->d_ds->outstanding;
920 state->d_outstanding = true;
921 }
922 }
923 else {
924 state->d_currentPos += sent;
925 iostate = IOState::NeedWrite;
926 /* disable fast open on partial write */
927 state->d_downstreamConnection->setReused();
928 }
929 }
930
931 if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
932 // then we need to allocate a new buffer (new because we might need to re-send the query if the
933 // backend dies on us
934 // We also might need to read and send to the client more than one response in case of XFR (yeah!)
935 // should very likely be a TCPIOHandler d_downstreamHandler
936 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
937 if (iostate == IOState::Done) {
938 state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
939 state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
940 state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
941 state->d_currentPos = 0;
942 }
943 }
944
945 if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
946 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
947 if (iostate == IOState::Done) {
948 handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
949
950 if (state->d_isXFR) {
951 /* Don't reuse the TCP connection after an {A,I}XFR */
952 /* but don't reset it either, we will need to read more messages */
953 }
954 else {
955 releaseDownstreamConnection(std::move(state->d_downstreamConnection));
956 }
957 fd = -1;
958
959 state->d_responseReadTime = now;
960 handleResponse(state, now);
961 return;
962 }
963 }
964
965 if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
966 state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
967 state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
968 vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
969 }
970 }
971 catch(const std::exception& e) {
972 /* most likely an EOF because the other end closed the connection,
973 but it might also be a real IO error or something else.
974 Let's just drop the connection
975 */
976 vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
977 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
978 ++state->d_ds->tcpDiedSendingQuery;
979 }
980 else {
981 ++state->d_ds->tcpDiedReadingResponse;
982 }
983
984 /* don't increase this counter when reusing connections */
985 if (state->d_downstreamConnection->isFresh()) {
986 ++state->d_downstreamFailures;
987 }
988 if (state->d_outstanding && state->d_ds != nullptr) {
989 --state->d_ds->outstanding;
990 state->d_outstanding = false;
991 }
992 /* remove this FD from the IO multiplexer */
993 iostate = IOState::Done;
994 connectionDied = true;
995 }
996
997 if (iostate == IOState::Done) {
998 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
999 }
1000 else {
1001 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
1002 }
1003
1004 if (connectionDied) {
1005 sendQueryToBackend(state, now);
1006 }
1007 }
1008
1009 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1010 {
1011 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1012 if (state->d_downstreamConnection == nullptr) {
1013 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
1014 }
1015 if (fd != state->d_downstreamConnection->getHandle()) {
1016 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
1017 }
1018
1019 struct timeval now;
1020 gettimeofday(&now, 0);
1021 handleDownstreamIO(state, now);
1022 }
1023
1024 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
1025 {
1026 int fd = state->d_ci.fd;
1027 IOState iostate = IOState::Done;
1028
1029 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
1030 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
1031 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1032 return;
1033 }
1034
1035 try {
1036 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
1037 iostate = state->d_handler.tryHandshake();
1038 if (iostate == IOState::Done) {
1039 state->d_handshakeDoneTime = now;
1040 state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
1041 }
1042 }
1043
1044 if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
1045 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
1046 if (iostate == IOState::Done) {
1047 state->d_state = IncomingTCPConnectionState::State::readingQuery;
1048 state->d_querySizeReadTime = now;
1049 if (state->d_queriesCount == 0) {
1050 state->d_firstQuerySizeReadTime = now;
1051 }
1052 state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
1053 if (state->d_querySize < sizeof(dnsheader)) {
1054 /* go away */
1055 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1056 return;
1057 }
1058
1059 /* allocate a bit more memory to be able to spoof the content,
1060 or to add ECS without allocating a new buffer */
1061 state->d_buffer.resize(state->d_querySize + 512);
1062 state->d_currentPos = 0;
1063 }
1064 }
1065
1066 if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1067 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
1068 if (iostate == IOState::Done) {
1069 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1070 handleQuery(state, now);
1071 return;
1072 }
1073 }
1074
1075 if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1076 iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
1077 if (iostate == IOState::Done) {
1078 handleResponseSent(state, now);
1079 return;
1080 }
1081 }
1082
1083 if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
1084 state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
1085 state->d_state != IncomingTCPConnectionState::State::readingQuery &&
1086 state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
1087 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
1088 }
1089 }
1090 catch(const std::exception& e) {
1091 /* most likely an EOF because the other end closed the connection,
1092 but it might also be a real IO error or something else.
1093 Let's just drop the connection
1094 */
1095 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
1096 state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
1097 state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1098 ++state->d_ci.cs->tcpDiedReadingQuery;
1099 }
1100 else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1101 ++state->d_ci.cs->tcpDiedSendingResponse;
1102 }
1103
1104 if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
1105 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
1106 }
1107 else {
1108 vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
1109 }
1110 /* remove this FD from the IO multiplexer */
1111 iostate = IOState::Done;
1112 }
1113
1114 if (iostate == IOState::Done) {
1115 handleNewIOState(state, iostate, fd, handleIOCallback);
1116 }
1117 else {
1118 handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
1119 }
1120 }
1121
1122 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1123 {
1124 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1125 if (fd != state->d_ci.fd) {
1126 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
1127 }
1128 struct timeval now;
1129 gettimeofday(&now, 0);
1130
1131 handleIO(state, now);
1132 }
1133
1134 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1135 {
1136 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1137
1138 ConnectionInfo* citmp{nullptr};
1139
1140 ssize_t got = read(pipefd, &citmp, sizeof(citmp));
1141 if (got == 0) {
1142 throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1143 }
1144 else if (got == -1) {
1145 if (errno == EAGAIN || errno == EINTR) {
1146 return;
1147 }
1148 throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + strerror(errno));
1149 }
1150 else if (got != sizeof(citmp)) {
1151 throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1152 }
1153
1154 try {
1155 g_tcpclientthreads->decrementQueuedCount();
1156
1157 struct timeval now;
1158 gettimeofday(&now, 0);
1159 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1160 delete citmp;
1161 citmp = nullptr;
1162
1163 /* let's update the remaining time */
1164 state->d_remainingTime = g_maxTCPConnectionDuration;
1165
1166 handleIO(state, now);
1167 }
1168 catch(...) {
1169 delete citmp;
1170 citmp = nullptr;
1171 throw;
1172 }
1173 }
1174
1175 void tcpClientThread(int pipefd)
1176 {
1177 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1178 from that point on */
1179
1180 setThreadName("dnsdist/tcpClie");
1181
1182 TCPClientThreadData data;
1183
1184 data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
1185 struct timeval now;
1186 gettimeofday(&now, 0);
1187 time_t lastTCPCleanup = now.tv_sec;
1188 time_t lastTimeoutScan = now.tv_sec;
1189
1190 for (;;) {
1191 data.mplexer->run(&now);
1192
1193 if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
1194 cleanupClosedTCPConnections();
1195 lastTCPCleanup = now.tv_sec;
1196 }
1197
1198 if (now.tv_sec > lastTimeoutScan) {
1199 lastTimeoutScan = now.tv_sec;
1200 auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1201 for(const auto& conn : expiredReadConns) {
1202 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1203 if (conn.first == state->d_ci.fd) {
1204 vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1205 ++state->d_ci.cs->tcpClientTimeouts;
1206 }
1207 else if (state->d_ds) {
1208 vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
1209 ++state->d_ci.cs->tcpDownstreamTimeouts;
1210 ++state->d_ds->tcpReadTimeouts;
1211 }
1212 data.mplexer->removeReadFD(conn.first);
1213 state->d_lastIOState = IOState::Done;
1214 }
1215
1216 auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1217 for(const auto& conn : expiredWriteConns) {
1218 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1219 if (conn.first == state->d_ci.fd) {
1220 vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1221 ++state->d_ci.cs->tcpClientTimeouts;
1222 }
1223 else if (state->d_ds) {
1224 vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
1225 ++state->d_ci.cs->tcpDownstreamTimeouts;
1226 ++state->d_ds->tcpWriteTimeouts;
1227 }
1228 data.mplexer->removeWriteFD(conn.first);
1229 state->d_lastIOState = IOState::Done;
1230 }
1231 }
1232 }
1233 }
1234
1235 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1236 they will hand off to worker threads & spawn more of them if required
1237 */
1238 void tcpAcceptorThread(void* p)
1239 {
1240 setThreadName("dnsdist/tcpAcce");
1241 ClientState* cs = (ClientState*) p;
1242 bool tcpClientCountIncremented = false;
1243 ComboAddress remote;
1244 remote.sin4.sin_family = cs->local.sin4.sin_family;
1245
1246 g_tcpclientthreads->addTCPClientThread();
1247
1248 auto acl = g_ACL.getLocal();
1249 for(;;) {
1250 bool queuedCounterIncremented = false;
1251 std::unique_ptr<ConnectionInfo> ci;
1252 tcpClientCountIncremented = false;
1253 try {
1254 socklen_t remlen = remote.getSocklen();
1255 ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
1256 #ifdef HAVE_ACCEPT4
1257 ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1258 #else
1259 ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1260 #endif
1261 ++cs->tcpCurrentConnections;
1262
1263 if(ci->fd < 0) {
1264 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % strerror(errno)).str());
1265 }
1266
1267 if(!acl->match(remote)) {
1268 ++g_stats.aclDrops;
1269 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1270 continue;
1271 }
1272
1273 #ifndef HAVE_ACCEPT4
1274 if (!setNonBlocking(ci->fd)) {
1275 continue;
1276 }
1277 #endif
1278 setTCPNoDelay(ci->fd); // disable NAGLE
1279 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1280 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1281 continue;
1282 }
1283
1284 if (g_maxTCPConnectionsPerClient) {
1285 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
1286
1287 if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
1288 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1289 continue;
1290 }
1291 tcpClientsCount[remote]++;
1292 tcpClientCountIncremented = true;
1293 }
1294
1295 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1296
1297 ci->remote = remote;
1298 int pipe = g_tcpclientthreads->getThread();
1299 if (pipe >= 0) {
1300 queuedCounterIncremented = true;
1301 auto tmp = ci.release();
1302 try {
1303 writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
1304 }
1305 catch(...) {
1306 delete tmp;
1307 tmp = nullptr;
1308 throw;
1309 }
1310 }
1311 else {
1312 g_tcpclientthreads->decrementQueuedCount();
1313 queuedCounterIncremented = false;
1314 if(tcpClientCountIncremented) {
1315 decrementTCPClientCount(remote);
1316 }
1317 }
1318 }
1319 catch(const std::exception& e) {
1320 errlog("While reading a TCP question: %s", e.what());
1321 if(tcpClientCountIncremented) {
1322 decrementTCPClientCount(remote);
1323 }
1324 if (queuedCounterIncremented) {
1325 g_tcpclientthreads->decrementQueuedCount();
1326 }
1327 }
1328 catch(...){}
1329 }
1330 }