From: Remi Gacogne Date: Mon, 6 Sep 2021 14:46:38 +0000 (+0200) Subject: dnsdist: Implement Proxy Protocol support for outgoing DoH X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0e6892c688047c21282a54df36ca71af466ba527;p=thirdparty%2Fpdns.git dnsdist: Implement Proxy Protocol support for outgoing DoH --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index f839b321f2..20d1f00fcb 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -552,7 +552,6 @@ public: { query = InternalQuery(std::move(buffer), std::move(ids)); downstream = ds; - #warning handle proxy protocol payload proxyProtocolPayloadSize = 0; } @@ -694,11 +693,19 @@ static void handleQuery(std::shared_ptr& state, cons ++state->d_currentQueriesCount; + std::string proxyProtocolPayload; if (ds->isDoH()) { vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName()); + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (ds->useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + auto incoming = std::make_shared(state, state->d_threadData.crossProtocolResponsesPipe); auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, incoming); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); ds->passCrossProtocolQuery(std::move(cpq)); return; @@ -709,7 +716,6 @@ static void handleQuery(std::shared_ptr& state, cons auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); bool proxyProtocolPayloadAdded = false; - std::string proxyProtocolPayload; if (ds->useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index dbcbbcc4c8..c8b8d3f407 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1493,6 +1493,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } if (ss->isTCPOnly()) { + std::string proxyProtocolPayload; + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (ss->useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + IDState ids; ids.cs = &cs; ids.origFD = cs.udpFD; @@ -1505,6 +1512,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids.origDest = cs.local; } auto cpq = std::make_unique(std::move(query), std::move(ids), ss); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); ss->passCrossProtocolQuery(std::move(cpq)); return; diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index f452932d65..ebca358532 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -341,11 +341,14 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared /* we need to compute that _before_ adding the proxy protocol payload */ uint16_t packetSize = packet.size(); + std::string proxyProtocolPayload; size_t proxyProtocolPayloadSize = 0; if (ds->useProxyProtocol) { - auto payload = makeLocalProxyHeader(); - proxyProtocolPayloadSize = payload.size(); - packet.insert(packet.begin(), payload.begin(), payload.end()); + proxyProtocolPayload = makeLocalProxyHeader(); + proxyProtocolPayloadSize = proxyProtocolPayload.size(); + if (!ds->isDoH()) { + packet.insert(packet.begin(), proxyProtocolPayload.begin(), proxyProtocolPayload.end()); + } } Socket sock(ds->remote.sin4.sin_family, ds->doHealthcheckOverTCP() ? SOCK_STREAM : SOCK_DGRAM); @@ -397,6 +400,7 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared } else if (ds->isDoH()) { InternalQuery query(std::move(packet), IDState()); + query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); auto sender = std::shared_ptr(new HealthCheckQuerySender(data)); if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) { updateHealthCheckResult(data->d_ds, data->d_initial, false); diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 2819d159e2..dbb3c86e3c 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -46,7 +46,7 @@ uint16_t g_outgoingDoHWorkerThreads{0}; class DoHConnectionToBackend : public TCPConnectionToBackend { public: - DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now); + DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload); void handleTimeout(const struct timeval& now, bool write) override; void queueQuery(std::shared_ptr& sender, TCPQuery&& query) override; @@ -109,16 +109,18 @@ private: std::unordered_map d_currentStreams; PacketBuffer d_out; PacketBuffer d_in; + std::string d_proxyProtocolPayload; size_t d_outPos{0}; size_t d_inPos{0}; uint32_t d_highestStreamID{0}; bool d_healthCheckQuery{false}; + bool d_proxyProtocolPayloadSent{false}; }; class DownstreamDoHConnectionsManager { public: - static std::shared_ptr getConnectionToDownstream(std::unique_ptr& mplexer, const std::shared_ptr& ds, const struct timeval& now); + static std::shared_ptr getConnectionToDownstream(std::unique_ptr& mplexer, const std::shared_ptr& ds, const struct timeval& now, std::string&& proxyProtocolPayload); static void releaseDownstreamConnection(std::shared_ptr&& conn); static bool removeDownstreamConnection(std::shared_ptr& conn); static void cleanupClosedConnections(struct timeval now); @@ -193,6 +195,11 @@ bool DoHConnectionToBackend::canBeReused() const if (d_connectionDied) { return false; } + + if (!d_proxyProtocolPayload.empty()) { + return false; + } + const uint32_t maximumStreamID = (static_cast(1) << 31) - 1; if (d_highestStreamID == maximumStreamID) { return false; @@ -525,6 +532,11 @@ ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const ui { DoHConnectionToBackend* conn = reinterpret_cast(user_data); bool bufferWasEmpty = conn->d_out.empty(); + if (!conn->d_proxyProtocolPayloadSent && !conn->d_proxyProtocolPayload.empty()) { + conn->d_out.insert(conn->d_out.end(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end()); + conn->d_proxyProtocolPayloadSent = true; + } + conn->d_out.insert(conn->d_out.end(), data, data + length); if (bufferWasEmpty) { @@ -685,7 +697,7 @@ int DoHConnectionToBackend::on_stream_close_callback(nghttp2_session* session, i if (request.d_query.d_downstreamFailures < conn->d_ds->d_retries) { // cerr<<"in "<<__PRETTY_FUNCTION__<<", looking for a connection to send a query of size "<d_mplexer, conn->d_ds, now); + auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(conn->d_mplexer, conn->d_ds, now, std::string(conn->d_proxyProtocolPayload)); downstream->queueQuery(request.d_sender, std::move(request.d_query)); } else { @@ -741,8 +753,8 @@ int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_ return 0; } -DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now) : - TCPConnectionToBackend(ds, mplexer, now) +DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload) : + TCPConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload)) { // inherit most of the stuff from the TCPConnectionToBackend() d_ioState = make_unique(*d_mplexer, d_handler->getDescriptor()); @@ -863,7 +875,7 @@ void DownstreamDoHConnectionsManager::cleanupClosedConnections(struct timeval no } } -std::shared_ptr DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr& mplexer, const std::shared_ptr& ds, const struct timeval& now) +std::shared_ptr DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr& mplexer, const std::shared_ptr& ds, const struct timeval& now, std::string&& proxyProtocolPayload) { std::shared_ptr result; struct timeval freshCutOff = now; @@ -877,7 +889,8 @@ std::shared_ptr DownstreamDoHConnectionsManager::getConn cleanupClosedConnections(now); } - { + const bool haveProxyProtocol = !proxyProtocolPayload.empty(); + if (!haveProxyProtocol) { //cerr<<"looking for existing connection"< DownstreamDoHConnectionsManager::getConn ++listIt; } } + } - auto newConnection = std::make_shared(ds, mplexer, now); + auto newConnection = std::make_shared(ds, mplexer, now, std::move(proxyProtocolPayload)); + if (!haveProxyProtocol) { t_downstreamConnections[backendId].push_back(newConnection); - return newConnection; } + return newConnection; } static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) @@ -943,9 +958,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par tmp = nullptr; try { - auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); - -#warning what about the proxy protocol payload, here, do we need to remove it? we likely need to handle forward-for headers? + auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload)); downstream->queueQuery(tqs, std::move(query)); } catch (...) { @@ -1193,12 +1206,12 @@ bool sendH2Query(const std::shared_ptr& ds, std::unique_ptr(ds, mplexer, now); + auto newConnection = std::make_shared(ds, mplexer, now, std::move(query.d_proxyProtocolPayload)); newConnection->setHealthCheck(healthCheck); newConnection->queueQuery(sender, std::move(query)); } else { - auto connection = DownstreamDoHConnectionsManager::getConnectionToDownstream(mplexer, ds, now); + auto connection = DownstreamDoHConnectionsManager::getConnectionToDownstream(mplexer, ds, now, std::move(query.d_proxyProtocolPayload)); connection->queueQuery(sender, std::move(query)); } diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index fad20531a9..4441fea264 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -502,6 +502,11 @@ public: DoHCrossProtocolQuery(DOHUnit* du_): du(du_) { query = InternalQuery(std::move(du->query), std::move(du->ids)); + /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload, + clearing query.d_proxyProtocolPayloadAdded and du->proxyProtocolPayloadSize. + Leave it for now because we know that the onky case where the payload has been + added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT, + and we know the TCP/DoT code can handle it. */ query.d_proxyProtocolPayloadAdded = du->proxyProtocolPayloadSize > 0; downstream = du->downstream; proxyProtocolPayloadSize = du->proxyProtocolPayloadSize; @@ -619,9 +624,16 @@ static int processDOHQuery(DOHUnit* du) } if (du->downstream->isTCPOnly()) { - auto cpq = std::make_unique(du); + std::string proxyProtocolPayload; + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (du->downstream->useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + auto cpq = std::make_unique(du); du->get(); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); du->tcp = true; du->ids.origID = htons(queryId); du->ids.cs = &cs; diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index 5571d65fb5..c794d13a1d 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -81,8 +81,8 @@ std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d struct DOHConnection { - DOHConnection() : - d_session(std::unique_ptr(nullptr, nghttp2_session_del)) + DOHConnection(bool needProxyProtocol) : + d_session(std::unique_ptr(nullptr, nghttp2_session_del)), d_needProxyProtocol(needProxyProtocol) { nghttp2_session_callbacks* cbs = nullptr; nghttp2_session_callbacks_new(&cbs); @@ -102,15 +102,47 @@ struct DOHConnection } PacketBuffer d_serverOutBuffer; + PacketBuffer d_proxyProtocolBuffer; std::map d_queries; std::map d_responses; std::unique_ptr d_session; /* used to replace the stream ID in outgoing frames. Ugly but the library does not let us test weird cases without that */ std::map d_idMapping; + bool d_needProxyProtocol; size_t submitIncoming(const PacketBuffer& data, size_t pos, size_t toWrite) { + size_t consumed = 0; + if (d_needProxyProtocol) { + do { + auto bytesRemaining = isProxyHeaderComplete(d_proxyProtocolBuffer); + if (bytesRemaining < 0) { + size_t toConsume = toWrite > static_cast(-bytesRemaining) ? static_cast(-bytesRemaining) : toWrite; + d_proxyProtocolBuffer.insert(d_proxyProtocolBuffer.end(), data.begin() + pos, data.begin() + pos + toConsume); + pos += toConsume; + toWrite -= toConsume; + consumed += toConsume; + + bytesRemaining = isProxyHeaderComplete(d_proxyProtocolBuffer); + if (bytesRemaining > 0) { + d_needProxyProtocol = false; + } + else if (bytesRemaining == 0) { + throw("Fatal error while parsing proxy protocol payload"); + } + } + else if (bytesRemaining == 0) { + throw("Fatal error while parsing proxy protocol payload"); + } + + if (toWrite == 0) { + return consumed; + } + } + while (d_needProxyProtocol && toWrite > 0); + } + ssize_t readlen = nghttp2_session_mem_recv(d_session.get(), &data.at(pos), toWrite); if (readlen < 0) { throw("Fatal error while submitting: " + std::string(nghttp2_strerror(static_cast(readlen)))); @@ -250,10 +282,8 @@ private: static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data) { DOHConnection* conn = reinterpret_cast(user_data); - // cerr<<"in "<<__PRETTY_FUNCTION__<d_queries[stream_id]; query.insert(query.end(), data, data + len); - // cerr<<"out "<<__PRETTY_FUNCTION__<> s_connectionBuffers; class MockupTLSConnection : public TLSConnection { public: - MockupTLSConnection(int descriptor, bool client = false) : + MockupTLSConnection(int descriptor, bool client = false, bool needProxyProtocol = false) : d_descriptor(descriptor), d_client(client) { - s_connectionBuffers[d_descriptor] = std::make_unique(); + s_connectionBuffers[d_descriptor] = std::make_unique(needProxyProtocol); } ~MockupTLSConnection() {} @@ -346,11 +376,9 @@ public: BOOST_REQUIRE_GE(buffer.size(), toRead); - // cerr<<"in server try read, adding "< getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override { - return std::make_unique(socket, true); + return std::make_unique(socket, true, d_needProxyProtocol); } void rotateTicketsKey(time_t now) override @@ -470,6 +498,8 @@ public: { return "Mockup TLS"; } + + bool d_needProxyProtocol{false}; }; class MockupFDMultiplexer : public FDMultiplexer @@ -1766,5 +1796,104 @@ BOOST_FIXTURE_TEST_CASE(test_WrongStreamID, TestFixture) BOOST_CHECK_EQUAL(clearH2Connections(), 0U); } +BOOST_FIXTURE_TEST_CASE(test_ProxyProtocol, TestFixture) +{ + ComboAddress local("192.0.2.1:80"); + ClientState localCS(local, true, false, false, "", {}); + auto tlsCtx = std::make_shared(); + tlsCtx->d_needProxyProtocol = true; + localCS.tlsFrontend = std::make_shared(tlsCtx); + + struct timeval now; + gettimeofday(&now, nullptr); + + auto backend = std::make_shared(ComboAddress("192.0.2.42:53"), ComboAddress("0.0.0.0:0"), 0, std::string(), 1, false); + backend->d_tlsCtx = tlsCtx; + backend->d_tlsSubjectName = "backend.powerdns.com"; + backend->d_dohPath = "/dns-query"; + backend->d_addXForwardedHeaders = true; + backend->useProxyProtocol = true; + + size_t numberOfQueries = 2; + std::vector, InternalQuery>> queries; + for (size_t counter = 0; counter < numberOfQueries; counter++) { + DNSName name("powerdns.com."); + PacketBuffer query; + GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); + pwQ.getHeader()->rd = 1; + pwQ.getHeader()->id = htons(counter); + + PacketBuffer response; + GenericDNSPacketWriter pwR(response, name, QType::A, QClass::IN, 0); + pwR.getHeader()->qr = 1; + pwR.getHeader()->rd = 1; + pwR.getHeader()->ra = 1; + pwR.getHeader()->id = htons(counter); + pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); + pwR.xfr32BitInt(0x01020304); + pwR.commit(); + + s_responses[counter] = {query, response}; + + auto sender = std::make_shared(); + sender->d_id = counter; + std::string payload = makeProxyHeader(counter % 2, local, local, {}); + InternalQuery internalQuery(std::move(query), IDState()); + internalQuery.d_proxyProtocolPayload = std::move(payload); + queries.push_back({std::move(sender), std::move(internalQuery)}); + } + + s_steps = { + {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done}, + /* proxy protocol data + opening */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* settings */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* headers */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* data */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max(), [](int desc, const ExpectedStep& step) { + /* set the outgoing descriptor (backend connection) as ready */ + dynamic_cast(s_mplexer.get())->setReady(desc); + }}, + {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done}, + /* proxy protocol data + opening */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* settings */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* headers */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + /* data */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max(), [](int desc, const ExpectedStep& step) { + /* set the outgoing descriptor (backend connection) as ready */ + dynamic_cast(s_mplexer.get())->setReady(desc); + }}, + /* read settings, headers and responses from the server */ + {ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, std::numeric_limits::max()}, + /* acknowledge settings */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + {ExpectedStep::ExpectedRequest::closeBackend, IOState::Done}, + /* read settings, headers and responses from the server */ + {ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, std::numeric_limits::max()}, + /* acknowledge settings */ + {ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, std::numeric_limits::max()}, + {ExpectedStep::ExpectedRequest::closeBackend, IOState::Done}, + }; + + for (auto& query : queries) { + auto sliced = std::static_pointer_cast(query.first); + bool result = sendH2Query(backend, s_mplexer, sliced, std::move(query.second), false); + BOOST_CHECK_EQUAL(result, true); + } + + while (s_mplexer->getWatchedFDCount(false) != 0 || s_mplexer->getWatchedFDCount(true) != 0) { + s_mplexer->run(&now); + } + + for (auto& query : queries) { + BOOST_CHECK_EQUAL(query.first->d_valid, true); + } +} + BOOST_AUTO_TEST_SUITE_END(); #endif /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 9c7e1be99f..9e7a12d3da 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -3046,7 +3046,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) proxyEnabledBackend->d_tlsCtx = tlsCtx; /* enable out-of-order on the backend side as well */ proxyEnabledBackend->d_maxInFlightQueriesPerConn = 65536; - proxyEnabledBackend-> useProxyProtocol = true; + proxyEnabledBackend->useProxyProtocol = true; expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end()); expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end()); diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 57d3d2ccbf..60d8c8932b 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -24,6 +24,7 @@ import h2.events import h2.config from eqdnsmessage import AssertEqualDNSMessageMixin +from proxyprotocol import ProxyProtocol # Python2/3 compatibility hacks try: @@ -322,7 +323,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sock.close() @classmethod - def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None): + def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. @@ -358,6 +359,30 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): conn.sendall(h2conn.data_to_send()) dnsData = {} + if useProxyProtocol: + # try to read the entire Proxy Protocol header + proxy = ProxyProtocol() + header = conn.recv(proxy.HEADER_SIZE) + if not header: + print('unable to get header') + conn.close() + continue + + if not proxy.parseHeader(header): + print('unable to parse header') + print(header) + conn.close() + continue + + proxyContent = conn.recv(proxy.contentLen) + if not proxyContent: + print('unable to get content') + conn.close() + continue + + payload = header + proxyContent + toQueue.put(payload, True, cls._queueTimeout) + while True: data = conn.recv(65535) if not data: @@ -519,7 +544,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls.sendTCPQueryOverConnection(sock, query, rawQuery) message = cls.recvTCPResponseOverConnection(sock) except socket.timeout as e: - print("Timeout: %s" % (str(e))) + print("Timeout while sending or receiving TCP data: %s" % (str(e))) except socket.error as e: print("Network error: %s" % (str(e))) finally: @@ -566,7 +591,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): messages.append(dns.message.from_wire(data)) except socket.timeout as e: - print("Timeout: %s" % (str(e))) + print("Timeout while receiving multiple TCP responses: %s" % (str(e))) except socket.error as e: print("Network error: %s" % (str(e))) finally: @@ -743,3 +768,33 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): for inFileName in ['server.pem', 'ca.pem']: with open(inFileName) as inFile: outFile.write(inFile.read()) + + def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): + proxy = ProxyProtocol() + self.assertTrue(proxy.parseHeader(receivedProxyPayload)) + self.assertEqual(proxy.version, 0x02) + self.assertEqual(proxy.command, 0x01) + if v6: + self.assertEqual(proxy.family, 0x02) + else: + self.assertEqual(proxy.family, 0x01) + if not isTCP: + self.assertEqual(proxy.protocol, 0x02) + else: + self.assertEqual(proxy.protocol, 0x01) + self.assertGreater(proxy.contentLen, 0) + + self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) + self.assertEqual(proxy.source, source) + self.assertEqual(proxy.destination, destination) + if sourcePort: + self.assertEqual(proxy.sourcePort, sourcePort) + if destinationPort: + self.assertEqual(proxy.destinationPort, destinationPort) + else: + self.assertEqual(proxy.destinationPort, self._dnsDistPort) + + self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) + proxy.values.sort() + values.sort() + self.assertEqual(proxy.values, values) diff --git a/regression-tests.dnsdist/test_OutgoingDOH.py b/regression-tests.dnsdist/test_OutgoingDOH.py index 205518c1bb..42b30c7489 100644 --- a/regression-tests.dnsdist/test_OutgoingDOH.py +++ b/regression-tests.dnsdist/test_OutgoingDOH.py @@ -390,3 +390,50 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext]) cls._DOHResponder.setDaemon(True) cls._DOHResponder.start() + +class TestOutgoingDOHProxyProtocol(DNSDistTest): + + _tlsBackendPort = 10551 + _config_params = ['_tlsBackendPort'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', useProxyProtocol=true}:setUp() + """ + _verboseMode = True + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH woth Proxy Protocol responder..") + cls._DOHResponder = threading.Thread(name='DOH with Proxy Protocol Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext, True]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + + def testPP(self): + """ + Outgoing DOH with Proxy Protocol + """ + name = 'proxy-protocol.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse.answer.append(rrset) + + (receivedProxyPayload, receivedResponse) = self.sendUDPQuery(query, expectedResponse) + receivedQuery = self._fromResponderQueue.get(True, 1.0) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False) + + (receivedProxyPayload, receivedResponse) = self.sendTCPQuery(query, expectedResponse) + receivedQuery = self._fromResponderQueue.get(True, 1.0) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True) diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 18b9efb414..16b27fde47 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -142,36 +142,6 @@ class ProxyProtocolTest(DNSDistTest): _proxyResponderPort = proxyResponderPort _config_params = ['_proxyResponderPort'] - def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): - proxy = ProxyProtocol() - self.assertTrue(proxy.parseHeader(receivedProxyPayload)) - self.assertEqual(proxy.version, 0x02) - self.assertEqual(proxy.command, 0x01) - if v6: - self.assertEqual(proxy.family, 0x02) - else: - self.assertEqual(proxy.family, 0x01) - if not isTCP: - self.assertEqual(proxy.protocol, 0x02) - else: - self.assertEqual(proxy.protocol, 0x01) - self.assertGreater(proxy.contentLen, 0) - - self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) - self.assertEqual(proxy.source, source) - self.assertEqual(proxy.destination, destination) - if sourcePort: - self.assertEqual(proxy.sourcePort, sourcePort) - if destinationPort: - self.assertEqual(proxy.destinationPort, destinationPort) - else: - self.assertEqual(proxy.destinationPort, self._dnsDistPort) - - self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) - proxy.values.sort() - values.sort() - self.assertEqual(proxy.values, values) - class TestProxyProtocol(ProxyProtocolTest): """ dnsdist is configured to prepend a Proxy Protocol header to the query