]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Use guard objects to do the TCP connection bookkeeping and cleanup if needed.
authorOtto <otto.moerbeek@open-xchange.com>
Wed, 24 Nov 2021 10:12:16 +0000 (11:12 +0100)
committerOtto <otto.moerbeek@open-xchange.com>
Wed, 24 Nov 2021 13:59:12 +0000 (14:59 +0100)
If a policy drop is to be handled for a TCP connection, do not
answer that query, but do handle already in-flight queries and then close.

pdns/pdns_recursor.cc
pdns/syncres.hh

index 6e8678b19c48245d380fc8690c3d36a4b9914405..a5e8c7b6c72607f1d2783685b2eaed5deb1f94c9 100644 (file)
@@ -933,7 +933,8 @@ static void finishTCPReply(std::unique_ptr<DNSComboWriter>& dc, bool hadError, b
     return;
   }
   dc->d_tcpConnection->queriesCount++;
-  if (g_tcpMaxQueriesPerConn && dc->d_tcpConnection->queriesCount >= g_tcpMaxQueriesPerConn) {
+  if ((g_tcpMaxQueriesPerConn && dc->d_tcpConnection->queriesCount >= g_tcpMaxQueriesPerConn) ||
+      (dc->d_tcpConnection->isDropOnIdle() && dc->d_tcpConnection->d_requestsInFlight == 0)) {
     try {
       t_fdm->removeReadFD(dc->d_socket);
     }
@@ -1144,9 +1145,29 @@ static bool addRecordToPacket(DNSPacketWriter& pw, const DNSRecord& rec, uint32_
   return true;
 }
 
+class RunningTCPResolve {
+public:
+  RunningTCPResolve(std::unique_ptr<DNSComboWriter>& dc) : d_dc(dc) {
+  }
+  ~RunningTCPResolve() {
+    if (!d_handled && d_dc->d_tcp) {
+      finishTCPReply(d_dc, false, true);
+    }
+  }
+  void setHandled() {
+    d_handled = true;
+  }
+  void setDropOnIdle() {
+    d_dc->d_tcpConnection->setDropOnIdle();
+  }
+private:
+  std::unique_ptr<DNSComboWriter>& d_dc;
+  bool d_handled{false};
+};
+
 enum class PolicyResult : uint8_t { NoAction, HaveAnswer, Drop };
 
-static PolicyResult handlePolicyHit(const DNSFilterEngine::Policy& appliedPolicy, const std::unique_ptr<DNSComboWriter>& dc, SyncRes& sr, int& res, vector<DNSRecord>& ret, DNSPacketWriter& pw)
+static PolicyResult handlePolicyHit(const DNSFilterEngine::Policy& appliedPolicy, const std::unique_ptr<DNSComboWriter>& dc, SyncRes& sr, int& res, vector<DNSRecord>& ret, DNSPacketWriter& pw, RunningTCPResolve& tcpGuard)
 {
   /* don't account truncate actions for TCP queries, since they are not applied */
   if (appliedPolicy.d_kind != DNSFilterEngine::PolicyKind::Truncate || !dc->d_tcp) {
@@ -1169,6 +1190,7 @@ static PolicyResult handlePolicyHit(const DNSFilterEngine::Policy& appliedPolicy
       return PolicyResult::NoAction;
 
   case DNSFilterEngine::PolicyKind::Drop:
+    tcpGuard.setDropOnIdle();
     ++g_stats.policyDrops;
     return PolicyResult::Drop;
 
@@ -1770,6 +1792,8 @@ static void startDoResolve(void *p)
     dq.extendedErrorExtra = &dc->d_extendedErrorExtra;
     dq.meta = std::move(dc->d_meta);
 
+    RunningTCPResolve tcpGuard(dc);
+
     if(ednsExtRCode != 0 || dc->d_mdp.d_header.opcode == Opcode::Notify) {
       goto sendit;
     }
@@ -1863,7 +1887,7 @@ static void startDoResolve(void *p)
           appliedPolicy = DNSFilterEngine::Policy();
         }
         else {
-          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
           if (policyResult == PolicyResult::HaveAnswer) {
             if (g_dns64Prefix && dq.qtype == QType::AAAA && answerIsNOData(dc->d_mdp.d_qtype, res, ret)) {
               res = getFakeAAAARecords(dq.qname, *g_dns64Prefix, ret);
@@ -1924,7 +1948,7 @@ static void startDoResolve(void *p)
         if (appliedPolicy.d_kind == DNSFilterEngine::PolicyKind::NoAction) {
           throw PDNSException("NoAction policy returned while a NSDNAME or NSIP trigger was hit");
         }
-        auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+        auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
         if (policyResult == PolicyResult::HaveAnswer) {
           goto haveAnswer;
         }
@@ -1938,7 +1962,7 @@ static void startDoResolve(void *p)
           if (answerIsNOData(dc->d_mdp.d_qtype, res, ret)) {
             if (t_pdl && t_pdl->nodata(dq, res, sr.d_eventTrace)) {
               shouldNotValidate = true;
-              auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+              auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
               if (policyResult == PolicyResult::HaveAnswer) {
                 goto haveAnswer;
               }
@@ -1954,7 +1978,7 @@ static void startDoResolve(void *p)
        }
        else if (res == RCode::NXDomain && t_pdl && t_pdl->nxdomain(dq, res, sr.d_eventTrace)) {
           shouldNotValidate = true;
-          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
           if (policyResult == PolicyResult::HaveAnswer) {
             goto haveAnswer;
           }
@@ -1965,7 +1989,7 @@ static void startDoResolve(void *p)
 
        if (t_pdl && t_pdl->postresolve(dq, res, sr.d_eventTrace)) {
           shouldNotValidate = true;
-          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+          auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
           // haveAnswer case redundant
           if (policyResult == PolicyResult::Drop) {
             return;
@@ -1976,7 +2000,7 @@ static void startDoResolve(void *p)
     else if (t_pdl) {
       // preresolve returned true
       shouldNotValidate = true;
-      auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw);
+      auto policyResult = handlePolicyHit(appliedPolicy, dc, sr, res, ret, pw, tcpGuard);
       // haveAnswer case redundant
       if (policyResult == PolicyResult::Drop) {
         return;
@@ -2313,6 +2337,7 @@ static void startDoResolve(void *p)
     else {
       bool hadError = sendResponseOverTCP(dc, packet);
       finishTCPReply(dc, hadError, true);
+      tcpGuard.setHandled();
     }
 
     sr.d_eventTrace.add(RecEventTrace::AnswerSent);
@@ -2628,14 +2653,35 @@ static void requestWipeCaches(const DNSName& canon)
   }
 }
 
+class RunningTCPGuard {
+public:
+  RunningTCPGuard(int fd) {
+    d_fd = fd;
+  }
+  ~RunningTCPGuard() {
+    if (d_fd != -1) {
+      terminateTCPConnection(d_fd);
+      d_fd = -1;
+    }
+  }
+  void keep() {
+    d_fd = -1;
+  }
+private:
+  int d_fd{-1};
+};
+
 static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
 {
   shared_ptr<TCPConnection> conn=boost::any_cast<shared_ptr<TCPConnection> >(var);
 
+  RunningTCPGuard tcpGuard{fd};
+
   if (conn->state == TCPConnection::PROXYPROTOCOLHEADER) {
     ssize_t bytes = recv(conn->getFD(), &conn->data.at(conn->proxyProtocolGot), conn->proxyProtocolNeed, 0);
     if (bytes <= 0) {
       handleTCPReadResult(fd, bytes);
+      tcpGuard.keep();
       return;
     }
 
@@ -2647,12 +2693,12 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
         g_log<<Logger::Error<<"Unable to consume proxy protocol header in packet from TCP client "<< conn->d_remote.toStringWithPort() <<endl;
       }
       ++g_stats.proxyProtocolInvalidCount;
-      terminateTCPConnection(fd);
       return;
     }
     else if (remaining < 0) {
       conn->proxyProtocolNeed = -remaining;
       conn->data.resize(conn->proxyProtocolGot + conn->proxyProtocolNeed);
+      tcpGuard.keep();
       return;
     }
     else {
@@ -2667,7 +2713,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           g_log<<Logger::Error<<"Unable to parse proxy protocol header in packet from TCP client "<< conn->d_remote.toStringWithPort() <<endl;
         }
         ++g_stats.proxyProtocolInvalidCount;
-        terminateTCPConnection(fd);
         return;
       }
       else if (static_cast<size_t>(used) > g_proxyProtocolMaximumSize) {
@@ -2675,7 +2720,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           g_log<<Logger::Error<<"Proxy protocol header in packet from TCP client "<< conn->d_remote.toStringWithPort() << " is larger than proxy-protocol-maximum-size (" << used << "), dropping"<< endl;
         }
         ++g_stats.proxyProtocolInvalidCount;
-        terminateTCPConnection(fd);
         return;
       }
 
@@ -2688,7 +2732,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
         }
 
         ++g_stats.unauthorizedTCP;
-        terminateTCPConnection(fd);
         return;
       }
 
@@ -2709,6 +2752,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
     }
     if (bytes <= 0) {
       handleTCPReadResult(fd, bytes);
+      tcpGuard.keep();
       return;
     }
   }
@@ -2727,6 +2771,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           g_log<<Logger::Error<<"TCP client "<< conn->d_remote.toStringWithPort() <<" disconnected after first byte"<<endl;
         }
       }
+      tcpGuard.keep();
       return;
     }
   }
