]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Implement Proxy Protocol support for outgoing DoH
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 6 Sep 2021 14:46:38 +0000 (16:46 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Sep 2021 13:34:33 +0000 (15:34 +0200)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdistdist/dnsdist-healthchecks.cc
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_OutgoingDOH.py
regression-tests.dnsdist/test_ProxyProtocol.py

index f839b321f222202b0b3c7e50ebbf342148ea91de..20d1f00fcb4d025b9d896ea8c0bb0cfa155b94f9 100644 (file)
@@ -552,7 +552,6 @@ public:
   {
     query = InternalQuery(std::move(buffer), std::move(ids));
     downstream = ds;
-    #warning handle proxy protocol payload
     proxyProtocolPayloadSize = 0;
   }
 
@@ -694,11 +693,19 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
 
   ++state->d_currentQueriesCount;
 
+  std::string proxyProtocolPayload;
   if (ds->isDoH()) {
     vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
 
+    /* we need to do this _before_ creating the cross protocol query because
+       after that the buffer will have been moved */
+    if (ds->useProxyProtocol) {
+      proxyProtocolPayload = getProxyProtocolPayload(dq);
+    }
+
     auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state, state->d_threadData.crossProtocolResponsesPipe);
     auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
+    cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
     ds->passCrossProtocolQuery(std::move(cpq));
     return;
@@ -709,7 +716,6 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
 
   bool proxyProtocolPayloadAdded = false;
-  std::string proxyProtocolPayload;
 
   if (ds->useProxyProtocol) {
     /* if we ever sent a TLV over a connection, we can never go back */
index dbcbbcc4c8b8cb79ae6f93913ed77229c8fe4a1a..c8b8d3f407620779bde5b14bc641885221225e5c 100644 (file)
@@ -1493,6 +1493,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     if (ss->isTCPOnly()) {
+      std::string proxyProtocolPayload;
+      /* we need to do this _before_ creating the cross protocol query because
+         after that the buffer will have been moved */
+      if (ss->useProxyProtocol) {
+        proxyProtocolPayload = getProxyProtocolPayload(dq);
+      }
+
       IDState ids;
       ids.cs = &cs;
       ids.origFD = cs.udpFD;
@@ -1505,6 +1512,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
         ids.origDest = cs.local;
       }
       auto cpq = std::make_unique<UDPCrossProtocolQuery>(std::move(query), std::move(ids), ss);
+      cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
       ss->passCrossProtocolQuery(std::move(cpq));
       return;
index f452932d65c6eb7070d98e2aee4c493a52f29c50..ebca3585326e1794dfbcd60a016fc863b3fee14f 100644 (file)
@@ -341,11 +341,14 @@ bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared
 
     /* we need to compute that _before_ adding the proxy protocol payload */
     uint16_t packetSize = packet.size();
+    std::string proxyProtocolPayload;
     size_t proxyProtocolPayloadSize = 0;
     if (ds->useProxyProtocol) {
-      auto payload = makeLocalProxyHeader();
-      proxyProtocolPayloadSize = payload.size();
-      packet.insert(packet.begin(), payload.begin(), payload.end());
+      proxyProtocolPayload = makeLocalProxyHeader();
+      proxyProtocolPayloadSize = proxyProtocolPayload.size();
+      if (!ds->isDoH()) {
+        packet.insert(packet.begin(), proxyProtocolPayload.begin(), proxyProtocolPayload.end());
+      }
     }
 
     Socket sock(ds->remote.sin4.sin_family, ds->doHealthcheckOverTCP() ? SOCK_STREAM : SOCK_DGRAM);
@@ -397,6 +400,7 @@ bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared
     }
     else if (ds->isDoH()) {
       InternalQuery query(std::move(packet), IDState());
+      query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
       auto sender = std::shared_ptr<TCPQuerySender>(new HealthCheckQuerySender(data));
       if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) {
         updateHealthCheckResult(data->d_ds, data->d_initial, false);
index 2819d159e23ba0b34f3b296d92781214d788de1a..dbb3c86e3c39c73121048c9f18c47c48cd2b3b92 100644 (file)
@@ -46,7 +46,7 @@ uint16_t g_outgoingDoHWorkerThreads{0};
 class DoHConnectionToBackend : public TCPConnectionToBackend
 {
 public:
-  DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now);
+  DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload);
 
   void handleTimeout(const struct timeval& now, bool write) override;
   void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query) override;
