]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
b250004d3a275896484c5e9a1ffdb98a2ae8df8c
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-rings.hh"
25 #include "dnsdist-xpf.hh"
26
27 #include "dnsparser.hh"
28 #include "ednsoptions.hh"
29 #include "dolog.hh"
30 #include "lock.hh"
31 #include "gettime.hh"
32 #include "tcpiohandler.hh"
33 #include "threadname.hh"
34 #include <thread>
35 #include <atomic>
36 #include <netinet/tcp.h>
37
38 #include "sstuff.hh"
39
40 using std::thread;
41 using std::atomic;
42
43 /* TCP: the grand design.
44 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
45 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
46 we will not go there.
47
48 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
49 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
50 to guarantee performance.
51
52 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
53 So whenever an answer comes in, we know where it needs to go.
54
55 Let's start naively.
56 */
57
58 static std::mutex tcpClientsCountMutex;
59 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
60 static const size_t g_maxCachedConnectionsPerDownstream = 20;
61 uint64_t g_maxTCPQueuedConnections{1000};
62 size_t g_maxTCPQueriesPerConn{0};
63 size_t g_maxTCPConnectionDuration{0};
64 size_t g_maxTCPConnectionsPerClient{0};
65 uint16_t g_downstreamTCPCleanupInterval{60};
66 bool g_useTCPSinglePipe{false};
67
68 static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
69 {
70 std::unique_ptr<Socket> result;
71
72 do {
73 vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
74 try {
75 result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
76 if (!IsAnyAddress(ds->sourceAddr)) {
77 SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
78 #ifdef IP_BIND_ADDRESS_NO_PORT
79 if (ds->ipBindAddrNoPort) {
80 SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
81 }
82 #endif
83 result->bind(ds->sourceAddr, false);
84 }
85 result->setNonBlocking();
86 #ifdef MSG_FASTOPEN
87 if (!ds->tcpFastOpen) {
88 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
89 }
90 #else
91 SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
92 #endif /* MSG_FASTOPEN */
93 return result;
94 }
95 catch(const std::runtime_error& e) {
96 vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
97 downstreamFailures++;
98 if (downstreamFailures > ds->retries) {
99 throw;
100 }
101 }
102 } while(downstreamFailures <= ds->retries);
103
104 return nullptr;
105 }
106
107 class TCPConnectionToBackend
108 {
109 public:
110 TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
111 {
112 d_socket = setupTCPDownstream(d_ds, downstreamFailures);
113 ++d_ds->tcpCurrentConnections;
114 }
115
116 ~TCPConnectionToBackend()
117 {
118 if (d_ds && d_socket) {
119 --d_ds->tcpCurrentConnections;
120 struct timeval now;
121 gettimeofday(&now, nullptr);
122
123 auto diff = now - d_connectionStartTime;
124 d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
125 }
126 }
127
128 int getHandle() const
129 {
130 if (!d_socket) {
131 throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
132 }
133
134 return d_socket->getHandle();
135 }
136
137 const ComboAddress& getRemote() const
138 {
139 return d_ds->remote;
140 }
141
142 bool isFresh() const
143 {
144 return d_fresh;
145 }
146
147 void incQueries()
148 {
149 ++d_queries;
150 }
151
152 void setReused()
153 {
154 d_fresh = false;
155 }
156
157 void disableFastOpen()
158 {
159 d_enableFastOpen = false;
160 }
161
162 bool isFastOpenEnabled()
163 {
164 return d_enableFastOpen;
165 }
166
167 private:
168 std::unique_ptr<Socket> d_socket{nullptr};
169 std::shared_ptr<DownstreamState> d_ds{nullptr};
170 struct timeval d_connectionStartTime;
171 uint64_t d_queries{0};
172 bool d_fresh{true};
173 bool d_enableFastOpen{false};
174 };
175
176 static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
177
178 static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
179 {
180 std::unique_ptr<TCPConnectionToBackend> result;
181
182 const auto& it = t_downstreamConnections.find(ds->remote);
183 if (it != t_downstreamConnections.end()) {
184 auto& list = it->second;
185 if (!list.empty()) {
186 result = std::move(list.front());
187 list.pop_front();
188 result->setReused();
189 return result;
190 }
191 }
192
193 return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
194 }
195
196 static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
197 {
198 if (conn == nullptr) {
199 return;
200 }
201
202 const auto& remote = conn->getRemote();
203 const auto& it = t_downstreamConnections.find(remote);
204 if (it != t_downstreamConnections.end()) {
205 auto& list = it->second;
206 if (list.size() >= g_maxCachedConnectionsPerDownstream) {
207 /* too many connections queued already */
208 conn.reset();
209 return;
210 }
211 list.push_back(std::move(conn));
212 }
213 else {
214 t_downstreamConnections[remote].push_back(std::move(conn));
215 }
216 }
217
218 struct ConnectionInfo
219 {
220 ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
221 {
222 }
223 ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
224 {
225 rhs.cs = nullptr;
226 rhs.fd = -1;
227 }
228
229 ConnectionInfo(const ConnectionInfo& rhs) = delete;
230 ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
231
232 ConnectionInfo& operator=(ConnectionInfo&& rhs)
233 {
234 remote = rhs.remote;
235 cs = rhs.cs;
236 rhs.cs = nullptr;
237 fd = rhs.fd;
238 rhs.fd = -1;
239 return *this;
240 }
241
242 ~ConnectionInfo()
243 {
244 if (fd != -1) {
245 close(fd);
246 fd = -1;
247 }
248 if (cs) {
249 --cs->tcpCurrentConnections;
250 }
251 }
252
253 ComboAddress remote;
254 ClientState* cs{nullptr};
255 int fd{-1};
256 };
257
258 void tcpClientThread(int pipefd);
259
260 static void decrementTCPClientCount(const ComboAddress& client)
261 {
262 if (g_maxTCPConnectionsPerClient) {
263 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
264 tcpClientsCount[client]--;
265 if (tcpClientsCount[client] == 0) {
266 tcpClientsCount.erase(client);
267 }
268 }
269 }
270
271 void TCPClientCollection::addTCPClientThread()
272 {
273 int pipefds[2] = { -1, -1};
274
275 vinfolog("Adding TCP Client thread");
276
277 if (d_useSinglePipe) {
278 pipefds[0] = d_singlePipe[0];
279 pipefds[1] = d_singlePipe[1];
280 }
281 else {
282 if (pipe(pipefds) < 0) {
283 errlog("Error creating the TCP thread communication pipe: %s", strerror(errno));
284 return;
285 }
286
287 if (!setNonBlocking(pipefds[0])) {
288 close(pipefds[0]);
289 close(pipefds[1]);
290 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
291 return;
292 }
293
294 if (!setNonBlocking(pipefds[1])) {
295 close(pipefds[0]);
296 close(pipefds[1]);
297 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
298 return;
299 }
300 }
301
302 {
303 std::lock_guard<std::mutex> lock(d_mutex);
304
305 if (d_numthreads >= d_tcpclientthreads.capacity()) {
306 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
307 if (!d_useSinglePipe) {
308 close(pipefds[0]);
309 close(pipefds[1]);
310 }
311 return;
312 }
313
314 try {
315 thread t1(tcpClientThread, pipefds[0]);
316 t1.detach();
317 }
318 catch(const std::runtime_error& e) {
319 /* the thread creation failed, don't leak */
320 errlog("Error creating a TCP thread: %s", e.what());
321 if (!d_useSinglePipe) {
322 close(pipefds[0]);
323 close(pipefds[1]);
324 }
325 return;
326 }
327
328 d_tcpclientthreads.push_back(pipefds[1]);
329 }
330
331 ++d_numthreads;
332 }
333
334 static void cleanupClosedTCPConnections()
335 {
336 for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
337 for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
338 if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
339 ++connIt;
340 }
341 else {
342 connIt = dsIt->second.erase(connIt);
343 }
344 }
345
346 if (!dsIt->second.empty()) {
347 ++dsIt;
348 }
349 else {
350 dsIt = t_downstreamConnections.erase(dsIt);
351 }
352 }
353 }
354
355 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
356 Updates pos everytime a successful read occurs,
357 throws an std::runtime_error in case of IO error,
358 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
359 would block.
360 */
361 // XXX could probably be implemented as a TCPIOHandler
362 IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
363 {
364 if (buffer.size() < (pos + toRead)) {
365 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
366 }
367
368 size_t got = 0;
369 do {
370 ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
371 if (res == 0) {
372 throw runtime_error("EOF while reading message");
373 }
374 if (res < 0) {
375 if (errno == EAGAIN || errno == EWOULDBLOCK) {
376 return IOState::NeedRead;
377 }
378 else {
379 throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
380 }
381 }
382
383 pos += static_cast<size_t>(res);
384 got += static_cast<size_t>(res);
385 }
386 while (got < toRead);
387
388 return IOState::Done;
389 }
390
391 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
392
393 class TCPClientThreadData
394 {
395 public:
396 TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
397 {
398 }
399
400 LocalHolders holders;
401 LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
402 std::unique_ptr<FDMultiplexer> mplexer{nullptr};
403 };
404
405 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
406
407 class IncomingTCPConnectionState
408 {
409 public:
410 IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
411 {
412 d_ids.origDest.reset();
413 d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
414 socklen_t socklen = d_ids.origDest.getSocklen();
415 if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
416 d_ids.origDest = d_ci.cs->local;
417 }
418 }
419
420 IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
421 IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
422
423 ~IncomingTCPConnectionState()
424 {
425 decrementTCPClientCount(d_ci.remote);
426 if (d_ci.cs != nullptr) {
427 struct timeval now;
428 gettimeofday(&now, nullptr);
429
430 auto diff = now - d_connectionStartTime;
431 d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
432 }
433
434 if (d_ds != nullptr) {
435 if (d_outstanding) {
436 --d_ds->outstanding;
437 d_outstanding = false;
438 }
439
440 if (d_downstreamConnection) {
441 try {
442 if (d_lastIOState == IOState::NeedRead) {
443 cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
444 d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
445 }
446 else if (d_lastIOState == IOState::NeedWrite) {
447 cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
448 d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
449 }
450 }
451 catch(const FDMultiplexerException& e) {
452 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
453 }
454 catch(const std::runtime_error& e) {
455 /* might be thrown by getHandle() */
456 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
457 }
458 }
459 }
460
461 try {
462 if (d_lastIOState == IOState::NeedRead) {
463 cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
464 d_threadData.mplexer->removeReadFD(d_ci.fd);
465 }
466 else if (d_lastIOState == IOState::NeedWrite) {
467 cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
468 d_threadData.mplexer->removeWriteFD(d_ci.fd);
469 }
470 }
471 catch(const FDMultiplexerException& e) {
472 vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
473 }
474 }
475
476 void resetForNewQuery()
477 {
478 d_buffer.resize(sizeof(uint16_t));
479 d_currentPos = 0;
480 d_querySize = 0;
481 d_responseSize = 0;
482 d_downstreamFailures = 0;
483 d_state = State::readingQuerySize;
484 d_lastIOState = IOState::Done;
485 d_selfGeneratedResponse = false;
486 }
487
488 boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
489 {
490 if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
491 return boost::none;
492 }
493
494 if (g_maxTCPConnectionDuration > 0) {
495 auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
496 if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
497 return now;
498 }
499 auto remaining = g_maxTCPConnectionDuration - elapsed;
500 if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
501 now.tv_sec += remaining;
502 return now;
503 }
504 }
505
506 now.tv_sec += g_tcpRecvTimeout;
507 return now;
508 }
509
510 boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
511 {
512 if (d_ds == nullptr) {
513 throw std::runtime_error("getBackendReadTTD() without any backend selected");
514 }
515 if (d_ds->tcpRecvTimeout == 0) {
516 return boost::none;
517 }
518
519 struct timeval res = now;
520 res.tv_sec += d_ds->tcpRecvTimeout;
521
522 return res;
523 }
524
525 boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
526 {
527 if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
528 return boost::none;
529 }
530
531 struct timeval res = now;
532
533 if (g_maxTCPConnectionDuration > 0) {
534 auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
535 if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
536 return res;
537 }
538 auto remaining = g_maxTCPConnectionDuration - elapsed;
539 if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
540 res.tv_sec += remaining;
541 return res;
542 }
543 }
544
545 res.tv_sec += g_tcpSendTimeout;
546 return res;
547 }
548
549 boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
550 {
551 if (d_ds == nullptr) {
552 throw std::runtime_error("getBackendReadTTD() called without any backend selected");
553 }
554 if (d_ds->tcpSendTimeout == 0) {
555 return boost::none;
556 }
557
558 struct timeval res = now;
559 res.tv_sec += d_ds->tcpSendTimeout;
560
561 return res;
562 }
563
564 bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
565 {
566 if (maxConnectionDuration) {
567 time_t curtime = now.tv_sec;
568 unsigned int elapsed = 0;
569 if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
570 elapsed = curtime - d_connectionStartTime.tv_sec;
571 }
572 if (elapsed >= maxConnectionDuration) {
573 return true;
574 }
575 d_remainingTime = maxConnectionDuration - elapsed;
576 }
577
578 return false;
579 }
580
581 void dump() const
582 {
583 static std::mutex s_mutex;
584
585 struct timeval now;
586 gettimeofday(&now, 0);
587
588 {
589 std::lock_guard<std::mutex> lock(s_mutex);
590 fprintf(stderr, "State is %p\n", this);
591 cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
592 cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
593 cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
594 if (d_state > State::doingHandshake) {
595 cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
596 }
597 if (d_state > State::readingQuerySize) {
598 cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
599 }
600 if (d_state > State::readingQuerySize) {
601 cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
602 }
603 if (d_state > State::readingQuery) {
604 cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
605 }
606 if (d_state > State::sendingQueryToBackend) {
607 cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
608 }
609 if (d_state > State::readingResponseFromBackend) {
610 cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
611 }
612 }
613 }
614
615 enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
616
617 std::vector<uint8_t> d_buffer;
618 std::vector<uint8_t> d_responseBuffer;
619 TCPClientThreadData& d_threadData;
620 IDState d_ids;
621 ConnectionInfo d_ci;
622 TCPIOHandler d_handler;
623 std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
624 std::shared_ptr<DownstreamState> d_ds{nullptr};
625 dnsheader d_cleartextDH;
626 struct timeval d_connectionStartTime;
627 struct timeval d_handshakeDoneTime;
628 struct timeval d_firstQuerySizeReadTime;
629 struct timeval d_querySizeReadTime;
630 struct timeval d_queryReadTime;
631 struct timeval d_querySentTime;
632 struct timeval d_responseReadTime;
633 size_t d_currentPos{0};
634 size_t d_queriesCount{0};
635 unsigned int d_remainingTime{0};
636 uint16_t d_querySize{0};
637 uint16_t d_responseSize{0};
638 uint16_t d_downstreamFailures{0};
639 State d_state{State::doingHandshake};
640 IOState d_lastIOState{IOState::Done};
641 bool d_readingFirstQuery{true};
642 bool d_outstanding{false};
643 bool d_firstResponsePacket{true};
644 bool d_isXFR{false};
645 bool d_xfrStarted{false};
646 bool d_selfGeneratedResponse{false};
647 };
648
649 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
650 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
651 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
652 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
653
654 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
655 {
656 handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
657
658 if (state->d_isXFR && state->d_downstreamConnection) {
659 /* we need to resume reading from the backend! */
660 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
661 state->d_currentPos = 0;
662 handleDownstreamIO(state, now);
663 return;
664 }
665
666 if (state->d_selfGeneratedResponse == false && state->d_ds) {
667 /* if we have no downstream server selected, this was a self-answered response
668 but cache hits have a selected server as well, so be careful */
669 struct timespec answertime;
670 gettime(&answertime);
671 double udiff = state->d_ids.sentTime.udiff();
672 g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote);
673 vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
674 }
675
676 if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
677 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
678 return;
679 }
680
681 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
682 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
683 return;
684 }
685
686 state->resetForNewQuery();
687
688 handleIO(state, now);
689 }
690
691 static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
692 {
693 state->d_state = IncomingTCPConnectionState::State::sendingResponse;
694 const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
695 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
696 that could occur if we had to deal with the size during the processing,
697 especially alignment issues */
698 state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
699
700 state->d_currentPos = 0;
701
702 handleIO(state, now);
703 }
704
705 static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
706 {
707 if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) {
708 return;
709 }
710
711 auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
712 unsigned int consumed;
713 if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
714 return;
715 }
716 state->d_firstResponsePacket = false;
717
718 if (state->d_outstanding) {
719 --state->d_ds->outstanding;
720 state->d_outstanding = false;
721 }
722
723 auto dh = reinterpret_cast<struct dnsheader*>(response);
724 uint16_t addRoom = 0;
725 DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
726 if (dr.dnsCryptQuery) {
727 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
728 }
729
730 memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH));
731
732 std::vector<uint8_t> rewrittenResponse;
733 size_t responseSize = state->d_responseBuffer.size();
734 if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
735 return;
736 }
737
738 if (!rewrittenResponse.empty()) {
739 /* responseSize has been updated as well but we don't really care since it will match
740 the capacity of rewrittenResponse anyway */
741 state->d_responseBuffer = std::move(rewrittenResponse);
742 state->d_responseSize = state->d_responseBuffer.size();
743 } else {
744 /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
745 state->d_responseBuffer.resize(state->d_responseSize);
746 }
747
748 if (state->d_isXFR && !state->d_xfrStarted) {
749 /* don't bother parsing the content of the response for now */
750 state->d_xfrStarted = true;
751 }
752
753 ++g_stats.responses;
754
755 sendResponse(state, now);
756 }
757
758 static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
759 {
760 auto ds = state->d_ds;
761 state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
762 state->d_currentPos = 0;
763 state->d_firstResponsePacket = true;
764 state->d_downstreamConnection.reset();
765
766 if (state->d_xfrStarted) {
767 /* sorry, but we are not going to resume a XFR if we have already sent some packets
768 to the client */
769 return;
770 }
771
772 if (state->d_downstreamFailures < state->d_ds->retries) {
773 try {
774 state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
775 }
776 catch (const std::runtime_error& e) {
777 state->d_downstreamConnection.reset();
778 }
779 }
780
781 if (!state->d_downstreamConnection) {
782 ++ds->tcpGaveUp;
783 ++state->d_ci.cs->tcpGaveUp;
784 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
785 return;
786 }
787
788 vinfolog("Got query for %s|%s from %s (%s), relayed to %s", state->d_ids.qname.toString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), ds->getName());
789
790 handleDownstreamIO(state, now);
791 return;
792 }
793
794 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
795 {
796 if (state->d_querySize < sizeof(dnsheader)) {
797 ++g_stats.nonCompliantQueries;
798 return;
799 }
800
801 state->d_readingFirstQuery = false;
802 ++state->d_queriesCount;
803 ++state->d_ci.cs->queries;
804 ++g_stats.queries;
805
806 /* we need an accurate ("real") value for the response and
807 to store into the IDS, but not for insertion into the
808 rings for example */
809 struct timespec queryRealTime;
810 gettime(&queryRealTime, true);
811
812 auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
813 std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
814 auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
815 if (dnsCryptResponse) {
816 state->d_responseBuffer = std::move(*dnsCryptResponse);
817 state->d_responseSize = state->d_responseBuffer.size();
818 sendResponse(state, now);
819 return;
820 }
821
822 const auto& dh = reinterpret_cast<dnsheader*>(query);
823 if (!checkQueryHeaders(dh)) {
824 return;
825 }
826
827 uint16_t qtype, qclass;
828 unsigned int consumed = 0;
829 DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
830 DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
831 dq.dnsCryptQuery = std::move(dnsCryptQuery);
832 dq.sni = state->d_handler.getServerNameIndication();
833
834 state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
835 if (state->d_isXFR) {
836 dq.skipCache = true;
837 }
838
839 state->d_ds.reset();
840 auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
841
842 if (result == ProcessQueryResult::Drop) {
843 return;
844 }
845
846 if (result == ProcessQueryResult::SendAnswer) {
847 state->d_selfGeneratedResponse = true;
848 state->d_buffer.resize(dq.len);
849 state->d_responseBuffer = std::move(state->d_buffer);
850 state->d_responseSize = state->d_responseBuffer.size();
851 sendResponse(state, now);
852 return;
853 }
854
855 if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
856 return;
857 }
858
859 state->d_buffer.resize(dq.len);
860 setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
861
862 const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
863 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
864 that could occur if we had to deal with the size during the processing,
865 especially alignment issues */
866 state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
867 sendQueryToBackend(state, now);
868 }
869
870 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
871 {
872 //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
873
874 if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
875 state->d_threadData.mplexer->removeReadFD(fd);
876 //cerr<<__func__<<": remove read FD "<<fd<<endl;
877 state->d_lastIOState = IOState::Done;
878 }
879 else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
880 state->d_threadData.mplexer->removeWriteFD(fd);
881 //cerr<<__func__<<": remove write FD "<<fd<<endl;
882 state->d_lastIOState = IOState::Done;
883 }
884
885 if (iostate == IOState::NeedRead) {
886 if (state->d_lastIOState == IOState::NeedRead) {
887 if (ttd) {
888 /* let's update the TTD ! */
889 state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
890 }
891 return;
892 }
893
894 state->d_lastIOState = IOState::NeedRead;
895 //cerr<<__func__<<": add read FD "<<fd<<endl;
896 state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
897 }
898 else if (iostate == IOState::NeedWrite) {
899 if (state->d_lastIOState == IOState::NeedWrite) {
900 return;
901 }
902
903 state->d_lastIOState = IOState::NeedWrite;
904 //cerr<<__func__<<": add write FD "<<fd<<endl;
905 state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
906 }
907 else if (iostate == IOState::Done) {
908 state->d_lastIOState = IOState::Done;
909 }
910 }
911
912 static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
913 {
914 if (state->d_downstreamConnection == nullptr) {
915 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
916 }
917
918 int fd = state->d_downstreamConnection->getHandle();
919 IOState iostate = IOState::Done;
920 bool connectionDied = false;
921
922 try {
923 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
924 int socketFlags = 0;
925 #ifdef MSG_FASTOPEN
926 if (state->d_downstreamConnection->isFastOpenEnabled()) {
927 socketFlags |= MSG_FASTOPEN;
928 }
929 #endif /* MSG_FASTOPEN */
930
931 size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
932 if (sent == state->d_buffer.size()) {
933 /* request sent ! */
934 state->d_downstreamConnection->incQueries();
935 state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
936 state->d_currentPos = 0;
937 state->d_querySentTime = now;
938 iostate = IOState::NeedRead;
939 if (!state->d_isXFR && !state->d_outstanding) {
940 /* don't bother with the outstanding count for XFR queries */
941 ++state->d_ds->outstanding;
942 state->d_outstanding = true;
943 }
944 }
945 else {
946 state->d_currentPos += sent;
947 iostate = IOState::NeedWrite;
948 /* disable fast open on partial write */
949 state->d_downstreamConnection->disableFastOpen();
950 }
951 }
952
953 if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
954 // then we need to allocate a new buffer (new because we might need to re-send the query if the
955 // backend dies on us
956 // We also might need to read and send to the client more than one response in case of XFR (yeah!)
957 // should very likely be a TCPIOHandler d_downstreamHandler
958 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
959 if (iostate == IOState::Done) {
960 state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
961 state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
962 state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
963 state->d_currentPos = 0;
964 }
965 }
966
967 if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
968 iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
969 if (iostate == IOState::Done) {
970 handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
971
972 if (state->d_isXFR) {
973 /* Don't reuse the TCP connection after an {A,I}XFR */
974 /* but don't reset it either, we will need to read more messages */
975 }
976 else {
977 releaseDownstreamConnection(std::move(state->d_downstreamConnection));
978 }
979 fd = -1;
980
981 state->d_responseReadTime = now;
982 try {
983 handleResponse(state, now);
984 }
985 catch (const std::exception& e) {
986 vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what());
987 }
988 return;
989 }
990 }
991
992 if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
993 state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
994 state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
995 vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
996 }
997 }
998 catch(const std::exception& e) {
999 /* most likely an EOF because the other end closed the connection,
1000 but it might also be a real IO error or something else.
1001 Let's just drop the connection
1002 */
1003 vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
1004 if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
1005 ++state->d_ds->tcpDiedSendingQuery;
1006 }
1007 else {
1008 ++state->d_ds->tcpDiedReadingResponse;
1009 }
1010
1011 /* don't increase this counter when reusing connections */
1012 if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) {
1013 ++state->d_downstreamFailures;
1014 }
1015
1016 if (state->d_outstanding) {
1017 state->d_outstanding = false;
1018
1019 if (state->d_ds != nullptr) {
1020 --state->d_ds->outstanding;
1021 }
1022 }
1023 /* remove this FD from the IO multiplexer */
1024 iostate = IOState::Done;
1025 connectionDied = true;
1026 }
1027
1028 if (iostate == IOState::Done) {
1029 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
1030 }
1031 else {
1032 handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
1033 }
1034
1035 if (connectionDied) {
1036 sendQueryToBackend(state, now);
1037 }
1038 }
1039
1040 static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1041 {
1042 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1043 if (state->d_downstreamConnection == nullptr) {
1044 throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
1045 }
1046 if (fd != state->d_downstreamConnection->getHandle()) {
1047 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
1048 }
1049
1050 struct timeval now;
1051 gettimeofday(&now, 0);
1052 handleDownstreamIO(state, now);
1053 }
1054
1055 static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
1056 {
1057 int fd = state->d_ci.fd;
1058 IOState iostate = IOState::Done;
1059
1060 if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
1061 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
1062 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1063 return;
1064 }
1065
1066 try {
1067 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
1068 iostate = state->d_handler.tryHandshake();
1069 if (iostate == IOState::Done) {
1070 state->d_handshakeDoneTime = now;
1071 state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
1072 }
1073 }
1074
1075 if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
1076 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
1077 if (iostate == IOState::Done) {
1078 state->d_state = IncomingTCPConnectionState::State::readingQuery;
1079 state->d_querySizeReadTime = now;
1080 if (state->d_queriesCount == 0) {
1081 state->d_firstQuerySizeReadTime = now;
1082 }
1083 state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
1084 if (state->d_querySize < sizeof(dnsheader)) {
1085 /* go away */
1086 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1087 return;
1088 }
1089
1090 /* allocate a bit more memory to be able to spoof the content,
1091 or to add ECS without allocating a new buffer */
1092 state->d_buffer.resize(state->d_querySize + 512);
1093 state->d_currentPos = 0;
1094 }
1095 }
1096
1097 if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1098 iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
1099 if (iostate == IOState::Done) {
1100 handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1101 handleQuery(state, now);
1102 return;
1103 }
1104 }
1105
1106 if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1107 iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
1108 if (iostate == IOState::Done) {
1109 handleResponseSent(state, now);
1110 return;
1111 }
1112 }
1113
1114 if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
1115 state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
1116 state->d_state != IncomingTCPConnectionState::State::readingQuery &&
1117 state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
1118 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
1119 }
1120 }
1121 catch(const std::exception& e) {
1122 /* most likely an EOF because the other end closed the connection,
1123 but it might also be a real IO error or something else.
1124 Let's just drop the connection
1125 */
1126 if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
1127 state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
1128 state->d_state == IncomingTCPConnectionState::State::readingQuery) {
1129 ++state->d_ci.cs->tcpDiedReadingQuery;
1130 }
1131 else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
1132 ++state->d_ci.cs->tcpDiedSendingResponse;
1133 }
1134
1135 if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
1136 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
1137 }
1138 else {
1139 vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
1140 }
1141 /* remove this FD from the IO multiplexer */
1142 iostate = IOState::Done;
1143 }
1144
1145 if (iostate == IOState::Done) {
1146 handleNewIOState(state, iostate, fd, handleIOCallback);
1147 }
1148 else {
1149 handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
1150 }
1151 }
1152
1153 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
1154 {
1155 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
1156 if (fd != state->d_ci.fd) {
1157 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
1158 }
1159 struct timeval now;
1160 gettimeofday(&now, 0);
1161
1162 handleIO(state, now);
1163 }
1164
1165 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1166 {
1167 auto threadData = boost::any_cast<TCPClientThreadData*>(param);
1168
1169 ConnectionInfo* citmp{nullptr};
1170
1171 ssize_t got = read(pipefd, &citmp, sizeof(citmp));
1172 if (got == 0) {
1173 throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1174 }
1175 else if (got == -1) {
1176 if (errno == EAGAIN || errno == EINTR) {
1177 return;
1178 }
1179 throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + strerror(errno));
1180 }
1181 else if (got != sizeof(citmp)) {
1182 throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
1183 }
1184
1185 try {
1186 g_tcpclientthreads->decrementQueuedCount();
1187
1188 struct timeval now;
1189 gettimeofday(&now, 0);
1190 auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1191 delete citmp;
1192 citmp = nullptr;
1193
1194 /* let's update the remaining time */
1195 state->d_remainingTime = g_maxTCPConnectionDuration;
1196
1197 handleIO(state, now);
1198 }
1199 catch(...) {
1200 delete citmp;
1201 citmp = nullptr;
1202 throw;
1203 }
1204 }
1205
1206 void tcpClientThread(int pipefd)
1207 {
1208 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1209 from that point on */
1210
1211 setThreadName("dnsdist/tcpClie");
1212
1213 TCPClientThreadData data;
1214
1215 data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
1216 struct timeval now;
1217 gettimeofday(&now, 0);
1218 time_t lastTCPCleanup = now.tv_sec;
1219 time_t lastTimeoutScan = now.tv_sec;
1220
1221 for (;;) {
1222 data.mplexer->run(&now);
1223
1224 if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
1225 cleanupClosedTCPConnections();
1226 lastTCPCleanup = now.tv_sec;
1227 }
1228
1229 if (now.tv_sec > lastTimeoutScan) {
1230 lastTimeoutScan = now.tv_sec;
1231 auto expiredReadConns = data.mplexer->getTimeouts(now, false);
1232 for(const auto& conn : expiredReadConns) {
1233 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1234 if (conn.first == state->d_ci.fd) {
1235 vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1236 ++state->d_ci.cs->tcpClientTimeouts;
1237 }
1238 else if (state->d_ds) {
1239 vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
1240 ++state->d_ci.cs->tcpDownstreamTimeouts;
1241 ++state->d_ds->tcpReadTimeouts;
1242 }
1243 data.mplexer->removeReadFD(conn.first);
1244 state->d_lastIOState = IOState::Done;
1245 }
1246
1247 auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
1248 for(const auto& conn : expiredWriteConns) {
1249 auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
1250 if (conn.first == state->d_ci.fd) {
1251 vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
1252 ++state->d_ci.cs->tcpClientTimeouts;
1253 }
1254 else if (state->d_ds) {
1255 vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
1256 ++state->d_ci.cs->tcpDownstreamTimeouts;
1257 ++state->d_ds->tcpWriteTimeouts;
1258 }
1259 data.mplexer->removeWriteFD(conn.first);
1260 state->d_lastIOState = IOState::Done;
1261 }
1262 }
1263 }
1264 }
1265
1266 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1267 they will hand off to worker threads & spawn more of them if required
1268 */
1269 void tcpAcceptorThread(void* p)
1270 {
1271 setThreadName("dnsdist/tcpAcce");
1272 ClientState* cs = (ClientState*) p;
1273 bool tcpClientCountIncremented = false;
1274 ComboAddress remote;
1275 remote.sin4.sin_family = cs->local.sin4.sin_family;
1276
1277 g_tcpclientthreads->addTCPClientThread();
1278
1279 auto acl = g_ACL.getLocal();
1280 for(;;) {
1281 bool queuedCounterIncremented = false;
1282 std::unique_ptr<ConnectionInfo> ci;
1283 tcpClientCountIncremented = false;
1284 try {
1285 socklen_t remlen = remote.getSocklen();
1286 ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
1287 #ifdef HAVE_ACCEPT4
1288 ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
1289 #else
1290 ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1291 #endif
1292 ++cs->tcpCurrentConnections;
1293
1294 if(ci->fd < 0) {
1295 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % strerror(errno)).str());
1296 }
1297
1298 if(!acl->match(remote)) {
1299 ++g_stats.aclDrops;
1300 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
1301 continue;
1302 }
1303
1304 #ifndef HAVE_ACCEPT4
1305 if (!setNonBlocking(ci->fd)) {
1306 continue;
1307 }
1308 #endif
1309 setTCPNoDelay(ci->fd); // disable NAGLE
1310 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
1311 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
1312 continue;
1313 }
1314
1315 if (g_maxTCPConnectionsPerClient) {
1316 std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
1317
1318 if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
1319 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
1320 continue;
1321 }
1322 tcpClientsCount[remote]++;
1323 tcpClientCountIncremented = true;
1324 }
1325
1326 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
1327
1328 ci->remote = remote;
1329 int pipe = g_tcpclientthreads->getThread();
1330 if (pipe >= 0) {
1331 queuedCounterIncremented = true;
1332 auto tmp = ci.release();
1333 try {
1334 writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
1335 }
1336 catch(...) {
1337 delete tmp;
1338 tmp = nullptr;
1339 throw;
1340 }
1341 }
1342 else {
1343 g_tcpclientthreads->decrementQueuedCount();
1344 queuedCounterIncremented = false;
1345 if(tcpClientCountIncremented) {
1346 decrementTCPClientCount(remote);
1347 }
1348 }
1349 }
1350 catch(const std::exception& e) {
1351 errlog("While reading a TCP question: %s", e.what());
1352 if(tcpClientCountIncremented) {
1353 decrementTCPClientCount(remote);
1354 }
1355 if (queuedCounterIncremented) {
1356 g_tcpclientthreads->decrementQueuedCount();
1357 }
1358 }
1359 catch(...){}
1360 }
1361 }