@@ -2738,14 +2783,14 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
         if(g_logCommonErrors) {
           g_log<<Logger::Error<<"TCP client "<< conn->d_remote.toStringWithPort() <<" disconnected while reading question body"<<endl;
         }
-      }
+      }      
+      tcpGuard.keep();
       return;
     }
     else if (bytes > std::numeric_limits<std::uint16_t>::max()) {
       if(g_logCommonErrors) {
         g_log<<Logger::Error<<"TCP client "<< conn->d_remote.toStringWithPort() <<" sent an invalid question size while reading question body"<<endl;
       }
-      terminateTCPConnection(fd);
       return;
     }
     conn->bytesread+=(uint16_t)bytes;
@@ -2760,7 +2805,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
         if (g_logCommonErrors) {
           g_log<<Logger::Error<<"Unable to parse packet from TCP client "<< conn->d_remote.toStringWithPort() <<endl;
         }
-        terminateTCPConnection(fd);
         return;
       }
       dc->d_tcpConnection = conn; // carry the torch
@@ -2883,7 +2927,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             g_log<<Logger::Notice<<t_id<<" ["<<MT->getTid()<<"/"<<MT->numProcesses()<<"] DROPPED TCP question from "<<dc->d_source.toStringWithPort()<<(dc->d_source != dc->d_remote ? " (via "+dc->d_remote.toStringWithPort()+")" : "")<<" based on policy"<<endl;
           }
           g_stats.policyDrops++;