@@ -109,16 +109,18 @@ private:
   std::unordered_map<int32_t, PendingRequest> d_currentStreams;
   PacketBuffer d_out;
   PacketBuffer d_in;
+  std::string d_proxyProtocolPayload;
   size_t d_outPos{0};
   size_t d_inPos{0};
   uint32_t d_highestStreamID{0};
   bool d_healthCheckQuery{false};
+  bool d_proxyProtocolPayloadSent{false};
 };
 
 class DownstreamDoHConnectionsManager
 {
 public:
-  static std::shared_ptr<DoHConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
+  static std::shared_ptr<DoHConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& ds, const struct timeval& now, std::string&& proxyProtocolPayload);
   static void releaseDownstreamConnection(std::shared_ptr<DoHConnectionToBackend>&& conn);
   static bool removeDownstreamConnection(std::shared_ptr<DoHConnectionToBackend>& conn);
   static void cleanupClosedConnections(struct timeval now);
@@ -193,6 +195,11 @@ bool DoHConnectionToBackend::canBeReused() const
   if (d_connectionDied) {
     return false;
   }
+
+  if (!d_proxyProtocolPayload.empty()) {
+    return false;
+  }
+
   const uint32_t maximumStreamID = (static_cast<uint32_t>(1) << 31) - 1;
   if (d_highestStreamID == maximumStreamID) {
     return false;
@@ -525,6 +532,11 @@ ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const ui
 {
   DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
   bool bufferWasEmpty = conn->d_out.empty();
+  if (!conn->d_proxyProtocolPayloadSent && !conn->d_proxyProtocolPayload.empty()) {
+    conn->d_out.insert(conn->d_out.end(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end());
+    conn->d_proxyProtocolPayloadSent = true;
+  }
+
   conn->d_out.insert(conn->d_out.end(), data, data + length);
 
   if (bufferWasEmpty) {
@@ -685,7 +697,7 @@ int DoHConnectionToBackend::on_stream_close_callback(nghttp2_session* session, i
   if (request.d_query.d_downstreamFailures < conn->d_ds->d_retries) {
     // cerr<<"in "<<__PRETTY_FUNCTION__<<", looking for a connection to send a query of size "<<request.d_query.d_buffer.size()<<endl;
     ++request.d_query.d_downstreamFailures;
-    auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(conn->d_mplexer, conn->d_ds, now);
+    auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(conn->d_mplexer, conn->d_ds, now, std::string(conn->d_proxyProtocolPayload));
     downstream->queueQuery(request.d_sender, std::move(request.d_query));
   }
   else {
@@ -741,8 +753,8 @@ int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_
   return 0;
 }
 
-DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now) :
-  TCPConnectionToBackend(ds, mplexer, now)
+DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload) :
+  TCPConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload))
 {
   // inherit most of the stuff from the TCPConnectionToBackend()
   d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
@@ -863,7 +875,7 @@ void DownstreamDoHConnectionsManager::cleanupClosedConnections(struct timeval no
   }
 }
 
-std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
+std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& ds, const struct timeval& now, std::string&& proxyProtocolPayload)
 {
   std::shared_ptr<DoHConnectionToBackend> result;
   struct timeval freshCutOff = now;
@@ -877,7 +889,8 @@ std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConn
     cleanupClosedConnections(now);
   }
 
-  {
+  const bool haveProxyProtocol = !proxyProtocolPayload.empty();
+  if (!haveProxyProtocol) {
     //cerr<<"looking for existing connection"<<endl;
     const auto& it = t_downstreamConnections.find(backendId);
     if (it != t_downstreamConnections.end()) {
@@ -906,11 +919,13 @@ std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConn
         ++listIt;
       }
     }
+  }
 
-    auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
+  auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now, std::move(proxyProtocolPayload));
+  if (!haveProxyProtocol) {
     t_downstreamConnections[backendId].push_back(newConnection);
-    return newConnection;
   }
+  return newConnection;
 }
 
 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
