From: Remi Gacogne Date: Thu, 15 Oct 2020 15:34:17 +0000 (+0200) Subject: dnsdist: Add regression tests for incoming Proxy Protocol X-Git-Tag: rec-4.5.0-alpha1~19^2~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f7ec81e20b4feeaf205d71b8e1c0eef8cda7db80;p=thirdparty%2Fpdns.git dnsdist: Add regression tests for incoming Proxy Protocol --- diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 9dd9a3fac0..761caee597 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -142,12 +142,15 @@ class ProxyProtocolTest(DNSDistTest): _proxyResponderPort = proxyResponderPort _config_params = ['_proxyResponderPort'] - def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]): + def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): proxy = ProxyProtocol() self.assertTrue(proxy.parseHeader(receivedProxyPayload)) self.assertEquals(proxy.version, 0x02) self.assertEquals(proxy.command, 0x01) - self.assertEquals(proxy.family, 0x01) + if v6: + self.assertEquals(proxy.family, 0x02) + else: + self.assertEquals(proxy.family, 0x01) if not isTCP: self.assertEquals(proxy.protocol, 0x02) else: @@ -157,8 +160,12 @@ class ProxyProtocolTest(DNSDistTest): self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) self.assertEquals(proxy.source, source) self.assertEquals(proxy.destination, destination) - #self.assertEquals(proxy.sourcePort, sourcePort) - self.assertEquals(proxy.destinationPort, self._dnsDistPort) + if sourcePort: + self.assertEquals(proxy.sourcePort, sourcePort) + if destinationPort: + self.assertEquals(proxy.destinationPort, destinationPort) + else: + self.assertEquals(proxy.destinationPort, self._dnsDistPort) self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) proxy.values.sort() @@ -410,3 +417,337 @@ class TestProxyProtocol(ProxyProtocolTest): self.assertEquals(receivedQuery, query) self.assertEquals(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, []) + +class TestProxyProtocolIncoming(ProxyProtocolTest): + """ + dnsdist is configured to prepend a Proxy Protocol header to the query and expect one on incoming queries + """ + + _config_template = """ + setProxyProtocolACL( { "127.0.0.1/32" } ) + newServer{address="127.0.0.1:%d", useProxyProtocol=true} + + function addValues(dq) + dq:addProxyProtocolValue(0, 'foo') + dq:addProxyProtocolValue(42, 'bar') + return DNSAction.None + end + + -- refuse queries with no TLV value type 2 + addAction(NotRule(ProxyProtocolValueRule(2)), RCodeAction(DNSRCode.REFUSED)) + -- or with a TLV value type 3 different from "proxy" + addAction(NotRule(ProxyProtocolValueRule(3, "proxy")), RCodeAction(DNSRCode.REFUSED)) + + function answerBasedOnForwardedDest(dq) + local port = dq.localaddr:getPort() + local dest = dq.localaddr:toString() + return DNSAction.Spoof, "address-was-"..dest.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com." + end + addAction("get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedDest)) + + function answerBasedOnForwardedSrc(dq) + local port = dq.remoteaddr:getPort() + local src = dq.remoteaddr:toString() + return DNSAction.Spoof, "address-was-"..src.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com." + end + addAction("get-forwarded-src.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedSrc)) + + -- add these values for all queries + addAction("proxy-protocol-incoming.tests.powerdns.com.", LuaAction(addValues)) + addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(1, "dnsdist")) + addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(255, "proxy-protocol")) + + -- override all existing values + addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"})) + """ + _config_params = ['_proxyResponderPort'] + _verboseMode = True + + def testNoHeader(self): + """ + Incoming Proxy Protocol: no header + """ + # no proxy protocol header while one is expected, should be dropped + name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None) + self.assertEquals(receivedResponse, None) + + def testIncomingProxyDest(self): + """ + Incoming Proxy Protocol: values from Lua + """ + name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + + destAddr = "2001:db8::9" + destPort = 9999 + srcAddr = "2001:db8::8" + srcPort = 8888 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.CNAME, + "address-was-{}-port-was-{}.proxy-protocol-incoming.tests.powerdns.com.".format(destAddr, destPort, self._dnsDistPort)) + response.answer.append(rrset) + + udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) + self.assertEquals(receivedResponse, response) + + tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + wire = query.to_wire() + + receivedResponse = None + try: + conn = self.openTCPConnection(2.0) + conn.send(tcpPayload) + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + self.assertEquals(receivedResponse, response) + + def testProxyUDPWithValuesFromLua(self): + """ + Incoming Proxy Protocol: values from Lua (UDP) + """ + name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + destAddr = "2001:db8::9" + destPort = 9999 + srcAddr = "2001:db8::8" + srcPort = 8888 + response = dns.message.make_response(query) + + udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + toProxyQueue.put(response, True, 2.0) + (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) + + def testProxyTCPWithValuesFromLua(self): + """ + Incoming Proxy Protocol: values from Lua (TCP) + """ + name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + destAddr = "2001:db8::9" + destPort = 9999 + srcAddr = "2001:db8::8" + srcPort = 8888 + response = dns.message.make_response(query) + + tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + + toProxyQueue.put(response, True, 2.0) + + wire = query.to_wire() + + receivedResponse = None + try: + conn = self.openTCPConnection(2.0) + conn.send(tcpPayload) + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + self.assertEquals(receivedResponse, response) + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) + + def testProxyUDPWithValueOverride(self): + """ + Incoming Proxy Protocol: override existing value (UDP) + """ + name = 'override.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + destAddr = "2001:db8::9" + destPort = 9999 + srcAddr = "2001:db8::8" + srcPort = 8888 + response = dns.message.make_response(query) + + udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [2, b'foo'], [3, b'proxy'], [ 50, b'initial-value']]) + toProxyQueue.put(response, True, 2.0) + (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [50, b'overridden'] ], True, srcPort, destPort) + + def testProxyTCPSeveralQueriesOverConnection(self): + """ + Incoming Proxy Protocol: Several queries over the same connection (TCP) + """ + name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + destAddr = "2001:db8::9" + destPort = 9999 + srcAddr = "2001:db8::8" + srcPort = 8888 + response = dns.message.make_response(query) + + tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + + toProxyQueue.put(response, True, 2.0) + + wire = query.to_wire() + + receivedResponse = None + conn = self.openTCPConnection(2.0) + try: + conn.send(tcpPayload) + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + self.assertEquals(receivedResponse, response) + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) + + for idx in range(5): + receivedResponse = None + toProxyQueue.put(response, True, 2.0) + try: + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + + self.assertEquals(receivedResponse, response) + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) + +class TestProxyProtocolNotExpected(DNSDistTest): + """ + dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1 + """ + + _config_template = """ + setProxyProtocolACL( { "192.0.2.1/32" } ) + newServer{address="127.0.0.1:%d"} + """ + # NORMAL responder, does not expect a proxy protocol payload! + _config_params = ['_testServerPort'] + _verboseMode = True + + def testNoHeader(self): + """ + Unexpected Proxy Protocol: no header + """ + # no proxy protocol header and none is expected from this source, should be passed on + name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + def testIncomingProxyDest(self): + """ + Unexpected Proxy Protocol: should be dropped + """ + name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + # Make sure that the proxy payload does NOT turn into a legal qname + destAddr = "ff:db8::ffff" + destPort = 65535 + srcAddr = "ff:db8::ffff" + srcPort = 65535 + + udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True) + self.assertEquals(receivedResponse, None) + + tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + wire = query.to_wire() + + receivedResponse = None + try: + conn = self.openTCPConnection(2.0) + conn.send(tcpPayload) + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + self.assertEquals(receivedResponse, None)