-          terminateTCPConnection(fd);
           return;
         }
       }
@@ -2893,7 +2936,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
         if (g_logCommonErrors) {
           g_log<<Logger::Error<<"Ignoring answer from TCP client "<< dc->getRemote() <<" on server socket!"<<endl;
         }
-        terminateTCPConnection(fd);
         return;
       }
       if (dc->d_mdp.d_header.opcode != Opcode::Query && dc->d_mdp.d_header.opcode != Opcode::Notify) {
@@ -2902,6 +2944,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           g_log<<Logger::Error<<"Ignoring unsupported opcode "<<Opcode::to_s(dc->d_mdp.d_header.opcode)<<" from TCP client "<< dc->getRemote() <<" on server socket!"<<endl;
         }
         sendErrorOverTCP(dc, RCode::NotImp);
+        tcpGuard.keep();
         return;
       }
       else if (dh->qdcount == 0) {
@@ -2910,6 +2953,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           g_log<<Logger::Error<<"Ignoring empty (qdcount == 0) query from "<< dc->getRemote() <<" on server socket!"<<endl;
         }
         sendErrorOverTCP(dc, RCode::NotImp);
+        tcpGuard.keep();
         return;
       }
       else {
@@ -2924,7 +2968,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             }
 
             g_stats.sourceDisallowedNotify++;
-            terminateTCPConnection(fd);
             return;
           }
 
@@ -2934,7 +2977,6 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             }
 
             g_stats.zoneDisallowedNotify++;
-            terminateTCPConnection(fd);
             return;
           }
         }
@@ -2971,6 +3013,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             if (dc->d_eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) {
               g_log << Logger::Info << dc->d_eventTrace.toString() << endl;
             }
+            tcpGuard.keep();
             return;
           } // cache hit
         } // query opcode
@@ -2998,6 +3041,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           struct timeval ttd = g_now;
           t_fdm->setReadTTD(fd, ttd, g_tcpTimeout);
         }
+        tcpGuard.keep();
         MT->makeThread(startDoResolve, dc.release()); // deletes dc
       } // good query
     } // read full query
index 5d51b4e5bbb470126bf388620001b0e889966e60..7cd408d19f6dbd92ff128e109c0d76c502030574 100644 (file)
@@ -1111,7 +1111,14 @@ public:
   {
     return d_fd;
   }
-
+  void setDropOnIdle()
+  {
+    d_dropOnIdle = true;
+  }
+  bool isDropOnIdle() const
+  {
+    return d_dropOnIdle;
+  }
   std::vector<ProxyProtocolValue> proxyProtocolValues;
   std::string data;
   const ComboAddress d_remote;
@@ -1130,6 +1137,7 @@ public:
 private:
   const int d_fd;
   static std::atomic<uint32_t> s_currentConnections; //!< total number of current TCP connections
+  bool d_dropOnIdle{false};
 };
 
 class ImmediateServFailException