@@ -943,9 +958,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
     tmp = nullptr;
 
     try {
-      auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
-
-#warning what about the proxy protocol payload, here, do we need to remove it? we likely need to handle forward-for headers?
+      auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload));
       downstream->queueQuery(tqs, std::move(query));
     }
     catch (...) {
@@ -1193,12 +1206,12 @@ bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDM
 
   if (healthCheck) {
     /* always do health-checks over a new connection */
-    auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
+    auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now, std::move(query.d_proxyProtocolPayload));
     newConnection->setHealthCheck(healthCheck);
     newConnection->queueQuery(sender, std::move(query));
   }
   else {
-    auto connection = DownstreamDoHConnectionsManager::getConnectionToDownstream(mplexer, ds, now);
+    auto connection = DownstreamDoHConnectionsManager::getConnectionToDownstream(mplexer, ds, now, std::move(query.d_proxyProtocolPayload));
     connection->queueQuery(sender, std::move(query));
   }
 
index fad20531a974315f889cdaf313842bc166c3efcb..4441fea264588f67b1d7a4a3656d17814487e825 100644 (file)
@@ -502,6 +502,11 @@ public:
   DoHCrossProtocolQuery(DOHUnit* du_): du(du_)
   {
     query = InternalQuery(std::move(du->query), std::move(du->ids));
+    /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
+       clearing query.d_proxyProtocolPayloadAdded and du->proxyProtocolPayloadSize.
+       Leave it for now because we know that the onky case where the payload has been
+       added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
+       and we know the TCP/DoT code can handle it. */
     query.d_proxyProtocolPayloadAdded = du->proxyProtocolPayloadSize > 0;
     downstream = du->downstream;
     proxyProtocolPayloadSize = du->proxyProtocolPayloadSize;
@@ -619,9 +624,16 @@ static int processDOHQuery(DOHUnit* du)
     }
 
     if (du->downstream->isTCPOnly()) {
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
+      std::string proxyProtocolPayload;
+      /* we need to do this _before_ creating the cross protocol query because
+         after that the buffer will have been moved */
+      if (du->downstream->useProxyProtocol) {
+        proxyProtocolPayload = getProxyProtocolPayload(dq);
+      }
 
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
       du->get();
+      cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
       du->tcp = true;
       du->ids.origID = htons(queryId);
       du->ids.cs = &cs;
index 5571d65fb58d199328802da9ab637d533dd89ef0..c794d13a1db9dc85531dd0db511baf87d3af630a 100644 (file)
@@ -81,8 +81,8 @@ std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d
 
 struct DOHConnection
 {
-  DOHConnection() :
-    d_session(std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(nullptr, nghttp2_session_del))
+  DOHConnection(bool needProxyProtocol) :
+    d_session(std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(nullptr, nghttp2_session_del)), d_needProxyProtocol(needProxyProtocol)
   {
     nghttp2_session_callbacks* cbs = nullptr;
     nghttp2_session_callbacks_new(&cbs);
@@ -102,15 +102,47 @@ struct DOHConnection
   }
 
   PacketBuffer d_serverOutBuffer;
+  PacketBuffer d_proxyProtocolBuffer;
   std::map<uint32_t, PacketBuffer> d_queries;
   std::map<uint32_t, PacketBuffer> d_responses;
   std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)> d_session;
   /* used to replace the stream ID in outgoing frames. Ugly but the library does not let us
      test weird cases without that */
   std::map<uint32_t, uint32_t> d_idMapping;
+  bool d_needProxyProtocol;
 
   size_t submitIncoming(const PacketBuffer& data, size_t pos, size_t toWrite)
   {
+    size_t consumed = 0;
+    if (d_needProxyProtocol) {
+      do {
+        auto bytesRemaining = isProxyHeaderComplete(d_proxyProtocolBuffer);
+        if (bytesRemaining < 0) {
+          size_t toConsume = toWrite > static_cast<size_t>(-bytesRemaining) ? static_cast<size_t>(-bytesRemaining) : toWrite;
+          d_proxyProtocolBuffer.insert(d_proxyProtocolBuffer.end(), data.begin() + pos, data.begin() + pos + toConsume);
+          pos += toConsume;
+          toWrite -= toConsume;
+          consumed += toConsume;
+
+          bytesRemaining = isProxyHeaderComplete(d_proxyProtocolBuffer);
+          if (bytesRemaining > 0) {
+            d_needProxyProtocol = false;
+          }
+          else if (bytesRemaining == 0) {
+            throw("Fatal error while parsing proxy protocol payload");
+          }
+        }
+        else if (bytesRemaining == 0) {
+          throw("Fatal error while parsing proxy protocol payload");
+        }
+
+        if (toWrite == 0) {
+          return consumed;
+        }
+      }
+      while (d_needProxyProtocol && toWrite > 0);
+    }
+
     ssize_t readlen = nghttp2_session_mem_recv(d_session.get(), &data.at(pos), toWrite);
     if (readlen < 0) {
       throw("Fatal error while submitting: " + std::string(nghttp2_strerror(static_cast<int>(readlen))));
@@ -250,10 +282,8 @@ private:
   static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data)
   {
     DOHConnection* conn = reinterpret_cast<DOHConnection*>(user_data);
-    // cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
     auto& query = conn->d_queries[stream_id];
     query.insert(query.end(), data, data + len);
-    // cerr<<"out "<<__PRETTY_FUNCTION__<<endl;
     return 0;
   }
 
@@ -274,10 +304,10 @@ static std::map<int, std::unique_ptr<DOHConnection>> s_connectionBuffers;
 class MockupTLSConnection : public TLSConnection
 {
 public:
-  MockupTLSConnection(int descriptor, bool client = false) :
+  MockupTLSConnection(int descriptor, bool client = false, bool needProxyProtocol = false) :
     d_descriptor(descriptor), d_client(client)
   {
-    s_connectionBuffers[d_descriptor] = std::make_unique<DOHConnection>();
+    s_connectionBuffers[d_descriptor] = std::make_unique<DOHConnection>(needProxyProtocol);
   }
 
   ~MockupTLSConnection() {}
@@ -346,11 +376,9 @@ public:
 
     BOOST_REQUIRE_GE(buffer.size(), toRead);
 
-    // cerr<<"in server try read, adding "<<toRead<<" bytes from the buffer of size "<<externalBuffer.size()<<" at position "<<pos<<", buffer had a size of "<<buffer.size()<<endl;
     std::copy(externalBuffer.begin(), externalBuffer.begin() + toRead, buffer.begin() + pos);
     pos += toRead;
     externalBuffer.erase(externalBuffer.begin(), externalBuffer.begin() + toRead);
-    // cerr<<"external buffer has "<<externalBuffer.size()<<" remaining"<<endl;
 
     return step.nextState;
   }
@@ -454,7 +482,7 @@ public:
 
   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
   {
-    return std::make_unique<MockupTLSConnection>(socket, true);
+    return std::make_unique<MockupTLSConnection>(socket, true, d_needProxyProtocol);
   }
 
   void rotateTicketsKey(time_t now) override
@@ -470,6 +498,8 @@ public:
   {
     return "Mockup TLS";
   }
+
+  bool d_needProxyProtocol{false};
 };
 
 class MockupFDMultiplexer : public FDMultiplexer
@@ -1766,5 +1796,104 @@ BOOST_FIXTURE_TEST_CASE(test_WrongStreamID, TestFixture)
   BOOST_CHECK_EQUAL(clearH2Connections(), 0U);
 }
 
