]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Clean up the Downstream TCP code by using a TCPIOHandler
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 11 Feb 2021 18:02:03 +0000 (19:02 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 2 Mar 2021 09:52:41 +0000 (10:52 +0100)
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/sdig.cc
pdns/tcpiohandler.hh

index 9e7634fc74b54abd5dbf602911a944048d5b1457..0c00d2c831813a83b4ff949159dc686faba00765 100644 (file)
@@ -879,6 +879,7 @@ struct DownstreamState
   std::mutex socketsLock;
   std::mutex connectLock;
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+  std::shared_ptr<TLSCtx> d_tlsCtx{nullptr};
   std::thread tid;
   const ComboAddress remote;
   QPSLimiter qps;
index 645abbd299473a08323e6cfa157169c541192497..e93d6f6a340dbdceb78886196f2d8a859883f99d 100644 (file)
@@ -16,7 +16,7 @@ void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr<IncomingTC
 
   if (!d_clientConn) {
     d_clientConn = clientConn;
-    d_ioState = make_unique<IOStateHandler>(clientConn->getIOMPlexer(), d_socket->getHandle());
+    d_ioState = make_unique<IOStateHandler>(clientConn->getIOMPlexer(), d_handler->getDescriptor());
   }
   else if (d_clientConn != clientConn) {
     throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries");
@@ -48,88 +48,42 @@ IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBa
   return IOState::NeedWrite;
 }
 
-/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
-   Updates pos everytime a successful read occurs,
-   throws an std::runtime_error in case of IO error,
-   return Done when toRead bytes have been read, needRead or needWrite if the IO operation
-   would block.
-*/
-// XXX could probably be implemented as a TCPIOHandler
-static IOState tryRead(int fd, PacketBuffer& buffer, size_t& pos, size_t toRead)
+IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now)
 {
-  if (buffer.size() < (pos + toRead)) {
-    throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
-  }
+  DEBUGLOG("sending query to backend "<<conn->getDS()->getName()<<" over FD "<<conn->d_handler->getDescriptor());
 
-  size_t got = 0;
-  do {
-    ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
-    if (res == 0) {
-      throw runtime_error("EOF while reading message");
-    }
-    if (res < 0) {
-      if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
-        return IOState::NeedRead;
-      }
-      else {
-        throw std::runtime_error(std::string("Error while reading message: ") + stringerror());
-      }
-    }
+#warning FIXME: TODO: this drops 1/ source selection other than SO_BINDTODEVICE, perhaps we should look into IP_SENDIF?
+  IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_buffer.size());
 
-    pos += static_cast<size_t>(res);
-    got += static_cast<size_t>(res);
+  if (state != IOState::Done) {
+    return state;
   }
-  while (got < toRead);
-
-  return IOState::Done;
-}
 
-IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now)
-{
-  int fd = conn->d_socket->getHandle();
-  DEBUGLOG("sending query to backend "<<conn->getDS()->getName()<<" over FD "<<fd);
-  int socketFlags = 0;
-#ifdef MSG_FASTOPEN
-  if (conn->isFastOpenEnabled()) {
-    socketFlags |= MSG_FASTOPEN;
-  }
-#endif /* MSG_FASTOPEN */
-
-  size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&conn->d_currentQuery.d_buffer.at(conn->d_currentPos)), conn->d_currentQuery.d_buffer.size() - conn->d_currentPos, &conn->d_ds->remote, &conn->d_ds->sourceAddr, conn->d_ds->sourceItf, socketFlags);
-  if (sent == conn->d_currentQuery.d_buffer.size()) {
-    DEBUGLOG("query sent to backend");
-    /* request sent ! */
-    conn->incQueries();
-    conn->d_currentPos = 0;
-
-    DEBUGLOG("adding a pending response for ID "<<conn->d_currentQuery.d_idstate.origID<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname);
-    conn->d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery);
-    conn->d_currentQuery.d_buffer.clear();
+  DEBUGLOG("query sent to backend");
+  /* request sent ! */
+  conn->incQueries();
+  conn->d_currentPos = 0;
 
-    if (!conn->d_usedForXFR) {
-      ++conn->d_ds->outstanding;
-    }
+  DEBUGLOG("adding a pending response for ID "<<conn->d_currentQuery.d_idstate.origID<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname);
+  conn->d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery);
+  conn->d_currentQuery.d_buffer.clear();
 
