2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-proxy-protocol.hh"
25 #include "dnsdist-rings.hh"
26 #include "dnsdist-xpf.hh"
28 #include "dnsparser.hh"
29 #include "ednsoptions.hh"
33 #include "tcpiohandler.hh"
34 #include "threadname.hh"
37 #include <netinet/tcp.h>
44 /* TCP: the grand design.
45 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
46 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
49 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
50 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
51 to guarantee performance.
53 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
54 So whenever an answer comes in, we know where it needs to go.
59 static std::mutex tcpClientsCountMutex
;
60 static std::map
<ComboAddress
,size_t,ComboAddress::addressOnlyLessThan
> tcpClientsCount
;
61 static const size_t g_maxCachedConnectionsPerDownstream
= 20;
62 uint64_t g_maxTCPQueuedConnections
{1000};
63 size_t g_maxTCPQueriesPerConn
{0};
64 size_t g_maxTCPConnectionDuration
{0};
65 size_t g_maxTCPConnectionsPerClient
{0};
66 uint16_t g_downstreamTCPCleanupInterval
{60};
67 bool g_useTCPSinglePipe
{false};
69 static std::unique_ptr
<Socket
> setupTCPDownstream(shared_ptr
<DownstreamState
>& ds
, uint16_t& downstreamFailures
)
71 std::unique_ptr
<Socket
> result
;
74 vinfolog("TCP connecting to downstream %s (%d)", ds
->remote
.toStringWithPort(), downstreamFailures
);
76 result
= std::unique_ptr
<Socket
>(new Socket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0));
77 if (!IsAnyAddress(ds
->sourceAddr
)) {
78 SSetsockopt(result
->getHandle(), SOL_SOCKET
, SO_REUSEADDR
, 1);
79 #ifdef IP_BIND_ADDRESS_NO_PORT
80 if (ds
->ipBindAddrNoPort
) {
81 SSetsockopt(result
->getHandle(), SOL_IP
, IP_BIND_ADDRESS_NO_PORT
, 1);
84 #ifdef SO_BINDTODEVICE
85 if (!ds
->sourceItfName
.empty()) {
86 int res
= setsockopt(result
->getHandle(), SOL_SOCKET
, SO_BINDTODEVICE
, ds
->sourceItfName
.c_str(), ds
->sourceItfName
.length());
88 vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds
->getNameWithAddr(), stringerror());
92 result
->bind(ds
->sourceAddr
, false);
94 result
->setNonBlocking();
96 if (!ds
->tcpFastOpen
) {
97 SConnectWithTimeout(result
->getHandle(), ds
->remote
, /* no timeout, we will handle it ourselves */ 0);
100 SConnectWithTimeout(result
->getHandle(), ds
->remote
, /* no timeout, we will handle it ourselves */ 0);
101 #endif /* MSG_FASTOPEN */
104 catch(const std::runtime_error
& e
) {
105 vinfolog("Connection to downstream server %s failed: %s", ds
->getName(), e
.what());
106 downstreamFailures
++;
107 if (downstreamFailures
> ds
->retries
) {
111 } while(downstreamFailures
<= ds
->retries
);
116 class TCPConnectionToBackend
119 TCPConnectionToBackend(std::shared_ptr
<DownstreamState
>& ds
, uint16_t& downstreamFailures
, const struct timeval
& now
): d_ds(ds
), d_connectionStartTime(now
), d_enableFastOpen(ds
->tcpFastOpen
)
121 d_socket
= setupTCPDownstream(d_ds
, downstreamFailures
);
122 ++d_ds
->tcpCurrentConnections
;
125 ~TCPConnectionToBackend()
127 if (d_ds
&& d_socket
) {
128 --d_ds
->tcpCurrentConnections
;
130 gettimeofday(&now
, nullptr);
132 auto diff
= now
- d_connectionStartTime
;
133 d_ds
->updateTCPMetrics(d_queries
, diff
.tv_sec
* 1000 + diff
.tv_usec
/ 1000);
137 int getHandle() const
140 throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
143 return d_socket
->getHandle();
146 const ComboAddress
& getRemote() const
166 void disableFastOpen()
168 d_enableFastOpen
= false;
171 bool isFastOpenEnabled()
173 return d_enableFastOpen
;
176 bool canBeReused() const
178 /* we can't reuse a connection where a proxy protocol payload has been sent,
180 - it cannot be reused for a different client
181 - we might have different TLV values for each query
183 if (d_ds
&& d_ds
->useProxyProtocol
) {
189 bool matches(const std::shared_ptr
<DownstreamState
>& ds
) const
198 std::unique_ptr
<Socket
> d_socket
{nullptr};
199 std::shared_ptr
<DownstreamState
> d_ds
{nullptr};
200 struct timeval d_connectionStartTime
;
201 uint64_t d_queries
{0};
203 bool d_enableFastOpen
{false};
206 static thread_local map
<ComboAddress
, std::deque
<std::unique_ptr
<TCPConnectionToBackend
>>> t_downstreamConnections
;
208 static std::unique_ptr
<TCPConnectionToBackend
> getConnectionToDownstream(std::shared_ptr
<DownstreamState
>& ds
, uint16_t& downstreamFailures
, const struct timeval
& now
)
210 std::unique_ptr
<TCPConnectionToBackend
> result
;
212 const auto& it
= t_downstreamConnections
.find(ds
->remote
);
213 if (it
!= t_downstreamConnections
.end()) {
214 auto& list
= it
->second
;
216 result
= std::move(list
.front());
223 return std::unique_ptr
<TCPConnectionToBackend
>(new TCPConnectionToBackend(ds
, downstreamFailures
, now
));
226 static void releaseDownstreamConnection(std::unique_ptr
<TCPConnectionToBackend
>&& conn
)
228 if (conn
== nullptr) {
232 if (!conn
->canBeReused()) {
237 const auto& remote
= conn
->getRemote();
238 const auto& it
= t_downstreamConnections
.find(remote
);
239 if (it
!= t_downstreamConnections
.end()) {
240 auto& list
= it
->second
;
241 if (list
.size() >= g_maxCachedConnectionsPerDownstream
) {
242 /* too many connections queued already */
246 list
.push_back(std::move(conn
));
249 t_downstreamConnections
[remote
].push_back(std::move(conn
));
253 struct ConnectionInfo
255 ConnectionInfo(ClientState
* cs_
): cs(cs_
), fd(-1)
258 ConnectionInfo(ConnectionInfo
&& rhs
): remote(rhs
.remote
), cs(rhs
.cs
), fd(rhs
.fd
)
264 ConnectionInfo(const ConnectionInfo
& rhs
) = delete;
265 ConnectionInfo
& operator=(const ConnectionInfo
& rhs
) = delete;
267 ConnectionInfo
& operator=(ConnectionInfo
&& rhs
)
284 --cs
->tcpCurrentConnections
;
289 ClientState
* cs
{nullptr};
293 void tcpClientThread(int pipefd
);
295 static void decrementTCPClientCount(const ComboAddress
& client
)
297 if (g_maxTCPConnectionsPerClient
) {
298 std::lock_guard
<std::mutex
> lock(tcpClientsCountMutex
);
299 tcpClientsCount
[client
]--;
300 if (tcpClientsCount
[client
] == 0) {
301 tcpClientsCount
.erase(client
);
306 void TCPClientCollection::addTCPClientThread()
308 int pipefds
[2] = { -1, -1};
310 vinfolog("Adding TCP Client thread");
312 if (d_useSinglePipe
) {
313 pipefds
[0] = d_singlePipe
[0];
314 pipefds
[1] = d_singlePipe
[1];
317 if (pipe(pipefds
) < 0) {
318 errlog("Error creating the TCP thread communication pipe: %s", stringerror());
322 if (!setNonBlocking(pipefds
[0])) {
326 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err
));
330 if (!setNonBlocking(pipefds
[1])) {
334 errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err
));
340 std::lock_guard
<std::mutex
> lock(d_mutex
);
342 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
343 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
344 if (!d_useSinglePipe
) {
352 thread
t1(tcpClientThread
, pipefds
[0]);
355 catch(const std::runtime_error
& e
) {
356 /* the thread creation failed, don't leak */
357 errlog("Error creating a TCP thread: %s", e
.what());
358 if (!d_useSinglePipe
) {
365 d_tcpclientthreads
.push_back(pipefds
[1]);
370 static void cleanupClosedTCPConnections()
372 for(auto dsIt
= t_downstreamConnections
.begin(); dsIt
!= t_downstreamConnections
.end(); ) {
373 for (auto connIt
= dsIt
->second
.begin(); connIt
!= dsIt
->second
.end(); ) {
374 if (*connIt
&& isTCPSocketUsable((*connIt
)->getHandle())) {
378 connIt
= dsIt
->second
.erase(connIt
);
382 if (!dsIt
->second
.empty()) {
386 dsIt
= t_downstreamConnections
.erase(dsIt
);
391 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
392 Updates pos everytime a successful read occurs,
393 throws an std::runtime_error in case of IO error,
394 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
397 // XXX could probably be implemented as a TCPIOHandler
398 IOState
tryRead(int fd
, std::vector
<uint8_t>& buffer
, size_t& pos
, size_t toRead
)
400 if (buffer
.size() < (pos
+ toRead
)) {
401 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer
.size()) + ") for a read of " + std::to_string(toRead
) + " bytes starting at " + std::to_string(pos
));
406 ssize_t res
= ::read(fd
, reinterpret_cast<char*>(&buffer
.at(pos
)), toRead
- got
);
408 throw runtime_error("EOF while reading message");
411 if (errno
== EAGAIN
|| errno
== EWOULDBLOCK
|| errno
== ENOTCONN
) {
412 return IOState::NeedRead
;
415 throw std::runtime_error(std::string("Error while reading message: ") + stringerror());
419 pos
+= static_cast<size_t>(res
);
420 got
+= static_cast<size_t>(res
);
422 while (got
< toRead
);
424 return IOState::Done
;
427 std::unique_ptr
<TCPClientCollection
> g_tcpclientthreads
;
429 class TCPClientThreadData
432 TCPClientThreadData(): localRespRulactions(g_resprulactions
.getLocal()), mplexer(std::unique_ptr
<FDMultiplexer
>(FDMultiplexer::getMultiplexerSilent()))
436 LocalHolders holders
;
437 LocalStateHolder
<vector
<DNSDistResponseRuleAction
> > localRespRulactions
;
438 std::unique_ptr
<FDMultiplexer
> mplexer
{nullptr};
441 static void handleDownstreamIOCallback(int fd
, FDMultiplexer::funcparam_t
& param
);
443 class IncomingTCPConnectionState
446 IncomingTCPConnectionState(ConnectionInfo
&& ci
, TCPClientThreadData
& threadData
, const struct timeval
& now
): d_buffer(s_maxPacketCacheEntrySize
), d_responseBuffer(s_maxPacketCacheEntrySize
), d_threadData(threadData
), d_ci(std::move(ci
)), d_handler(d_ci
.fd
, g_tcpRecvTimeout
, d_ci
.cs
->tlsFrontend
? d_ci
.cs
->tlsFrontend
->getContext() : nullptr, now
.tv_sec
), d_connectionStartTime(now
)
448 d_ids
.origDest
.reset();
449 d_ids
.origDest
.sin4
.sin_family
= d_ci
.remote
.sin4
.sin_family
;
450 socklen_t socklen
= d_ids
.origDest
.getSocklen();
451 if (getsockname(d_ci
.fd
, reinterpret_cast<sockaddr
*>(&d_ids
.origDest
), &socklen
)) {
452 d_ids
.origDest
= d_ci
.cs
->local
;
456 IncomingTCPConnectionState(const IncomingTCPConnectionState
& rhs
) = delete;
457 IncomingTCPConnectionState
& operator=(const IncomingTCPConnectionState
& rhs
) = delete;
459 ~IncomingTCPConnectionState()
461 decrementTCPClientCount(d_ci
.remote
);
462 if (d_ci
.cs
!= nullptr) {
464 gettimeofday(&now
, nullptr);
466 auto diff
= now
- d_connectionStartTime
;
467 d_ci
.cs
->updateTCPMetrics(d_queriesCount
, diff
.tv_sec
* 1000.0 + diff
.tv_usec
/ 1000.0);
470 if (d_ds
!= nullptr) {
473 d_outstanding
= false;
476 if (d_downstreamConnection
) {
478 if (d_lastIOState
== IOState::NeedRead
) {
479 cerr
<<__func__
<<": removing leftover backend read FD "<<d_downstreamConnection
->getHandle()<<endl
;
480 d_threadData
.mplexer
->removeReadFD(d_downstreamConnection
->getHandle());
482 else if (d_lastIOState
== IOState::NeedWrite
) {
483 cerr
<<__func__
<<": removing leftover backend write FD "<<d_downstreamConnection
->getHandle()<<endl
;
484 d_threadData
.mplexer
->removeWriteFD(d_downstreamConnection
->getHandle());
487 catch(const FDMultiplexerException
& e
) {
488 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds
->getName(), e
.what());
490 catch(const std::runtime_error
& e
) {
491 /* might be thrown by getHandle() */
492 vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds
->getName(), e
.what());
498 if (d_lastIOState
== IOState::NeedRead
) {
499 cerr
<<__func__
<<": removing leftover client read FD "<<d_ci
.fd
<<endl
;
500 d_threadData
.mplexer
->removeReadFD(d_ci
.fd
);
502 else if (d_lastIOState
== IOState::NeedWrite
) {
503 cerr
<<__func__
<<": removing leftover client write FD "<<d_ci
.fd
<<endl
;
504 d_threadData
.mplexer
->removeWriteFD(d_ci
.fd
);
507 catch(const FDMultiplexerException
& e
) {
508 vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci
.remote
.toStringWithPort(), e
.what());
512 void resetForNewQuery()
514 d_buffer
.resize(sizeof(uint16_t));
518 d_downstreamFailures
= 0;
519 d_state
= State::readingQuerySize
;
520 d_lastIOState
= IOState::Done
;
521 d_selfGeneratedResponse
= false;
524 boost::optional
<struct timeval
> getClientReadTTD(struct timeval now
) const
526 if (g_maxTCPConnectionDuration
== 0 && g_tcpRecvTimeout
== 0) {
530 if (g_maxTCPConnectionDuration
> 0) {
531 auto elapsed
= now
.tv_sec
- d_connectionStartTime
.tv_sec
;
532 if (elapsed
< 0 || (static_cast<size_t>(elapsed
) >= g_maxTCPConnectionDuration
)) {
535 auto remaining
= g_maxTCPConnectionDuration
- elapsed
;
536 if (g_tcpRecvTimeout
== 0 || remaining
<= static_cast<size_t>(g_tcpRecvTimeout
)) {
537 now
.tv_sec
+= remaining
;
542 now
.tv_sec
+= g_tcpRecvTimeout
;
546 boost::optional
<struct timeval
> getBackendReadTTD(const struct timeval
& now
) const
548 if (d_ds
== nullptr) {
549 throw std::runtime_error("getBackendReadTTD() without any backend selected");
551 if (d_ds
->tcpRecvTimeout
== 0) {
555 struct timeval res
= now
;
556 res
.tv_sec
+= d_ds
->tcpRecvTimeout
;
561 boost::optional
<struct timeval
> getClientWriteTTD(const struct timeval
& now
) const
563 if (g_maxTCPConnectionDuration
== 0 && g_tcpSendTimeout
== 0) {
567 struct timeval res
= now
;
569 if (g_maxTCPConnectionDuration
> 0) {
570 auto elapsed
= res
.tv_sec
- d_connectionStartTime
.tv_sec
;
571 if (elapsed
< 0 || static_cast<size_t>(elapsed
) >= g_maxTCPConnectionDuration
) {
574 auto remaining
= g_maxTCPConnectionDuration
- elapsed
;
575 if (g_tcpSendTimeout
== 0 || remaining
<= static_cast<size_t>(g_tcpSendTimeout
)) {
576 res
.tv_sec
+= remaining
;
581 res
.tv_sec
+= g_tcpSendTimeout
;
585 boost::optional
<struct timeval
> getBackendWriteTTD(const struct timeval
& now
) const
587 if (d_ds
== nullptr) {
588 throw std::runtime_error("getBackendReadTTD() called without any backend selected");
590 if (d_ds
->tcpSendTimeout
== 0) {
594 struct timeval res
= now
;
595 res
.tv_sec
+= d_ds
->tcpSendTimeout
;
600 bool maxConnectionDurationReached(unsigned int maxConnectionDuration
, const struct timeval
& now
)
602 if (maxConnectionDuration
) {
603 time_t curtime
= now
.tv_sec
;
604 unsigned int elapsed
= 0;
605 if (curtime
> d_connectionStartTime
.tv_sec
) { // To prevent issues when time goes backward
606 elapsed
= curtime
- d_connectionStartTime
.tv_sec
;
608 if (elapsed
>= maxConnectionDuration
) {
611 d_remainingTime
= maxConnectionDuration
- elapsed
;
619 static std::mutex s_mutex
;
622 gettimeofday(&now
, 0);
625 std::lock_guard
<std::mutex
> lock(s_mutex
);
626 fprintf(stderr
, "State is %p\n", this);
627 cerr
<< "Current state is " << static_cast<int>(d_state
) << ", got "<<d_queriesCount
<<" queries so far" << endl
;
628 cerr
<< "Current time is " << now
.tv_sec
<< " - " << now
.tv_usec
<< endl
;
629 cerr
<< "Connection started at " << d_connectionStartTime
.tv_sec
<< " - " << d_connectionStartTime
.tv_usec
<< endl
;
630 if (d_state
> State::doingHandshake
) {
631 cerr
<< "Handshake done at " << d_handshakeDoneTime
.tv_sec
<< " - " << d_handshakeDoneTime
.tv_usec
<< endl
;
633 if (d_state
> State::readingQuerySize
) {
634 cerr
<< "Got first query size at " << d_firstQuerySizeReadTime
.tv_sec
<< " - " << d_firstQuerySizeReadTime
.tv_usec
<< endl
;
636 if (d_state
> State::readingQuerySize
) {
637 cerr
<< "Got query size at " << d_querySizeReadTime
.tv_sec
<< " - " << d_querySizeReadTime
.tv_usec
<< endl
;
639 if (d_state
> State::readingQuery
) {
640 cerr
<< "Got query at " << d_queryReadTime
.tv_sec
<< " - " << d_queryReadTime
.tv_usec
<< endl
;
642 if (d_state
> State::sendingQueryToBackend
) {
643 cerr
<< "Sent query at " << d_querySentTime
.tv_sec
<< " - " << d_querySentTime
.tv_usec
<< endl
;
645 if (d_state
> State::readingResponseFromBackend
) {
646 cerr
<< "Got response at " << d_responseReadTime
.tv_sec
<< " - " << d_responseReadTime
.tv_usec
<< endl
;
651 enum class State
{ doingHandshake
, readingQuerySize
, readingQuery
, sendingQueryToBackend
, readingResponseSizeFromBackend
, readingResponseFromBackend
, sendingResponse
};
653 std::vector
<uint8_t> d_buffer
;
654 std::vector
<uint8_t> d_responseBuffer
;
655 TCPClientThreadData
& d_threadData
;
658 TCPIOHandler d_handler
;
659 std::unique_ptr
<TCPConnectionToBackend
> d_downstreamConnection
{nullptr};
660 std::shared_ptr
<DownstreamState
> d_ds
{nullptr};
661 dnsheader d_cleartextDH
;
662 struct timeval d_connectionStartTime
;
663 struct timeval d_handshakeDoneTime
;
664 struct timeval d_firstQuerySizeReadTime
;
665 struct timeval d_querySizeReadTime
;
666 struct timeval d_queryReadTime
;
667 struct timeval d_querySentTime
;
668 struct timeval d_responseReadTime
;
669 size_t d_currentPos
{0};
670 size_t d_queriesCount
{0};
671 unsigned int d_remainingTime
{0};
672 uint16_t d_querySize
{0};
673 uint16_t d_responseSize
{0};
674 uint16_t d_downstreamFailures
{0};
675 State d_state
{State::doingHandshake
};
676 IOState d_lastIOState
{IOState::Done
};
677 bool d_readingFirstQuery
{true};
678 bool d_outstanding
{false};
679 bool d_firstResponsePacket
{true};
681 bool d_xfrStarted
{false};
682 bool d_selfGeneratedResponse
{false};
683 bool d_proxyProtocolPayloadAdded
{false};
684 bool d_proxyProtocolPayloadHasTLV
{false};
687 static void handleIOCallback(int fd
, FDMultiplexer::funcparam_t
& param
);
688 static void handleNewIOState(std::shared_ptr
<IncomingTCPConnectionState
>& state
, IOState iostate
, const int fd
, FDMultiplexer::callbackfunc_t callback
, boost::optional
<struct timeval
> ttd
=boost::none
);
689 static void handleIO(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
);
690 static void handleDownstreamIO(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
);
692 static void handleResponseSent(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
694 handleNewIOState(state
, IOState::Done
, state
->d_ci
.fd
, handleIOCallback
);
696 if (state
->d_isXFR
&& state
->d_downstreamConnection
) {
697 /* we need to resume reading from the backend! */
698 state
->d_state
= IncomingTCPConnectionState::State::readingResponseSizeFromBackend
;
699 state
->d_currentPos
= 0;
700 handleDownstreamIO(state
, now
);
704 if (state
->d_selfGeneratedResponse
== false && state
->d_ds
) {
705 /* if we have no downstream server selected, this was a self-answered response
706 but cache hits have a selected server as well, so be careful */
707 struct timespec answertime
;
708 gettime(&answertime
);
709 double udiff
= state
->d_ids
.sentTime
.udiff();
710 g_rings
.insertResponse(answertime
, state
->d_ci
.remote
, state
->d_ids
.qname
, state
->d_ids
.qtype
, static_cast<unsigned int>(udiff
), static_cast<unsigned int>(state
->d_responseBuffer
.size()), state
->d_cleartextDH
, state
->d_ds
->remote
);
711 vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state
->d_ds
->remote
.toStringWithPort(), state
->d_ids
.origRemote
.toStringWithPort(), (state
->d_ci
.cs
->tlsFrontend
? "DoT" : "TCP"), udiff
);
714 switch (state
->d_cleartextDH
.rcode
) {
715 case RCode::NXDomain
:
716 ++g_stats
.frontendNXDomain
;
718 case RCode::ServFail
:
719 ++g_stats
.servfailResponses
;
720 ++g_stats
.frontendServFail
;
723 ++g_stats
.frontendNoError
;
727 if (g_maxTCPQueriesPerConn
&& state
->d_queriesCount
> g_maxTCPQueriesPerConn
) {
728 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state
->d_ci
.remote
.toStringWithPort(), state
->d_queriesCount
, g_maxTCPQueriesPerConn
);
732 if (state
->maxConnectionDurationReached(g_maxTCPConnectionDuration
, now
)) {
733 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state
->d_ci
.remote
.toStringWithPort());
737 state
->resetForNewQuery();
739 handleIO(state
, now
);
742 static void sendResponse(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
744 state
->d_state
= IncomingTCPConnectionState::State::sendingResponse
;
745 const uint8_t sizeBytes
[] = { static_cast<uint8_t>(state
->d_responseSize
/ 256), static_cast<uint8_t>(state
->d_responseSize
% 256) };
746 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
747 that could occur if we had to deal with the size during the processing,
748 especially alignment issues */
749 state
->d_responseBuffer
.insert(state
->d_responseBuffer
.begin(), sizeBytes
, sizeBytes
+ 2);
751 state
->d_currentPos
= 0;
753 handleIO(state
, now
);
756 static void handleResponse(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
758 if (state
->d_responseSize
< sizeof(dnsheader
) || !state
->d_ds
) {
762 auto response
= reinterpret_cast<char*>(&state
->d_responseBuffer
.at(0));
763 unsigned int consumed
;
764 if (state
->d_firstResponsePacket
&& !responseContentMatches(response
, state
->d_responseSize
, state
->d_ids
.qname
, state
->d_ids
.qtype
, state
->d_ids
.qclass
, state
->d_ds
->remote
, consumed
)) {
767 state
->d_firstResponsePacket
= false;
769 if (state
->d_outstanding
) {
770 --state
->d_ds
->outstanding
;
771 state
->d_outstanding
= false;
774 auto dh
= reinterpret_cast<struct dnsheader
*>(response
);
775 uint16_t addRoom
= 0;
776 DNSResponse dr
= makeDNSResponseFromIDState(state
->d_ids
, dh
, state
->d_responseBuffer
.size(), state
->d_responseSize
, true);
777 if (dr
.dnsCryptQuery
) {
778 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
781 memcpy(&state
->d_cleartextDH
, dr
.dh
, sizeof(state
->d_cleartextDH
));
783 std::vector
<uint8_t> rewrittenResponse
;
784 size_t responseSize
= state
->d_responseBuffer
.size();
785 if (!processResponse(&response
, &state
->d_responseSize
, &responseSize
, state
->d_threadData
.localRespRulactions
, dr
, addRoom
, rewrittenResponse
, false)) {
789 if (!rewrittenResponse
.empty()) {
790 /* responseSize has been updated as well but we don't really care since it will match
791 the capacity of rewrittenResponse anyway */
792 state
->d_responseBuffer
= std::move(rewrittenResponse
);
793 state
->d_responseSize
= state
->d_responseBuffer
.size();
795 /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
796 state
->d_responseBuffer
.resize(state
->d_responseSize
);
799 if (state
->d_isXFR
&& !state
->d_xfrStarted
) {
800 /* don't bother parsing the content of the response for now */
801 state
->d_xfrStarted
= true;
803 ++state
->d_ci
.cs
->responses
;
804 ++state
->d_ds
->responses
;
807 if (!state
->d_isXFR
) {
809 ++state
->d_ci
.cs
->responses
;
810 ++state
->d_ds
->responses
;
813 sendResponse(state
, now
);
816 static void sendQueryToBackend(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
818 auto ds
= state
->d_ds
;
819 state
->d_state
= IncomingTCPConnectionState::State::sendingQueryToBackend
;
820 state
->d_currentPos
= 0;
821 state
->d_firstResponsePacket
= true;
823 if (state
->d_xfrStarted
) {
824 /* sorry, but we are not going to resume a XFR if we have already sent some packets
829 if (!state
->d_downstreamConnection
) {
830 if (state
->d_downstreamFailures
< state
->d_ds
->retries
) {
832 state
->d_downstreamConnection
= getConnectionToDownstream(ds
, state
->d_downstreamFailures
, now
);
834 catch (const std::runtime_error
& e
) {
835 state
->d_downstreamConnection
.reset();
839 if (!state
->d_downstreamConnection
) {
841 ++state
->d_ci
.cs
->tcpGaveUp
;
842 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), state
->d_downstreamFailures
);
846 if (ds
->useProxyProtocol
&& !state
->d_proxyProtocolPayloadAdded
) {
847 /* we know there is no TLV values to add, otherwise we would not have tried
848 to reuse the connection and d_proxyProtocolPayloadAdded would be true already */
849 addProxyProtocol(state
->d_buffer
, true, state
->d_ci
.remote
, state
->d_ids
.origDest
, std::vector
<ProxyProtocolValue
>());
850 state
->d_proxyProtocolPayloadAdded
= true;
854 vinfolog("Got query for %s|%s from %s (%s), relayed to %s", state
->d_ids
.qname
.toLogString(), QType(state
->d_ids
.qtype
).getName(), state
->d_ci
.remote
.toStringWithPort(), (state
->d_ci
.cs
->tlsFrontend
? "DoT" : "TCP"), ds
->getName());
856 handleDownstreamIO(state
, now
);
860 static void handleQuery(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
862 if (state
->d_querySize
< sizeof(dnsheader
)) {
863 ++g_stats
.nonCompliantQueries
;
867 state
->d_readingFirstQuery
= false;
868 state
->d_proxyProtocolPayloadAdded
= false;
869 ++state
->d_queriesCount
;
870 ++state
->d_ci
.cs
->queries
;
873 if (state
->d_handler
.isTLS()) {
874 auto tlsVersion
= state
->d_handler
.getTLSVersion();
875 switch (tlsVersion
) {
876 case LibsslTLSVersion::TLS10
:
877 ++state
->d_ci
.cs
->tls10queries
;
879 case LibsslTLSVersion::TLS11
:
880 ++state
->d_ci
.cs
->tls11queries
;
882 case LibsslTLSVersion::TLS12
:
883 ++state
->d_ci
.cs
->tls12queries
;
885 case LibsslTLSVersion::TLS13
:
886 ++state
->d_ci
.cs
->tls13queries
;
889 ++state
->d_ci
.cs
->tlsUnknownqueries
;
893 /* we need an accurate ("real") value for the response and
894 to store into the IDS, but not for insertion into the
896 struct timespec queryRealTime
;
897 gettime(&queryRealTime
, true);
899 auto query
= reinterpret_cast<char*>(&state
->d_buffer
.at(0));
900 std::shared_ptr
<DNSCryptQuery
> dnsCryptQuery
{nullptr};
901 auto dnsCryptResponse
= checkDNSCryptQuery(*state
->d_ci
.cs
, query
, state
->d_querySize
, dnsCryptQuery
, queryRealTime
.tv_sec
, true);
902 if (dnsCryptResponse
) {
903 state
->d_responseBuffer
= std::move(*dnsCryptResponse
);
904 state
->d_responseSize
= state
->d_responseBuffer
.size();
905 sendResponse(state
, now
);
909 const auto& dh
= reinterpret_cast<dnsheader
*>(query
);
910 if (!checkQueryHeaders(dh
)) {
914 uint16_t qtype
, qclass
;
915 unsigned int consumed
= 0;
916 DNSName
qname(query
, state
->d_querySize
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
917 DNSQuestion
dq(&qname
, qtype
, qclass
, consumed
, &state
->d_ids
.origDest
, &state
->d_ci
.remote
, reinterpret_cast<dnsheader
*>(query
), state
->d_buffer
.size(), state
->d_querySize
, true, &queryRealTime
);
918 dq
.dnsCryptQuery
= std::move(dnsCryptQuery
);
919 dq
.sni
= state
->d_handler
.getServerNameIndication();
921 state
->d_isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
922 if (state
->d_isXFR
) {
927 auto result
= processQuery(dq
, *state
->d_ci
.cs
, state
->d_threadData
.holders
, state
->d_ds
);
929 if (result
== ProcessQueryResult::Drop
) {
933 if (result
== ProcessQueryResult::SendAnswer
) {
934 state
->d_selfGeneratedResponse
= true;
935 state
->d_buffer
.resize(dq
.len
);
936 state
->d_responseBuffer
= std::move(state
->d_buffer
);
937 state
->d_responseSize
= state
->d_responseBuffer
.size();
938 sendResponse(state
, now
);
942 if (result
!= ProcessQueryResult::PassToBackend
|| state
->d_ds
== nullptr) {
946 setIDStateFromDNSQuestion(state
->d_ids
, dq
, std::move(qname
));
948 const uint8_t sizeBytes
[] = { static_cast<uint8_t>(dq
.len
/ 256), static_cast<uint8_t>(dq
.len
% 256) };
949 /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
950 that could occur if we had to deal with the size during the processing,
951 especially alignment issues */
952 state
->d_buffer
.insert(state
->d_buffer
.begin(), sizeBytes
, sizeBytes
+ 2);
954 dq
.dh
= reinterpret_cast<dnsheader
*>(&state
->d_buffer
.at(0));
955 dq
.size
= state
->d_buffer
.size();
956 state
->d_buffer
.resize(dq
.len
);
958 if (state
->d_ds
->useProxyProtocol
) {
959 /* if we ever sent a TLV over a connection, we can never go back */
960 if (!state
->d_proxyProtocolPayloadHasTLV
) {
961 state
->d_proxyProtocolPayloadHasTLV
= dq
.proxyProtocolValues
&& !dq
.proxyProtocolValues
->empty();
964 if (state
->d_downstreamConnection
&& !state
->d_proxyProtocolPayloadHasTLV
&& state
->d_downstreamConnection
->matches(state
->d_ds
)) {
965 /* we have an existing connection, on which we already sent a Proxy Protocol header with no values
966 (in the previous query had TLV values we would have reset the connection afterwards),
967 so let's reuse it as long as we still don't have any values */
968 state
->d_proxyProtocolPayloadAdded
= false;
971 state
->d_downstreamConnection
.reset();
972 addProxyProtocol(state
->d_buffer
, true, state
->d_ci
.remote
, state
->d_ids
.origDest
, dq
.proxyProtocolValues
? *dq
.proxyProtocolValues
: std::vector
<ProxyProtocolValue
>());
973 state
->d_proxyProtocolPayloadAdded
= true;
977 sendQueryToBackend(state
, now
);
980 static void handleNewIOState(std::shared_ptr
<IncomingTCPConnectionState
>& state
, IOState iostate
, const int fd
, FDMultiplexer::callbackfunc_t callback
, boost::optional
<struct timeval
> ttd
)
982 //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
984 if (state
->d_lastIOState
== IOState::NeedRead
&& iostate
!= IOState::NeedRead
) {
985 state
->d_threadData
.mplexer
->removeReadFD(fd
);
986 //cerr<<__func__<<": remove read FD "<<fd<<endl;
987 state
->d_lastIOState
= IOState::Done
;
989 else if (state
->d_lastIOState
== IOState::NeedWrite
&& iostate
!= IOState::NeedWrite
) {
990 state
->d_threadData
.mplexer
->removeWriteFD(fd
);
991 //cerr<<__func__<<": remove write FD "<<fd<<endl;
992 state
->d_lastIOState
= IOState::Done
;
995 if (iostate
== IOState::NeedRead
) {
996 if (state
->d_lastIOState
== IOState::NeedRead
) {
998 /* let's update the TTD ! */
999 state
->d_threadData
.mplexer
->setReadTTD(fd
, *ttd
, /* we pass 0 here because we already have a TTD */0);
1004 state
->d_lastIOState
= IOState::NeedRead
;
1005 //cerr<<__func__<<": add read FD "<<fd<<endl;
1006 state
->d_threadData
.mplexer
->addReadFD(fd
, callback
, state
, ttd
? &*ttd
: nullptr);
1008 else if (iostate
== IOState::NeedWrite
) {
1009 if (state
->d_lastIOState
== IOState::NeedWrite
) {
1013 state
->d_lastIOState
= IOState::NeedWrite
;
1014 //cerr<<__func__<<": add write FD "<<fd<<endl;
1015 state
->d_threadData
.mplexer
->addWriteFD(fd
, callback
, state
, ttd
? &*ttd
: nullptr);
1017 else if (iostate
== IOState::Done
) {
1018 state
->d_lastIOState
= IOState::Done
;
1022 static void handleDownstreamIO(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
1024 if (state
->d_downstreamConnection
== nullptr) {
1025 throw std::runtime_error("No downstream socket in " + std::string(__func__
) + "!");
1028 int fd
= state
->d_downstreamConnection
->getHandle();
1029 IOState iostate
= IOState::Done
;
1030 bool connectionDied
= false;
1033 if (state
->d_state
== IncomingTCPConnectionState::State::sendingQueryToBackend
) {
1034 int socketFlags
= 0;
1036 if (state
->d_downstreamConnection
->isFastOpenEnabled()) {
1037 socketFlags
|= MSG_FASTOPEN
;
1039 #endif /* MSG_FASTOPEN */
1041 size_t sent
= sendMsgWithOptions(fd
, reinterpret_cast<const char *>(&state
->d_buffer
.at(state
->d_currentPos
)), state
->d_buffer
.size() - state
->d_currentPos
, &state
->d_ds
->remote
, &state
->d_ds
->sourceAddr
, state
->d_ds
->sourceItf
, socketFlags
);
1042 if (sent
== state
->d_buffer
.size()) {
1043 /* request sent ! */
1044 state
->d_downstreamConnection
->incQueries();
1045 state
->d_state
= IncomingTCPConnectionState::State::readingResponseSizeFromBackend
;
1046 state
->d_currentPos
= 0;
1047 state
->d_querySentTime
= now
;
1048 iostate
= IOState::NeedRead
;
1049 if (!state
->d_isXFR
&& !state
->d_outstanding
) {
1050 /* don't bother with the outstanding count for XFR queries */
1051 ++state
->d_ds
->outstanding
;
1052 state
->d_outstanding
= true;
1056 state
->d_currentPos
+= sent
;
1057 iostate
= IOState::NeedWrite
;
1058 /* disable fast open on partial write */
1059 state
->d_downstreamConnection
->disableFastOpen();
1063 if (state
->d_state
== IncomingTCPConnectionState::State::readingResponseSizeFromBackend
) {
1064 // then we need to allocate a new buffer (new because we might need to re-send the query if the
1065 // backend dies on us
1066 // We also might need to read and send to the client more than one response in case of XFR (yeah!)
1067 // should very likely be a TCPIOHandler d_downstreamHandler
1068 iostate
= tryRead(fd
, state
->d_responseBuffer
, state
->d_currentPos
, sizeof(uint16_t) - state
->d_currentPos
);
1069 if (iostate
== IOState::Done
) {
1070 state
->d_state
= IncomingTCPConnectionState::State::readingResponseFromBackend
;
1071 state
->d_responseSize
= state
->d_responseBuffer
.at(0) * 256 + state
->d_responseBuffer
.at(1);
1072 state
->d_responseBuffer
.resize((state
->d_ids
.dnsCryptQuery
&& (UINT16_MAX
- state
->d_responseSize
) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
)) ? state
->d_responseSize
+ DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
: state
->d_responseSize
);
1073 state
->d_currentPos
= 0;
1077 if (state
->d_state
== IncomingTCPConnectionState::State::readingResponseFromBackend
) {
1078 iostate
= tryRead(fd
, state
->d_responseBuffer
, state
->d_currentPos
, state
->d_responseSize
- state
->d_currentPos
);
1079 if (iostate
== IOState::Done
) {
1080 handleNewIOState(state
, IOState::Done
, fd
, handleDownstreamIOCallback
);
1082 if (state
->d_isXFR
) {
1083 /* Don't reuse the TCP connection after an {A,I}XFR */
1084 /* but don't reset it either, we will need to read more messages */
1087 /* if we did not send a Proxy Protocol header, let's pool the connection */
1088 if (state
->d_ds
&& state
->d_ds
->useProxyProtocol
== false) {
1089 releaseDownstreamConnection(std::move(state
->d_downstreamConnection
));
1092 if (state
->d_proxyProtocolPayloadHasTLV
) {
1093 /* sent a Proxy Protocol header with TLV values, we can't reuse it */
1094 state
->d_downstreamConnection
.reset();
1097 /* if we did but there was no TLV values, let's try to reuse it but only
1098 for this incoming connection */
1104 state
->d_responseReadTime
= now
;
1106 handleResponse(state
, now
);
1108 catch (const std::exception
& e
) {
1109 vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state
->d_ds
? state
->d_ds
->getName() : "unknown", state
->d_ci
.remote
.toStringWithPort(), e
.what());
1115 if (state
->d_state
!= IncomingTCPConnectionState::State::sendingQueryToBackend
&&
1116 state
->d_state
!= IncomingTCPConnectionState::State::readingResponseSizeFromBackend
&&
1117 state
->d_state
!= IncomingTCPConnectionState::State::readingResponseFromBackend
) {
1118 vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state
->d_state
));
1121 catch(const std::exception
& e
) {
1122 /* most likely an EOF because the other end closed the connection,
1123 but it might also be a real IO error or something else.
1124 Let's just drop the connection
1126 vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state
->d_lastIOState
== IOState::NeedRead
? "reading from" : "writing to"), state
->d_ci
.remote
.toStringWithPort(), e
.what());
1127 if (state
->d_state
== IncomingTCPConnectionState::State::sendingQueryToBackend
) {
1128 ++state
->d_ds
->tcpDiedSendingQuery
;
1131 ++state
->d_ds
->tcpDiedReadingResponse
;
1134 /* don't increase this counter when reusing connections */
1135 if (state
->d_downstreamConnection
&& state
->d_downstreamConnection
->isFresh()) {
1136 ++state
->d_downstreamFailures
;
1139 if (state
->d_outstanding
) {
1140 state
->d_outstanding
= false;
1142 if (state
->d_ds
!= nullptr) {
1143 --state
->d_ds
->outstanding
;
1146 /* remove this FD from the IO multiplexer */
1147 iostate
= IOState::Done
;
1148 connectionDied
= true;
1151 if (iostate
== IOState::Done
) {
1152 handleNewIOState(state
, iostate
, fd
, handleDownstreamIOCallback
);
1155 handleNewIOState(state
, iostate
, fd
, handleDownstreamIOCallback
, iostate
== IOState::NeedRead
? state
->getBackendReadTTD(now
) : state
->getBackendWriteTTD(now
));
1158 if (connectionDied
) {
1159 state
->d_downstreamConnection
.reset();
1160 sendQueryToBackend(state
, now
);
1164 static void handleDownstreamIOCallback(int fd
, FDMultiplexer::funcparam_t
& param
)
1166 auto state
= boost::any_cast
<std::shared_ptr
<IncomingTCPConnectionState
>>(param
);
1167 if (state
->d_downstreamConnection
== nullptr) {
1168 throw std::runtime_error("No downstream socket in " + std::string(__func__
) + "!");
1170 if (fd
!= state
->d_downstreamConnection
->getHandle()) {
1171 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd
) + " received in " + std::string(__func__
) + ", expected " + std::to_string(state
->d_downstreamConnection
->getHandle()));
1175 gettimeofday(&now
, 0);
1176 handleDownstreamIO(state
, now
);
1179 static void handleIO(std::shared_ptr
<IncomingTCPConnectionState
>& state
, struct timeval
& now
)
1181 int fd
= state
->d_ci
.fd
;
1182 IOState iostate
= IOState::Done
;
1184 if (state
->maxConnectionDurationReached(g_maxTCPConnectionDuration
, now
)) {
1185 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state
->d_ci
.remote
.toStringWithPort());
1186 handleNewIOState(state
, IOState::Done
, fd
, handleIOCallback
);
1191 if (state
->d_state
== IncomingTCPConnectionState::State::doingHandshake
) {
1192 iostate
= state
->d_handler
.tryHandshake();
1193 if (iostate
== IOState::Done
) {
1194 if (state
->d_handler
.isTLS()) {
1195 if (!state
->d_handler
.hasTLSSessionBeenResumed()) {
1196 ++state
->d_ci
.cs
->tlsNewSessions
;
1199 ++state
->d_ci
.cs
->tlsResumptions
;
1201 if (state
->d_handler
.getResumedFromInactiveTicketKey()) {
1202 ++state
->d_ci
.cs
->tlsInactiveTicketKey
;
1204 if (state
->d_handler
.getUnknownTicketKey()) {
1205 ++state
->d_ci
.cs
->tlsUnknownTicketKey
;
1209 state
->d_handshakeDoneTime
= now
;
1210 state
->d_state
= IncomingTCPConnectionState::State::readingQuerySize
;
1214 if (state
->d_state
== IncomingTCPConnectionState::State::readingQuerySize
) {
1215 iostate
= state
->d_handler
.tryRead(state
->d_buffer
, state
->d_currentPos
, sizeof(uint16_t));
1216 if (iostate
== IOState::Done
) {
1217 state
->d_state
= IncomingTCPConnectionState::State::readingQuery
;
1218 state
->d_querySizeReadTime
= now
;
1219 if (state
->d_queriesCount
== 0) {
1220 state
->d_firstQuerySizeReadTime
= now
;
1222 state
->d_querySize
= state
->d_buffer
.at(0) * 256 + state
->d_buffer
.at(1);
1223 if (state
->d_querySize
< sizeof(dnsheader
)) {
1225 handleNewIOState(state
, IOState::Done
, fd
, handleIOCallback
);
1229 /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
1230 or to add ECS without allocating a new buffer */
1231 state
->d_buffer
.resize(std::max(state
->d_querySize
+ static_cast<size_t>(512), s_maxPacketCacheEntrySize
));
1232 state
->d_currentPos
= 0;
1236 if (state
->d_state
== IncomingTCPConnectionState::State::readingQuery
) {
1237 iostate
= state
->d_handler
.tryRead(state
->d_buffer
, state
->d_currentPos
, state
->d_querySize
);
1238 if (iostate
== IOState::Done
) {
1239 handleNewIOState(state
, IOState::Done
, fd
, handleIOCallback
);
1240 handleQuery(state
, now
);
1245 if (state
->d_state
== IncomingTCPConnectionState::State::sendingResponse
) {
1246 iostate
= state
->d_handler
.tryWrite(state
->d_responseBuffer
, state
->d_currentPos
, state
->d_responseBuffer
.size());
1247 if (iostate
== IOState::Done
) {
1248 handleResponseSent(state
, now
);
1253 if (state
->d_state
!= IncomingTCPConnectionState::State::doingHandshake
&&
1254 state
->d_state
!= IncomingTCPConnectionState::State::readingQuerySize
&&
1255 state
->d_state
!= IncomingTCPConnectionState::State::readingQuery
&&
1256 state
->d_state
!= IncomingTCPConnectionState::State::sendingResponse
) {
1257 vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state
->d_state
));
1260 catch(const std::exception
& e
) {
1261 /* most likely an EOF because the other end closed the connection,
1262 but it might also be a real IO error or something else.
1263 Let's just drop the connection
1265 if (state
->d_state
== IncomingTCPConnectionState::State::doingHandshake
||
1266 state
->d_state
== IncomingTCPConnectionState::State::readingQuerySize
||
1267 state
->d_state
== IncomingTCPConnectionState::State::readingQuery
) {
1268 ++state
->d_ci
.cs
->tcpDiedReadingQuery
;
1270 else if (state
->d_state
== IncomingTCPConnectionState::State::sendingResponse
) {
1271 ++state
->d_ci
.cs
->tcpDiedSendingResponse
;
1274 if (state
->d_lastIOState
== IOState::NeedWrite
|| state
->d_readingFirstQuery
) {
1275 vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state
->d_lastIOState
== IOState::NeedRead
? "reading" : "writing"), state
->d_ci
.remote
.toStringWithPort(), e
.what());
1278 vinfolog("Closing TCP client connection with %s", state
->d_ci
.remote
.toStringWithPort());
1280 /* remove this FD from the IO multiplexer */
1281 iostate
= IOState::Done
;
1284 if (iostate
== IOState::Done
) {
1285 handleNewIOState(state
, iostate
, fd
, handleIOCallback
);
1288 handleNewIOState(state
, iostate
, fd
, handleIOCallback
, iostate
== IOState::NeedRead
? state
->getClientReadTTD(now
) : state
->getClientWriteTTD(now
));
1292 static void handleIOCallback(int fd
, FDMultiplexer::funcparam_t
& param
)
1294 auto state
= boost::any_cast
<std::shared_ptr
<IncomingTCPConnectionState
>>(param
);
1295 if (fd
!= state
->d_ci
.fd
) {
1296 throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd
) + " received in " + std::string(__func__
) + ", expected " + std::to_string(state
->d_ci
.fd
));
1299 gettimeofday(&now
, 0);
1301 handleIO(state
, now
);
1304 static void handleIncomingTCPQuery(int pipefd
, FDMultiplexer::funcparam_t
& param
)
1306 auto threadData
= boost::any_cast
<TCPClientThreadData
*>(param
);
1308 ConnectionInfo
* citmp
{nullptr};
1310 ssize_t got
= read(pipefd
, &citmp
, sizeof(citmp
));
1312 throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode");
1314 else if (got
== -1) {
1315 if (errno
== EAGAIN
|| errno
== EINTR
) {
1318 throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode:" + stringerror());
1320 else if (got
!= sizeof(citmp
)) {
1321 throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode");
1325 g_tcpclientthreads
->decrementQueuedCount();
1328 gettimeofday(&now
, 0);
1329 auto state
= std::make_shared
<IncomingTCPConnectionState
>(std::move(*citmp
), *threadData
, now
);
1333 /* let's update the remaining time */
1334 state
->d_remainingTime
= g_maxTCPConnectionDuration
;
1336 handleIO(state
, now
);
1345 void tcpClientThread(int pipefd
)
1347 /* we get launched with a pipe on which we receive file descriptors from clients that we own
1348 from that point on */
1350 setThreadName("dnsdist/tcpClie");
1352 TCPClientThreadData data
;
1354 data
.mplexer
->addReadFD(pipefd
, handleIncomingTCPQuery
, &data
);
1356 gettimeofday(&now
, 0);
1357 time_t lastTCPCleanup
= now
.tv_sec
;
1358 time_t lastTimeoutScan
= now
.tv_sec
;
1361 data
.mplexer
->run(&now
);
1363 if (g_downstreamTCPCleanupInterval
> 0 && (now
.tv_sec
> (lastTCPCleanup
+ g_downstreamTCPCleanupInterval
))) {
1364 cleanupClosedTCPConnections();
1365 lastTCPCleanup
= now
.tv_sec
;
1368 if (now
.tv_sec
> lastTimeoutScan
) {
1369 lastTimeoutScan
= now
.tv_sec
;
1370 auto expiredReadConns
= data
.mplexer
->getTimeouts(now
, false);
1371 for(const auto& conn
: expiredReadConns
) {
1372 auto state
= boost::any_cast
<std::shared_ptr
<IncomingTCPConnectionState
>>(conn
.second
);
1373 if (conn
.first
== state
->d_ci
.fd
) {
1374 vinfolog("Timeout (read) from remote TCP client %s", state
->d_ci
.remote
.toStringWithPort());
1375 ++state
->d_ci
.cs
->tcpClientTimeouts
;
1377 else if (state
->d_ds
) {
1378 vinfolog("Timeout (read) from remote backend %s", state
->d_ds
->getName());
1379 ++state
->d_ci
.cs
->tcpDownstreamTimeouts
;
1380 ++state
->d_ds
->tcpReadTimeouts
;
1382 data
.mplexer
->removeReadFD(conn
.first
);
1383 state
->d_lastIOState
= IOState::Done
;
1386 auto expiredWriteConns
= data
.mplexer
->getTimeouts(now
, true);
1387 for(const auto& conn
: expiredWriteConns
) {
1388 auto state
= boost::any_cast
<std::shared_ptr
<IncomingTCPConnectionState
>>(conn
.second
);
1389 if (conn
.first
== state
->d_ci
.fd
) {
1390 vinfolog("Timeout (write) from remote TCP client %s", state
->d_ci
.remote
.toStringWithPort());
1391 ++state
->d_ci
.cs
->tcpClientTimeouts
;
1393 else if (state
->d_ds
) {
1394 vinfolog("Timeout (write) from remote backend %s", state
->d_ds
->getName());
1395 ++state
->d_ci
.cs
->tcpDownstreamTimeouts
;
1396 ++state
->d_ds
->tcpWriteTimeouts
;
1398 data
.mplexer
->removeWriteFD(conn
.first
);
1399 state
->d_lastIOState
= IOState::Done
;
1405 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1406 they will hand off to worker threads & spawn more of them if required
1408 void tcpAcceptorThread(void* p
)
1410 setThreadName("dnsdist/tcpAcce");
1411 ClientState
* cs
= (ClientState
*) p
;
1412 bool tcpClientCountIncremented
= false;
1413 ComboAddress remote
;
1414 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
1416 if(!g_tcpclientthreads
->hasReachedMaxThreads()) {
1417 g_tcpclientthreads
->addTCPClientThread();
1420 auto acl
= g_ACL
.getLocal();
1422 bool queuedCounterIncremented
= false;
1423 std::unique_ptr
<ConnectionInfo
> ci
;
1424 tcpClientCountIncremented
= false;
1426 socklen_t remlen
= remote
.getSocklen();
1427 ci
= std::unique_ptr
<ConnectionInfo
>(new ConnectionInfo(cs
));
1429 ci
->fd
= accept4(cs
->tcpFD
, reinterpret_cast<struct sockaddr
*>(&remote
), &remlen
, SOCK_NONBLOCK
);
1431 ci
->fd
= accept(cs
->tcpFD
, reinterpret_cast<struct sockaddr
*>(&remote
), &remlen
);
1433 ++cs
->tcpCurrentConnections
;
1436 throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
1439 if(!acl
->match(remote
)) {
1441 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
1445 #ifndef HAVE_ACCEPT4
1446 if (!setNonBlocking(ci
->fd
)) {
1450 setTCPNoDelay(ci
->fd
); // disable NAGLE
1451 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->getQueuedCount() >= g_maxTCPQueuedConnections
) {
1452 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
1456 if (g_maxTCPConnectionsPerClient
) {
1457 std::lock_guard
<std::mutex
> lock(tcpClientsCountMutex
);
1459 if (tcpClientsCount
[remote
] >= g_maxTCPConnectionsPerClient
) {
1460 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote
.toStringWithPort());
1463 tcpClientsCount
[remote
]++;
1464 tcpClientCountIncremented
= true;
1467 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
1469 ci
->remote
= remote
;
1470 int pipe
= g_tcpclientthreads
->getThread();
1472 queuedCounterIncremented
= true;
1473 auto tmp
= ci
.release();
1475 writen2WithTimeout(pipe
, &tmp
, sizeof(tmp
), 0);
1484 g_tcpclientthreads
->decrementQueuedCount();
1485 queuedCounterIncremented
= false;
1486 if(tcpClientCountIncremented
) {
1487 decrementTCPClientCount(remote
);
1491 catch(const std::exception
& e
) {
1492 errlog("While reading a TCP question: %s", e
.what());
1493 if(tcpClientCountIncremented
) {
1494 decrementTCPClientCount(remote
);
1496 if (queuedCounterIncremented
) {
1497 g_tcpclientthreads
->decrementQueuedCount();