+BOOST_FIXTURE_TEST_CASE(test_ProxyProtocol, TestFixture)
+{
+  ComboAddress local("192.0.2.1:80");
+  ClientState localCS(local, true, false, false, "", {});
+  auto tlsCtx = std::make_shared<MockupTLSCtx>();
+  tlsCtx->d_needProxyProtocol = true;
+  localCS.tlsFrontend = std::make_shared<TLSFrontend>(tlsCtx);
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
+  auto backend = std::make_shared<DownstreamState>(ComboAddress("192.0.2.42:53"), ComboAddress("0.0.0.0:0"), 0, std::string(), 1, false);
+  backend->d_tlsCtx = tlsCtx;
+  backend->d_tlsSubjectName = "backend.powerdns.com";
+  backend->d_dohPath = "/dns-query";
+  backend->d_addXForwardedHeaders = true;
+  backend->useProxyProtocol = true;
+
+  size_t numberOfQueries = 2;
+  std::vector<std::pair<std::shared_ptr<MockupQuerySender>, InternalQuery>> queries;
+  for (size_t counter = 0; counter < numberOfQueries; counter++) {
+    DNSName name("powerdns.com.");
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> pwQ(query, name, QType::A, QClass::IN, 0);
+    pwQ.getHeader()->rd = 1;
+    pwQ.getHeader()->id = htons(counter);
+
+    PacketBuffer response;
+    GenericDNSPacketWriter<PacketBuffer> pwR(response, name, QType::A, QClass::IN, 0);
+    pwR.getHeader()->qr = 1;
+    pwR.getHeader()->rd = 1;
+    pwR.getHeader()->ra = 1;
+    pwR.getHeader()->id = htons(counter);
+    pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
+    pwR.xfr32BitInt(0x01020304);
+    pwR.commit();
+
+    s_responses[counter] = {query, response};
+
+    auto sender = std::make_shared<MockupQuerySender>();
+    sender->d_id = counter;
+    std::string payload = makeProxyHeader(counter % 2, local, local, {});
+    InternalQuery internalQuery(std::move(query), IDState());
+    internalQuery.d_proxyProtocolPayload = std::move(payload);
+    queries.push_back({std::move(sender), std::move(internalQuery)});
+  }
+
+  s_steps = {
+    {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done},
+    /* proxy protocol data + opening */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* settings */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* headers */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* data */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max(), [](int desc, const ExpectedStep& step) {
+       /* set the outgoing descriptor (backend connection) as ready */
+       dynamic_cast<MockupFDMultiplexer*>(s_mplexer.get())->setReady(desc);
+     }},
+    {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done},
+    /* proxy protocol data + opening */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* settings */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* headers */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* data */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max(), [](int desc, const ExpectedStep& step) {
+       /* set the outgoing descriptor (backend connection) as ready */
+       dynamic_cast<MockupFDMultiplexer*>(s_mplexer.get())->setReady(desc);
+     }},
+    /* read settings, headers and responses from the server */
+    {ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* acknowledge settings */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    {ExpectedStep::ExpectedRequest::closeBackend, IOState::Done},
+    /* read settings, headers and responses from the server */
+    {ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    /* acknowledge settings */
+    {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits<size_t>::max()},
+    {ExpectedStep::ExpectedRequest::closeBackend, IOState::Done},
+  };
+
+  for (auto& query : queries) {
+    auto sliced = std::static_pointer_cast<TCPQuerySender>(query.first);
+    bool result = sendH2Query(backend, s_mplexer, sliced, std::move(query.second), false);
+    BOOST_CHECK_EQUAL(result, true);
+  }
+
+  while (s_mplexer->getWatchedFDCount(false) != 0 || s_mplexer->getWatchedFDCount(true) != 0) {
+    s_mplexer->run(&now);
+  }
+
+  for (auto& query : queries) {
+    BOOST_CHECK_EQUAL(query.first->d_valid, true);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();
 #endif /* HAVE_NGHTTP2 */
index 9c7e1be99f4f734dde51d970f209b9618b30d73f..9e7a12d3dab0baa20ea750da7d7f1ac6214928fe 100644 (file)
@@ -3046,7 +3046,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     proxyEnabledBackend->d_tlsCtx = tlsCtx;
     /* enable out-of-order on the backend side as well */
     proxyEnabledBackend->d_maxInFlightQueriesPerConn = 65536;
-    proxyEnabledBackend-> useProxyProtocol = true;
+    proxyEnabledBackend->useProxyProtocol = true;
 
     expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end());
     expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end());