-    return IOState::Done;
-  }
-  else {
-    conn->d_currentPos += sent;
-    /* disable fast open on partial write */
-    conn->disableFastOpen();
-    return IOState::NeedWrite;
+  if (!conn->d_usedForXFR) {
+    ++conn->d_ds->outstanding;
   }
+
+  return state;
 }
 
 void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now)
 {
-  if (conn->d_socket == nullptr) {
+  if (conn->d_handler == nullptr) {
     throw std::runtime_error("No downstream socket in " + std::string(__PRETTY_FUNCTION__) + "!");
   }
 
   bool connectionDied = false;
   IOState iostate = IOState::Done;
   IOStateGuard ioGuard(conn->d_ioState);
-  int fd = conn->d_socket->getHandle();
 
   try {
     if (conn->d_state == State::sendingQueryToBackend) {
@@ -153,9 +107,8 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
       // then we need to allocate a new buffer (new because we might need to re-send the query if the
       // backend dies on us)
       // We also might need to read and send to the client more than one response in case of XFR (yeah!)
-      // should very likely be a TCPIOHandler
       conn->d_responseBuffer.resize(sizeof(uint16_t));
-      iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t) - conn->d_currentPos);
+      iostate = conn->d_handler->tryRead(conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t) - conn->d_currentPos);
       if (iostate == IOState::Done) {
         DEBUGLOG("got response size from backend");
         conn->d_state = State::readingResponseFromBackend;
@@ -168,7 +121,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
 
     if (conn->d_state == State::readingResponseFromBackend) {
       DEBUGLOG("reading response from backend");
-      iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, conn->d_responseSize - conn->d_currentPos);
+      iostate = conn->d_handler->tryRead(conn->d_responseBuffer, conn->d_currentPos, conn->d_responseSize - conn->d_currentPos);
       if (iostate == IOState::Done) {
         DEBUGLOG("got response from backend");
         try {
@@ -225,7 +178,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
 
       try {
         if (conn->reconnect()) {
-          conn->d_ioState = make_unique<IOStateHandler>(conn->d_clientConn->getIOMPlexer(), conn->d_socket->getHandle());
+          conn->d_ioState = make_unique<IOStateHandler>(conn->d_clientConn->getIOMPlexer(), conn->d_handler->getDescriptor());
 
           /* we need to resend the queries that were in flight, if any */
           for (auto& pending : conn->d_pendingResponses) {
@@ -328,12 +281,9 @@ void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPCon
 
 bool TCPConnectionToBackend::reconnect()
 {
-  std::unique_ptr<Socket> result;
-
-  if (d_socket) {
-    DEBUGLOG("closing socket "<<d_socket->getHandle());
-    shutdown(d_socket->getHandle(), SHUT_RDWR);
-    d_socket.reset();
+  if (d_handler) {
+    DEBUGLOG("closing socket "<<d_handler->getDescriptor());
+    d_handler->close();
     d_ioState.reset();
     --d_ds->tcpCurrentConnections;
   }
@@ -344,36 +294,32 @@ bool TCPConnectionToBackend::reconnect()
     vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
     DEBUGLOG("Opening TCP connection to backend "<<d_ds->getNameWithAddr());
     try {
-      result = std::unique_ptr<Socket>(new Socket(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0));
-      DEBUGLOG("result of connect is "<<result->getHandle());
+      auto socket = std::make_unique<Socket>(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+      DEBUGLOG("result of socket() is "<<socket->getHandle());
 
       if (!IsAnyAddress(d_ds->sourceAddr)) {
-        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
+        SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef IP_BIND_ADDRESS_NO_PORT
         if (d_ds->ipBindAddrNoPort) {
-          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+          SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
         }
 #endif
 #ifdef SO_BINDTODEVICE
         if (!d_ds->sourceItfName.empty()) {
-          int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
+          int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
           if (res != 0) {
             vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
           }
         }
 #endif
-        result->bind(d_ds->sourceAddr, false);
+        socket->bind(d_ds->sourceAddr, false);
       }
-      result->setNonBlocking();
-#ifdef MSG_FASTOPEN
-      if (!d_ds->tcpFastOpen || !isFastOpenEnabled()) {
-        SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
-      }
-#else
-      SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
-#endif /* MSG_FASTOPEN */
+      socket->setNonBlocking();
+
+      auto handler = std::make_unique<TCPIOHandler>("", socket->releaseHandle(), 0, d_ds->d_tlsCtx, time(nullptr));
+      handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote);
 
-      d_socket = std::move(result);
+      d_handler = std::move(handler);
       ++d_ds->tcpCurrentConnections;
       return true;
     }
index 46b53a70ac1820955cd6feeec7355524b2941028..a26b29ca06f3565b65d7cc61d0504b6d89032cf9 100644 (file)
@@ -52,7 +52,7 @@ public:
 
   ~TCPConnectionToBackend()
   {
-    if (d_ds && d_socket) {
+    if (d_ds && d_handler) {
       --d_ds->tcpCurrentConnections;
       struct timeval now;
       gettimeofday(&now, nullptr);
@@ -66,11 +66,11 @@ public:
 
   int getHandle() const
   {
-    if (!d_socket) {
+    if (!d_handler) {
       throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
     }
 
-    return d_socket->getHandle();
+    return d_handler->getDescriptor();
   }
 
   const std::shared_ptr<DownstreamState>& getDS() const
@@ -172,7 +172,7 @@ public:
   std::string toString() const
   {
     ostringstream o;
-    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_socket ? std::to_string(d_socket->getHandle()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_clientConn ? " a client" : "no client");
+    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_clientConn ? " a client" : "no client");
     return o.str();
   }
 
@@ -228,7 +228,7 @@ private:
   std::deque<TCPQuery> d_pendingQueries;
   std::unordered_map<uint16_t, TCPQuery> d_pendingResponses;
   std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValuesSent{nullptr};
-  std::unique_ptr<Socket> d_socket{nullptr};
+  std::unique_ptr<TCPIOHandler> d_handler{nullptr};
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
   std::shared_ptr<IncomingTCPConnectionState> d_clientConn;
index d9e2d7da30cf30b2ccd9425a472d9a4c9d1eea24..c4432a2008558f3717441c7a26766efa2c956683 100644 (file)
@@ -409,7 +409,6 @@ try {
     }
     uint16_t counter = 0;
     Socket sock(dest.sin4.sin_family, SOCK_STREAM);
-    SConnectWithTimeout(sock.getHandle(), dest, timeout);
     TCPIOHandler handler(subjectName, sock.releaseHandle(), timeout, tlsCtx, time(nullptr));
     handler.connect(fastOpen, dest, timeout);
     // we are writing the proxyheader inside the TLS connection. Is that right?
index e6c4983aa9ccba21437f67501aa268c71e4fa469..b9b271e384237d0f30fa2a383421299d529989db 100644 (file)
@@ -230,24 +230,64 @@ public:
 
   IOState tryConnect(bool fastOpen, const ComboAddress& remote)
   {
-    /* yes, this is only the TLS connect not the socket one,
-       sorry about that */
+    d_remote = remote;
+
+#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
+    if (fastOpen) {
+      int value = 1;
+      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
+      if (res == 0) {
+        fastOpen = false;
+      }
+    }
+#endif /* TCP_FASTOPEN_CONNECT */
+
+#ifdef MSG_FASTOPEN
+    if (!d_conn && fastOpen) {
+      d_fastOpen = true;
+    }
+    else {
+      SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ 0);
+    }
+#else
+    SConnectWithTimeout(d_socket, d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
+#endif /* MSG_FASTOPEN */
+
     if (d_conn) {
       return d_conn->tryConnect(fastOpen, remote);
     }
-    d_fastOpen = fastOpen;
 
     return IOState::Done;
   }
 
   void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout)
   {
-    /* yes, this is only the TLS connect not the socket one,
-       sorry about that */
+    d_remote = remote;
+
+#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
+    if (fastOpen) {
+      int value = 1;
+      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
+      if (res == 0) {
+        fastOpen = false;
+      }
+    }
+#endif /* TCP_FASTOPEN_CONNECT */
+
+#ifdef MSG_FASTOPEN
+    if (!d_conn && fastOpen) {
+      d_fastOpen = true;
+    }
+    else {
+      SConnectWithTimeout(d_socket, remote, timeout);
+    }
+#else
+    SConnectWithTimeout(d_socket, d_ds->remote, timeout);
+#endif /* MSG_FASTOPEN */
+
     if (d_conn) {
       d_conn->connect(fastOpen, remote, timeout);
     }
-    d_fastOpen = fastOpen;
   }
 
   IOState tryHandshake()
@@ -319,8 +359,24 @@ public:
       return d_conn->tryWrite(buffer, pos, toWrite);
     }
 
+    if (d_fastOpen) {
+      int socketFlags = MSG_FASTOPEN;
+      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos, &d_remote, nullptr, 0, socketFlags);
+      if (sent > 0) {
+        d_fastOpen = false;
+        pos += sent;
+      }
+
+      if (pos < toWrite) {
+        return IOState::NeedWrite;
+      }
+
+      return IOState::Done;
+    }
+
     do {
       ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
+
       if (res == 0) {
         throw runtime_error("EOF while sending message");
       }
@@ -389,13 +445,14 @@ public:
     return d_conn && d_conn->getResumedFromInactiveTicketKey();
   }
 
-    bool getUnknownTicketKey() const
+  bool getUnknownTicketKey() const
   {
     return d_conn && d_conn->getUnknownTicketKey();
   }
 
 private:
   std::unique_ptr<TLSConnection> d_conn{nullptr};
+  ComboAddress d_remote;
   int d_socket{-1};
   bool d_fastOpen{false};
 };