]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Refactor the tcp case of lwres.
authorOtto <otto.moerbeek@open-xchange.com>
Mon, 13 Sep 2021 11:39:45 +0000 (13:39 +0200)
committerOtto <otto.moerbeek@open-xchange.com>
Fri, 24 Sep 2021 07:59:44 +0000 (09:59 +0200)
Not as far as rgacogne suggested, but it's more readable right now.

pdns/lwres.cc

index 1dc85fbbd3476dce25027e30f5a7ee2636652da6..cc15c18171f5605761481554887337155466eb4e 100644 (file)
@@ -236,6 +236,7 @@ static bool tcpconnect(const struct timeval& now, const ComboAddress& ip, TCPOut
 {
   dnsOverTLS = SyncRes::s_dot_to_port_853 && ip.getPort() == 853;
 
+
   while (true) {
     connection = t_tcp_manager.get(ip);
     if (connection.d_handler) {
@@ -262,16 +263,52 @@ static bool tcpconnect(const struct timeval& now, const ComboAddress& ip, TCPOut
     }
     connection.d_handler = std::make_shared<TCPIOHandler>("", s.releaseHandle(), timeout, tlsCtx, now.tv_sec);
     // Returned state ignored
-    try {
-      connection.d_handler->tryConnect(SyncRes::s_tcp_fast_open_connect, ip);
-    }
-    catch (const std::runtime_error&) {
-      continue;
-    }
+    // This can throw an excepion, retry will need to happen at higher level
+    connection.d_handler->tryConnect(SyncRes::s_tcp_fast_open_connect, ip);
     return true;
   }
 }
 
+static LWResult::Result tcpsendrecv(const ComboAddress& ip, TCPOutConnectionManager::Connection& connection,
+                                    ComboAddress& localip, const vector<uint8_t>& vpacket, size_t& len, PacketBuffer& buf)
+{
+  socklen_t slen = ip.getSocklen();
+  uint16_t tlen = htons(vpacket.size());
+  const char *lenP = reinterpret_cast<const char*>(&tlen);
+  const char *msgP = reinterpret_cast<const char*>(&*vpacket.begin());
+
+  localip.sin4.sin_family = ip.sin4.sin_family;
+  getsockname(connection.d_handler->getDescriptor(), reinterpret_cast<sockaddr*>(&localip), &slen);
+
+  PacketBuffer packet;
+  packet.reserve(2 + vpacket.size());
+  packet.insert(packet.end(), lenP, lenP + 2);
+  packet.insert(packet.end(), msgP, msgP + vpacket.size());
+
+  LWResult::Result ret = asendtcp(packet, connection.d_handler);
+  if (ret != LWResult::Result::Success) {
+    return ret;
+  }
+
+  ret = arecvtcp(packet, 2, connection.d_handler, false);
+  if (ret != LWResult::Result::Success) {
+    return ret;
+  }
+
+  memcpy(&tlen, packet.data(), sizeof(tlen));
+  len = ntohs(tlen); // switch to the 'len' shared with the rest of the calling function
+
+  // XXX receive into buf directly?
+  packet.resize(len);
+  ret = arecvtcp(packet, len, connection.d_handler, false);
+  if (ret != LWResult::Result::Success) {
+    return ret;
+  }
+  buf.resize(len);
+  memcpy(buf.data(), packet.data(), len);
+  return LWResult::Result::Success;
+}
+
 /** lwr is only filled out in case 1 was returned, and even when returning 1 for 'success', lwr might contain DNS errors
     Never throws! 
  */
@@ -385,69 +422,32 @@ static LWResult::Result asyncresolve(const ComboAddress& ip, const DNSName& doma
     ret = arecvfrom(buf, 0, ip, &len, qid, domain, type, queryfd, now);
   }
   else {
-    try {
-      while (true) {
-        // If we get a new (not re-used) TCP connection that does not
-        // work, we give up. For reused connections, we assume the
-        // peer has closed it on error, so we retry. At some point we
-        // *will* get a new connection, so this loop is not endless.
-        bool isNew = tcpconnect(*now, ip, connection, dnsOverTLS);
-        localip.sin4.sin_family = ip.sin4.sin_family;
-        socklen_t slen = ip.getSocklen();
-        getsockname(connection.d_handler->getDescriptor(), reinterpret_cast<sockaddr*>(&localip), &slen);
-        uint16_t tlen = htons(vpacket.size());
-        char *lenP = (char*)&tlen;
-        const char *msgP=(const char*)&*vpacket.begin();
-        PacketBuffer packet;
-        packet.reserve(2 + vpacket.size());
-        packet.insert(packet.end(), lenP, lenP+2);
-        packet.insert(packet.end(), msgP, msgP+vpacket.size());
-        ret = asendtcp(packet, connection.d_handler);
-        if (ret != LWResult::Result::Success) {
-          if (isNew) {
-            connection.d_handler->close();
-            return ret;
-          } else {
-            continue;
-          }
-        }
+      bool isNew;
+      do {
+        try {
+          // If we get a new (not re-used) TCP connection that does not
+          // work, we give up. For reused connections, we assume the
+          // peer has closed it on error, so we retry. At some point we
+          // *will* get a new connection, so this loop is not endless.
+          isNew = tcpconnect(*now, ip, connection, dnsOverTLS);
+          ret = tcpsendrecv(ip, connection, localip, vpacket, len, buf);
 #ifdef HAVE_FSTRM
-        if (fstrmQEnabled) {
-          logFstreamQuery(fstrmLoggers, queryTime, localip, ip, !dnsOverTLS ? DnstapMessage::ProtocolType::DoTCP : DnstapMessage::ProtocolType::DoT, context ? context->d_auth : boost::none, vpacket);
-        }
+          if (fstrmQEnabled) {
+            logFstreamQuery(fstrmLoggers, queryTime, localip, ip, !dnsOverTLS ? DnstapMessage::ProtocolType::DoTCP : DnstapMessage::ProtocolType::DoT, context ? context->d_auth : boost::none, vpacket);
+          }
 #endif /* HAVE_FSTRM */
-
-        ret = arecvtcp(packet, 2, connection.d_handler, false);
-        if (ret != LWResult::Result::Success) {
-          if (isNew) {
-            return ret;
-          } else {
-            continue;
+          if (ret == LWResult::Result::Success) {
+            break;
           }
+          connection.d_handler->close();
         }
-
-        memcpy(&tlen, packet.data(), sizeof(tlen));
-        len = ntohs(tlen); // switch to the 'len' shared with the rest of the function
-
-        // XXX receive into buf directly?
-        packet.resize(len);
-        ret = arecvtcp(packet, len, connection.d_handler, false);
-        if (ret != LWResult::Result::Success) {
-          if (isNew) {
-            return ret;
-          } else {
-            continue;
-          }
+        catch (const NetworkError&) {
+          ret = LWResult::Result::OSLimitError; // OS limits error
         }
-        buf.resize(len);
-        memcpy(buf.data(), packet.data(), len);
-        ret = LWResult::Result::Success;
-        break;
-      }
-    }
-    catch (const NetworkError& ne) {
-      ret = LWResult::Result::OSLimitError; // OS limits error
-    }
+        catch (const runtime_error&) {
+          ret = LWResult::Result::OSLimitError; // OS limits error (PermanentError is transport related)
+        }
+      } while (!isNew);
   }
 
   lwr->d_usec=dt.udiff();
@@ -564,12 +564,12 @@ static LWResult::Result asyncresolve(const ComboAddress& ip, const DNSName& doma
 LWResult::Result asyncresolve(const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, const std::shared_ptr<std::vector<std::unique_ptr<RemoteLogger>>>& outgoingLoggers, const std::shared_ptr<std::vector<std::unique_ptr<FrameStreamLogger>>>& fstrmLoggers, const std::set<uint16_t>& exportTypes, LWResult *lwr, bool* chained)
 {
   TCPOutConnectionManager::Connection connection;
-  auto ret = asyncresolve(ip, domain, type,doTCP, sendRDQuery, EDNS0Level, now, srcmask, context, outgoingLoggers, fstrmLoggers, exportTypes, lwr, chained, connection);
+  auto ret = asyncresolve(ip, domain, type, doTCP, sendRDQuery, EDNS0Level, now, srcmask, context, outgoingLoggers, fstrmLoggers, exportTypes, lwr, chained, connection);
 
   if (doTCP) {
     if (!lwr->d_validpacket) {
-      ret = asyncresolve(ip, domain, type,doTCP, sendRDQuery, EDNS0Level, now, srcmask, context, outgoingLoggers, fstrmLoggers, exportTypes, lwr, chained, connection);
-    } 
+      ret = asyncresolve(ip, domain, type, doTCP, sendRDQuery, EDNS0Level, now, srcmask, context, outgoingLoggers, fstrmLoggers, exportTypes, lwr, chained, connection);
+    }
     if (connection.d_handler && lwr->d_validpacket) {
       t_tcp_manager.store(*now, ip, std::move(connection));
     }