index 57d3d2ccbf2526eaf376cc7f8282af3779c9b6b6..60d8c8932b6e4f35797c773e638bcac1a773b62f 100644 (file)
@@ -24,6 +24,7 @@ import h2.events
 import h2.config
 
 from eqdnsmessage import AssertEqualDNSMessageMixin
+from proxyprotocol import ProxyProtocol
 
 # Python2/3 compatibility hacks
 try:
@@ -322,7 +323,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         sock.close()
 
     @classmethod
-    def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None):
+    def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
@@ -358,6 +359,30 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             conn.sendall(h2conn.data_to_send())
             dnsData = {}
 
+            if useProxyProtocol:
+                # try to read the entire Proxy Protocol header
+                proxy = ProxyProtocol()
+                header = conn.recv(proxy.HEADER_SIZE)
+                if not header:
+                    print('unable to get header')
+                    conn.close()
+                    continue
+
+                if not proxy.parseHeader(header):
+                    print('unable to parse header')
+                    print(header)
+                    conn.close()
+                    continue
+
+                proxyContent = conn.recv(proxy.contentLen)
+                if not proxyContent:
+                    print('unable to get content')
+                    conn.close()
+                    continue
+
+                payload = header + proxyContent
+                toQueue.put(payload, True, cls._queueTimeout)
+
             while True:
                 data = conn.recv(65535)
                 if not data:
@@ -519,7 +544,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             cls.sendTCPQueryOverConnection(sock, query, rawQuery)
             message = cls.recvTCPResponseOverConnection(sock)
         except socket.timeout as e:
