From: Remi Gacogne Date: Fri, 17 Nov 2023 15:06:54 +0000 (+0100) Subject: dnsdist: Add Proxy Protocol v2 support to `TeeAction` X-Git-Tag: rec-5.1.0-alpha0~2^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8d22a19d32b6dab6c69635b38b30e1873a503b84;p=thirdparty%2Fpdns.git dnsdist: Add Proxy Protocol v2 support to `TeeAction` --- diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index 575c8f3465..4463ffd5f8 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -30,6 +30,7 @@ #include "dnsdist-lua-ffi.hh" #include "dnsdist-mac-address.hh" #include "dnsdist-protobuf.hh" +#include "dnsdist-proxy-protocol.hh" #include "dnsdist-kvs.hh" #include "dnsdist-svc.hh" @@ -130,18 +131,18 @@ class TeeAction : public DNSAction { public: // this action does not stop the processing - TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS=false); + TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS = false, bool addProxyProtocol = false); ~TeeAction() override; DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override; std::string toString() const override; std::map getStats() const override; private: - ComboAddress d_remote; - std::thread d_worker; void worker(); - int d_fd{-1}; + ComboAddress d_remote; + std::thread d_worker; + Socket d_socket; mutable std::atomic d_senderrors{0}; unsigned long d_recverrors{0}; mutable std::atomic d_queries{0}; @@ -156,32 +157,26 @@ private: stat_t d_otherrcode{0}; std::atomic d_pleaseQuit{false}; bool d_addECS{false}; + bool d_addProxyProtocol{false}; }; -TeeAction::TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS) - : d_remote(rca), d_addECS(addECS) +TeeAction::TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS, bool addProxyProtocol) + : d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol) { - d_fd=SSocket(d_remote.sin4.sin_family, SOCK_DGRAM, 0); - try { - if (lca) { - SBind(d_fd, *lca); - } - SConnect(d_fd, d_remote); - setNonBlocking(d_fd); - d_worker=std::thread([this](){worker();}); - } - catch (...) { - if (d_fd != -1) { - close(d_fd); - } - throw; + if (lca) { + d_socket.bind(*lca, false); } + d_socket.connect(d_remote); + d_socket.setNonBlocking(); + d_worker = std::thread([this]() { + worker(); + }); } TeeAction::~TeeAction() { - d_pleaseQuit=true; - close(d_fd); + d_pleaseQuit = true; + close(d_socket.releaseHandle()); d_worker.join(); } @@ -189,28 +184,38 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult { if (dq->overTCP()) { d_tcpdrops++; + return DNSAction::Action::None; } - else { - ssize_t res; - d_queries++; - if(d_addECS) { - PacketBuffer query(dq->getData()); - bool ednsAdded = false; - bool ecsAdded = false; + d_queries++; - std::string newECSOption; - generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength); + PacketBuffer query; + if (d_addECS) { + query = dq->getData(); + bool ednsAdded = false; + bool ecsAdded = false; - if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { - return DNSAction::Action::None; - } + std::string newECSOption; + generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength); - res = send(d_fd, query.data(), query.size(), 0); + if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { + return DNSAction::Action::None; } - else { - res = send(d_fd, dq->getData().data(), dq->getData().size(), 0); + } + + if (d_addProxyProtocol) { + auto proxyPayload = getProxyProtocolPayload(*dq); + if (query.empty()) { + query = dq->getData(); + } + if (!addProxyProtocol(query, proxyPayload)) { + return DNSAction::Action::None; } + } + + { + const PacketBuffer& payload = query.empty() ? dq->getData() : query; + auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0); if (res <= 0) { d_senderrors++; @@ -222,7 +227,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult std::string TeeAction::toString() const { - return "tee to "+d_remote.toStringWithPort(); + return "tee to " + d_remote.toStringWithPort(); } std::map TeeAction::getStats() const @@ -247,7 +252,7 @@ void TeeAction::worker() ssize_t res = 0; const dnsheader_aligned dh(packet.data()); for (;;) { - res = waitForData(d_fd, 0, 250000); + res = waitForData(d_socket.getHandle(), 0, 250000); if (d_pleaseQuit) { break; } @@ -259,7 +264,7 @@ void TeeAction::worker() if (res == 0) { continue; } - res = recv(d_fd, packet.data(), packet.size(), 0); + res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0); if (static_cast(res) <= sizeof(struct dnsheader)) { d_recverrors++; } @@ -2739,13 +2744,13 @@ void setupLuaActions(LuaContext& luaCtx) }); #endif /* DISABLE_PROTOBUF */ - luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional addECS, boost::optional local) { + luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional addECS, boost::optional local, boost::optional addProxyProtocol) { boost::optional localAddr{boost::none}; if (local) { localAddr = ComboAddress(*local, 0); } - return std::shared_ptr(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false)); + return std::shared_ptr(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false)); }); luaCtx.writeFunction("SetECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) { diff --git a/pdns/dnsdistdist/docs/rules-actions.rst b/pdns/dnsdistdist/docs/rules-actions.rst index 090b5a7701..8ba254b369 100644 --- a/pdns/dnsdistdist/docs/rules-actions.rst +++ b/pdns/dnsdistdist/docs/rules-actions.rst @@ -1850,18 +1850,24 @@ The following actions exist. Before 1.7.0 this action was performed even when the query had been received over TCP, which required the use of :func:`TCPRule` to prevent the TC bit from being set over TCP transports. -.. function:: TeeAction(remote[, addECS[, local]]) +.. function:: TeeAction(remote[, addECS[, local [, addProxyProtocol]]]) .. versionchanged:: 1.8.0 Added the optional parameter ``local``. + .. versionchanged:: 1.9.0 + Added the optional parameter ``addProxyProtocol``. + Send copy of query to ``remote``, keep stats on responses. If ``addECS`` is set to true, EDNS Client Subnet information will be added to the query. + If ``addProxyProtocol`` is set to true, a Proxy Protocol v2 payload will be prepended in front of the query. If ``local`` has provided a value like "192.0.2.53", :program:`dnsdist` will try binding that address as local address when sending the queries. Subsequent rules are processed after this action. :param string remote: An IP:PORT combination to send the copied queries to - :param bool addECS: Whether or not to add ECS information. Default false + :param bool addECS: Whether to add ECS information. Default false. + :param str local: The local address to use to send queries. The default is to let the kernel pick one. + :param bool addProxyProtocol: Whether to prepend a proxy protocol v2 payload in front of the query. Default to false. .. function:: TempFailureCacheTTLAction(ttl) diff --git a/regression-tests.dnsdist/proxyprotocolutils.py b/regression-tests.dnsdist/proxyprotocolutils.py new file mode 100644 index 0000000000..9d0021911b --- /dev/null +++ b/regression-tests.dnsdist/proxyprotocolutils.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +import copy +import dns +import socket +import struct +import sys + +from proxyprotocol import ProxyProtocol + +def ProxyProtocolUDPResponder(port, fromQueue, toQueue): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + try: + sock.bind(("127.0.0.1", port)) + except socket.error as e: + print("Error binding in the Proxy Protocol UDP responder: %s" % str(e)) + sys.exit(1) + + while True: + data, addr = sock.recvfrom(4096) + + proxy = ProxyProtocol() + if len(data) < proxy.HEADER_SIZE: + continue + + if not proxy.parseHeader(data): + continue + + if proxy.local: + # likely a healthcheck + data = data[proxy.HEADER_SIZE:] + request = dns.message.from_wire(data) + response = dns.message.make_response(request) + wire = response.to_wire() + sock.settimeout(2.0) + sock.sendto(wire, addr) + sock.settimeout(None) + + continue + + payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)] + dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):] + toQueue.put([payload, dnsData], True, 2.0) + # computing the correct ID for the response + request = dns.message.from_wire(dnsData) + response = fromQueue.get(True, 2.0) + response.id = request.id + + sock.settimeout(2.0) + sock.sendto(response.to_wire(), addr) + sock.settimeout(None) + + sock.close() + +def ProxyProtocolTCPResponder(port, fromQueue, toQueue): + # be aware that this responder will not accept a new connection + # until the last one has been closed. This is done on purpose to + # to check for connection reuse, making sure that a lot of connections + # are not opened in parallel. + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + sock.bind(("127.0.0.1", port)) + except socket.error as e: + print("Error binding in the TCP responder: %s" % str(e)) + sys.exit(1) + + sock.listen(100) + while True: + (conn, _) = sock.accept() + conn.settimeout(5.0) + # try to read the entire Proxy Protocol header + proxy = ProxyProtocol() + header = conn.recv(proxy.HEADER_SIZE) + if not header: + conn.close() + continue + + if not proxy.parseHeader(header): + conn.close() + continue + + proxyContent = conn.recv(proxy.contentLen) + if not proxyContent: + conn.close() + continue + + payload = header + proxyContent + while True: + try: + data = conn.recv(2) + except socket.timeout: + data = None + + if not data: + conn.close() + break + + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) + + toQueue.put([payload, data], True, 2.0) + + response = copy.deepcopy(fromQueue.get(True, 2.0)) + if not response: + conn.close() + break + + # computing the correct ID for the response + request = dns.message.from_wire(data) + response.id = request.id + + wire = response.to_wire() + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + + conn.close() + + sock.close() diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index dd4ca4fbef..a29852c8b6 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import copy import dns import selectors import socket @@ -12,6 +11,7 @@ import time from dnsdisttests import DNSDistTest, pickAvailablePort from proxyprotocol import ProxyProtocol +from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder from dnsdistdohtests import DNSDistDOHTest # Python2/3 compatibility hacks @@ -20,118 +20,6 @@ try: except ImportError: from Queue import Queue -def ProxyProtocolUDPResponder(port, fromQueue, toQueue): - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - try: - sock.bind(("127.0.0.1", port)) - except socket.error as e: - print("Error binding in the Proxy Protocol UDP responder: %s" % str(e)) - sys.exit(1) - - while True: - data, addr = sock.recvfrom(4096) - - proxy = ProxyProtocol() - if len(data) < proxy.HEADER_SIZE: - continue - - if not proxy.parseHeader(data): - continue - - if proxy.local: - # likely a healthcheck - data = data[proxy.HEADER_SIZE:] - request = dns.message.from_wire(data) - response = dns.message.make_response(request) - wire = response.to_wire() - sock.settimeout(2.0) - sock.sendto(wire, addr) - sock.settimeout(None) - - continue - - payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)] - dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):] - toQueue.put([payload, dnsData], True, 2.0) - # computing the correct ID for the response - request = dns.message.from_wire(dnsData) - response = fromQueue.get(True, 2.0) - response.id = request.id - - sock.settimeout(2.0) - sock.sendto(response.to_wire(), addr) - sock.settimeout(None) - - sock.close() - -def ProxyProtocolTCPResponder(port, fromQueue, toQueue): - # be aware that this responder will not accept a new connection - # until the last one has been closed. This is done on purpose to - # to check for connection reuse, making sure that a lot of connections - # are not opened in parallel. - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - try: - sock.bind(("127.0.0.1", port)) - except socket.error as e: - print("Error binding in the TCP responder: %s" % str(e)) - sys.exit(1) - - sock.listen(100) - while True: - (conn, _) = sock.accept() - conn.settimeout(5.0) - # try to read the entire Proxy Protocol header - proxy = ProxyProtocol() - header = conn.recv(proxy.HEADER_SIZE) - if not header: - conn.close() - continue - - if not proxy.parseHeader(header): - conn.close() - continue - - proxyContent = conn.recv(proxy.contentLen) - if not proxyContent: - conn.close() - continue - - payload = header + proxyContent - while True: - try: - data = conn.recv(2) - except socket.timeout: - data = None - - if not data: - conn.close() - break - - (datalen,) = struct.unpack("!H", data) - data = conn.recv(datalen) - - toQueue.put([payload, data], True, 2.0) - - response = copy.deepcopy(fromQueue.get(True, 2.0)) - if not response: - conn.close() - break - - # computing the correct ID for the response - request = dns.message.from_wire(data) - response.id = request.id - - wire = response.to_wire() - conn.send(struct.pack("!H", len(wire))) - conn.send(wire) - - conn.close() - - sock.close() - toProxyQueue = Queue() fromProxyQueue = Queue() proxyResponderPort = pickAvailablePort() diff --git a/regression-tests.dnsdist/test_TeeAction.py b/regression-tests.dnsdist/test_TeeAction.py index 373d90f645..0516fbc505 100644 --- a/regression-tests.dnsdist/test_TeeAction.py +++ b/regression-tests.dnsdist/test_TeeAction.py @@ -4,22 +4,27 @@ import threading import clientsubnetoption import dns from dnsdisttests import DNSDistTest, Queue, pickAvailablePort +from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder class TestTeeAction(DNSDistTest): _consoleKey = DNSDistTest.generateConsoleKey() _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') _teeServerPort = pickAvailablePort() + _teeProxyServerPort = pickAvailablePort() _toTeeQueue = Queue() _fromTeeQueue = Queue() + _toTeeProxyQueue = Queue() + _fromTeeProxyQueue = Queue() _config_template = """ setKey("%s") controlSocket("127.0.0.1:%s") newServer{address="127.0.0.1:%d"} addAction(QTypeRule(DNSQType.A), TeeAction("127.0.0.1:%d", true)) addAction(QTypeRule(DNSQType.AAAA), TeeAction("127.0.0.1:%d", false)) + addAction(QTypeRule(DNSQType.ANY), TeeAction("127.0.0.1:%d", false, '127.0.0.1', true)) """ - _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort'] + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort', '_teeProxyServerPort'] @classmethod def startResponders(cls): print("Launching responders..") @@ -36,6 +41,10 @@ class TestTeeAction(DNSDistTest): cls._TeeResponder.daemon = True cls._TeeResponder.start() + cls._TeeProxyResponder = threading.Thread(name='Proxy Protocol Tee Responder', target=ProxyProtocolUDPResponder, args=[cls._teeProxyServerPort, cls._toTeeProxyQueue, cls._fromTeeProxyQueue]) + cls._TeeProxyResponder.daemon = True + cls._TeeProxyResponder.start() + def testTeeWithECS(self): """ TeeAction: ECS @@ -130,4 +139,50 @@ responses\t%d send-errors\t0 servfails\t0 tcp-drops\t0 +""" % (numberOfQueries, numberOfQueries, numberOfQueries)) + + def testTeeWithProxy(self): + """ + TeeAction: Proxy + """ + name = 'proxy.tee.tests.powerdns.com.' + query = dns.message.make_query(name, 'ANY', 'IN') + response = dns.message.make_response(query) + + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + numberOfQueries = 10 + for _ in range(numberOfQueries): + # push the response to the Tee Proxy server + self._toTeeProxyQueue.put(response, True, 2.0) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + # retrieve the query from the Tee Proxy server + [payload, teedQuery] = self._fromTeeProxyQueue.get(True, 2.0) + self.checkMessageNoEDNS(query, dns.message.from_wire(teedQuery)) + self.checkMessageProxyProtocol(payload, '127.0.0.1', '127.0.0.1', False) + + # check the TeeAction stats + stats = self.sendConsoleCommand("getAction(0):printStats()") + self.assertEqual(stats, """noerrors\t%d +nxdomains\t0 +other-rcode\t0 +queries\t%d +recv-errors\t0 +refuseds\t0 +responses\t%d +send-errors\t0 +servfails\t0 +tcp-drops\t0 """ % (numberOfQueries, numberOfQueries, numberOfQueries))