-            print("Timeout: %s" % (str(e)))
+            print("Timeout while sending or receiving TCP data: %s" % (str(e)))
         except socket.error as e:
             print("Network error: %s" % (str(e)))
         finally:
@@ -566,7 +591,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 messages.append(dns.message.from_wire(data))
 
         except socket.timeout as e:
-            print("Timeout: %s" % (str(e)))
+            print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
         except socket.error as e:
             print("Network error: %s" % (str(e)))
         finally:
@@ -743,3 +768,33 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             for inFileName in ['server.pem', 'ca.pem']:
                 with open(inFileName) as inFile:
                     outFile.write(inFile.read())
+
+    def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
+        proxy = ProxyProtocol()
+        self.assertTrue(proxy.parseHeader(receivedProxyPayload))
+        self.assertEqual(proxy.version, 0x02)
+        self.assertEqual(proxy.command, 0x01)
+        if v6:
+            self.assertEqual(proxy.family, 0x02)
+        else:
+            self.assertEqual(proxy.family, 0x01)
+        if not isTCP:
+            self.assertEqual(proxy.protocol, 0x02)
+        else:
+            self.assertEqual(proxy.protocol, 0x01)
+        self.assertGreater(proxy.contentLen, 0)
+
+        self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
+        self.assertEqual(proxy.source, source)
+        self.assertEqual(proxy.destination, destination)
+        if sourcePort:
+            self.assertEqual(proxy.sourcePort, sourcePort)
+        if destinationPort:
+            self.assertEqual(proxy.destinationPort, destinationPort)
+        else:
+            self.assertEqual(proxy.destinationPort, self._dnsDistPort)
+
+        self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
+        proxy.values.sort()
+        values.sort()
+        self.assertEqual(proxy.values, values)
index 205518c1bbf2f91c70e4c264c311ae0552d01890..42b30c748905f9eb4b35a5b0baeb2c3dfcdf8f53 100644 (file)
@@ -390,3 +390,50 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons
         cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
         cls._DOHResponder.setDaemon(True)
         cls._DOHResponder.start()
+
+class TestOutgoingDOHProxyProtocol(DNSDistTest):
+
+    _tlsBackendPort = 10551
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', useProxyProtocol=true}:setUp()
+    """
+    _verboseMode = True
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH woth Proxy Protocol responder..")
+        cls._DOHResponder = threading.Thread(name='DOH with Proxy Protocol Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext, True])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+    def testPP(self):
+        """
+        Outgoing DOH with Proxy Protocol
+        """
+        name = 'proxy-protocol.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (receivedProxyPayload, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+        receivedQuery = self._fromResponderQueue.get(True, 1.0)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False)
+
+        (receivedProxyPayload, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        receivedQuery = self._fromResponderQueue.get(True, 1.0)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
index 18b9efb414e7c5ba69e99795eee546a7492cf8a2..16b27fde4713c4962ea608772d7e5282781b6411 100644 (file)
@@ -142,36 +142,6 @@ class ProxyProtocolTest(DNSDistTest):
     _proxyResponderPort = proxyResponderPort
     _config_params = ['_proxyResponderPort']
 
-    def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
-      proxy = ProxyProtocol()
-      self.assertTrue(proxy.parseHeader(receivedProxyPayload))
-      self.assertEqual(proxy.version, 0x02)
-      self.assertEqual(proxy.command, 0x01)
-      if v6:
-        self.assertEqual(proxy.family, 0x02)
-      else:
-        self.assertEqual(proxy.family, 0x01)
-      if not isTCP:
-        self.assertEqual(proxy.protocol, 0x02)
-      else:
-        self.assertEqual(proxy.protocol, 0x01)
-      self.assertGreater(proxy.contentLen, 0)
-
-      self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
-      self.assertEqual(proxy.source, source)
-      self.assertEqual(proxy.destination, destination)
-      if sourcePort:
-        self.assertEqual(proxy.sourcePort, sourcePort)
-      if destinationPort:
-        self.assertEqual(proxy.destinationPort, destinationPort)
-      else:
-        self.assertEqual(proxy.destinationPort, self._dnsDistPort)
-
-      self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
-      proxy.values.sort()
-      values.sort()
-      self.assertEqual(proxy.values, values)
-
 class TestProxyProtocol(ProxyProtocolTest):
     """
     dnsdist is configured to prepend a Proxy Protocol header to the query