From: Remi Gacogne Date: Wed, 26 Feb 2020 11:20:00 +0000 (+0100) Subject: rec: Add regression tests for the proxy protocol X-Git-Tag: dnsdist-1.5.0-alpha1~12^2~17 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=949cd0f281d3749ad9d8ecd43a8b2d4e505164a2;p=thirdparty%2Fpdns.git rec: Add regression tests for the proxy protocol --- diff --git a/regression-tests.common/proxyprotocol.py b/regression-tests.common/proxyprotocol.py new file mode 100644 index 0000000000..0677b0dbbd --- /dev/null +++ b/regression-tests.common/proxyprotocol.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python + +import copy +import socket +import struct + +class ProxyProtocol(object): + MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A' + # Header is magic + versioncommand (1) + family (1) + content length (2) + HEADER_SIZE = len(MAGIC) + 1 + 1 + 2 + PORT_SIZE = 2 + + def consumed(self): + return self.offset + + def parseHeader(self, data): + if len(data) < self.HEADER_SIZE: + return False + + if data[:len(self.MAGIC)] != self.MAGIC: + return False + + value = struct.unpack('!B', bytes(bytearray([data[12]])))[0] + self.version = value >> 4 + if self.version != 0x02: + return False + + self.command = value & ~0x20 + self.local = False + self.offset = self.HEADER_SIZE + + if self.command == 0x00: + self.local = True + elif self.command == 0x01: + value = struct.unpack('!B', bytes(bytearray([data[13]])))[0] + self.family = value >> 4 + if self.family == 0x01: + self.addrSize = 4 + elif self.family == 0x02: + self.addrSize = 16 + else: + return False + + self.protocol = value & ~0xF0 + if self.protocol == 0x01: + self.tcp = True + elif self.protocol == 0x02: + self.tcp = False + else: + return False + else: + return False + + self.contentLen = struct.unpack("!H", data[14:16])[0] + + if not self.local: + if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2): + return False + + return True + + def getAddr(self, data): + if len(data) < (self.consumed() + self.addrSize): + return False + + value = None + if self.family == 0x01: + value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize]) + else: + value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize]) + + self.offset = self.offset + self.addrSize + return value + + def getPort(self, data): + if len(data) < (self.consumed() + self.PORT_SIZE): + return False + + value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0] + self.offset = self.offset + self.PORT_SIZE + return value + + def parseAddressesAndPorts(self, data): + if self.local: + return True + + if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2): + return False + + self.source = self.getAddr(data) + self.destination = self.getAddr(data) + self.sourcePort = self.getPort(data) + self.destinationPort = self.getPort(data) + return True + + def parseAdditionalValues(self, data): + self.values = [] + if self.local: + return True + + if len(data) < (self.HEADER_SIZE + self.contentLen): + return False + + remaining = self.HEADER_SIZE + self.contentLen - self.consumed() + if len(data) < remaining: + return False + + while remaining >= 3: + valueType = struct.unpack("!B", bytes(bytearray([data[self.offset]])))[0] + self.offset = self.offset + 1 + valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0] + self.offset = self.offset + 2 + + remaining = remaining - 3 + if valueLen > 0: + if valueLen > remaining: + return False + self.values.append([valueType, data[self.offset:self.offset+valueLen]]) + self.offset = self.offset + valueLen + remaining = remaining - valueLen + + else: + self.values.append([valueType, ""]) + + return True + + @classmethod + def getPayload(cls, local, tcp, v6, source, destination, sourcePort, destinationPort, values): + payload = copy.deepcopy(cls.MAGIC) + version = 0x02 + + if local: + command = 0x00 + else: + command = 0x01 + + value = struct.pack('!B', (version << 4) + command) + payload = payload + value + + addrSize = 0 + family = 0x00 + protocol = 0x00 + if not local: + if tcp: + protocol = 0x01 + else: + protocol = 0x02 + # sorry but compatibility with python 2 is awful for this, + # not going to waste time on it + if not v6: + family = 0x01 + addrSize = 4 + else: + family = 0x02 + addrSize = 16 + + value = struct.pack('!B', (family << 4) + protocol) + payload = payload + value + + contentSize = 0 + if not local: + contentSize = contentSize + addrSize * 2 + cls.PORT_SIZE *2 + + valuesSize = 0 + for value in values: + valuesSize = valuesSize + 3 + len(value[1]) + + contentSize = contentSize + valuesSize + + value = struct.pack('!H', contentSize) + payload = payload + value + + if not local: + if family == 0x01: + af = socket.AF_INET + else: + af = socket.AF_INET6 + + value = socket.inet_pton(af, source) + payload = payload + value + value = socket.inet_pton(af, destination) + payload = payload + value + value = struct.pack('!H', sourcePort) + payload = payload + value + value = struct.pack('!H', destinationPort) + payload = payload + value + + for value in values: + valueType = struct.pack('!B', value[0]) + valueLen = struct.pack('!H', len(value[1])) + payload = payload + valueType + valueLen + value[1] + + return payload diff --git a/regression-tests.dnsdist/proxyprotocol.py b/regression-tests.dnsdist/proxyprotocol.py deleted file mode 100644 index a7a74907a8..0000000000 --- a/regression-tests.dnsdist/proxyprotocol.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python - -import socket -import struct - -class ProxyProtocol(object): - MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A' - # Header is magic + versioncommand (1) + family (1) + content length (2) - HEADER_SIZE = len(MAGIC) + 1 + 1 + 2 - PORT_SIZE = 2 - - def consumed(self): - return self.offset - - def parseHeader(self, data): - if len(data) < self.HEADER_SIZE: - return False - - if data[:len(self.MAGIC)] != self.MAGIC: - return False - - value = struct.unpack('!B', bytes(bytearray([data[12]])))[0] - self.version = value >> 4 - if self.version != 0x02: - return False - - self.command = value & ~0x20 - self.local = False - self.offset = self.HEADER_SIZE - - if self.command == 0x00: - self.local = True - elif self.command == 0x01: - value = struct.unpack('!B', bytes(bytearray([data[13]])))[0] - self.family = value >> 4 - if self.family == 0x01: - self.addrSize = 4 - elif self.family == 0x02: - self.addrSize = 16 - else: - return False - - self.protocol = value & ~0xF0 - if self.protocol == 0x01: - self.tcp = True - elif self.protocol == 0x02: - self.tcp = False - else: - return False - else: - return False - - self.contentLen = struct.unpack("!H", data[14:16])[0] - - if not self.local: - if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2): - return False - - return True - - def getAddr(self, data): - if len(data) < (self.consumed() + self.addrSize): - return False - - value = None - if self.family == 0x01: - value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize]) - else: - value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize]) - - self.offset = self.offset + self.addrSize - return value - - def getPort(self, data): - if len(data) < (self.consumed() + self.PORT_SIZE): - return False - - value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0] - self.offset = self.offset + self.PORT_SIZE - return value - - def parseAddressesAndPorts(self, data): - if self.local: - return True - - if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2): - return False - - self.source = self.getAddr(data) - self.destination = self.getAddr(data) - self.sourcePort = self.getPort(data) - self.destinationPort = self.getPort(data) - return True - - def parseAdditionalValues(self, data): - self.values = [] - if self.local: - return True - - if len(data) < (self.HEADER_SIZE + self.contentLen): - return False - - remaining = self.HEADER_SIZE + self.contentLen - self.consumed() - if len(data) < remaining: - return False - - while remaining >= 3: - valueType = struct.unpack("!B", bytes(bytearray([data[self.offset]])))[0] - self.offset = self.offset + 1 - valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0] - self.offset = self.offset + 2 - - remaining = remaining - 3 - if valueLen > 0: - if valueLen > remaining: - return False - self.values.append([valueType, data[self.offset:self.offset+valueLen]]) - self.offset = self.offset + valueLen - remaining = remaining - valueLen - - else: - self.values.append([valueType, ""]) - - return True diff --git a/regression-tests.dnsdist/proxyprotocol.py b/regression-tests.dnsdist/proxyprotocol.py new file mode 120000 index 0000000000..2a3d79b075 --- /dev/null +++ b/regression-tests.dnsdist/proxyprotocol.py @@ -0,0 +1 @@ +../regression-tests.common/proxyprotocol.py \ No newline at end of file diff --git a/regression-tests.recursor-dnssec/proxyprotocol.py b/regression-tests.recursor-dnssec/proxyprotocol.py new file mode 120000 index 0000000000..2a3d79b075 --- /dev/null +++ b/regression-tests.recursor-dnssec/proxyprotocol.py @@ -0,0 +1 @@ +../regression-tests.common/proxyprotocol.py \ No newline at end of file diff --git a/regression-tests.recursor-dnssec/recursortests.py b/regression-tests.recursor-dnssec/recursortests.py index a6f688fc3d..2770e94477 100644 --- a/regression-tests.recursor-dnssec/recursortests.py +++ b/regression-tests.recursor-dnssec/recursortests.py @@ -40,6 +40,7 @@ max-cache-ttl=15 threads=1 loglevel=9 disable-syslog=yes +log-common-errors=yes """ _config_template = """ """ diff --git a/regression-tests.recursor-dnssec/test_ProxyProtocol.py b/regression-tests.recursor-dnssec/test_ProxyProtocol.py new file mode 100644 index 0000000000..d243462078 --- /dev/null +++ b/regression-tests.recursor-dnssec/test_ProxyProtocol.py @@ -0,0 +1,562 @@ +import dns +import os +import socket +import struct +import sys +import time + +from recursortests import RecursorTest +from proxyprotocol import ProxyProtocol + +class ProxyProtocolRecursorTest(RecursorTest): + + @classmethod + def setUpClass(cls): + + # we don't need all the auth stuff + cls.setUpSockets() + cls.startResponders() + + confdir = os.path.join('configs', cls._confdir) + cls.createConfigDir(confdir) + + cls.generateRecursorConfig(confdir) + cls.startRecursor(confdir, cls._recursorPort) + + @classmethod + def tearDownClass(cls): + cls.tearDownRecursor() + + @classmethod + def sendUDPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0): + queryPayload = query.to_wire() + ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values) + payload = ppPayload + queryPayload + + if timeout: + cls._sock.settimeout(timeout) + + try: + cls._sock.send(payload) + data = cls._sock.recv(4096) + except socket.timeout: + data = None + finally: + if timeout: + cls._sock.settimeout(None) + + message = None + if data: + message = dns.message.from_wire(data) + return message + + @classmethod + def sendTCPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0): + queryPayload = query.to_wire() + ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if timeout: + sock.settimeout(timeout) + + sock.connect(("127.0.0.1", cls._recursorPort)) + + try: + sock.send(ppPayload) + sock.send(struct.pack("!H", len(queryPayload))) + sock.send(queryPayload) + data = sock.recv(2) + if data: + (datalen,) = struct.unpack("!H", data) + data = sock.recv(datalen) + except socket.timeout as e: + print("Timeout: %s" % (str(e))) + data = None + except socket.error as e: + print("Network error: %s" % (str(e))) + data = None + finally: + sock.close() + + message = None + if data: + message = dns.message.from_wire(data) + return message + +class ProxyProtocolAllowedRecursorTest(ProxyProtocolRecursorTest): + _confdir = 'ProxyProtocol' + _lua_dns_script_file = """ + + function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp, proxyProtocolValues) + local remoteaddr = remote:toStringWithPort() + local localaddr = localip:toStringWithPort() + local foundFoo = false + local foundBar = false + + if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then + pdnslog('gettag: invalid source '..remoteaddr) + return 1 + end + if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then + pdnslog('gettag: invalid dest '..localaddr) + return 2 + end + + for k,v in pairs(proxyProtocolValues) do + local type = v:getType() + local content = v:getContent() + if type == 0 and content == 'foo' then + foundFoo = true + end + if type == 255 and content == 'bar' then + foundBar = true + end + end + + if not foundFoo or not foundBar then + pdnslog('gettag: TLV not found') + return 3 + end + + return 42 + end + + function preresolve(dq) + local foundFoo = false + local foundBar = false + local values = dq:getProxyProtocolValues() + for k,v in pairs(values) do + local type = v:getType() + local content = v:getContent() + if type == 0 and content == 'foo' then + foundFoo = true + end + if type == 255 and content == 'bar' then + foundBar = true + end + end + + if not foundFoo or not foundBar then + pdnslog('TLV not found') + dq:addAnswer(pdns.A, '192.0.2.255', 60) + return true + end + + local remoteaddr = dq.remoteaddr:toStringWithPort() + local localaddr = dq.localaddr:toStringWithPort() + + if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then + pdnslog('invalid source '..remoteaddr) + dq:addAnswer(pdns.A, '192.0.2.128', 60) + return true + end + if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then + pdnslog('invalid dest '..localaddr) + dq:addAnswer(pdns.A, '192.0.2.129', 60) + return true + end + + if dq.tag ~= 42 then + pdnslog('invalid tag '..dq.tag) + dq:addAnswer(pdns.A, '192.0.2.130', 60) + return true + end + + dq:addAnswer(pdns.A, '192.0.2.1', 60) + return true + end + """ + + _config_template = """ + proxy-protocol-from=127.0.0.1 + allow-from=127.0.0.0/24, ::1/128, ::42/128 +""" % () + + def testLocalProxyProtocol(self): + qname = 'local.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + queryPayload = query.to_wire() + ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, []) + payload = ppPayload + queryPayload + + # UDP + + self._sock.settimeout(2.0) + + try: + self._sock.send(payload) + data = self._sock.recv(4096) + except socket.timeout: + data = None + finally: + self._sock.settimeout(None) + + res = None + if data: + res = dns.message.from_wire(data) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + # TCP + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(("127.0.0.1", self._recursorPort)) + + try: + sock.send(ppPayload) + sock.send(struct.pack("!H", len(queryPayload))) + sock.send(queryPayload) + data = sock.recv(2) + if data: + (datalen,) = struct.unpack("!H", data) + data = sock.recv(datalen) + except socket.timeout as e: + print("Timeout: %s" % (str(e))) + data = None + except socket.error as e: + print("Network error: %s" % (str(e))) + data = None + finally: + sock.close() + + res = None + if data: + res = dns.message.from_wire(data) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testInvalidMagicProxyProtocol(self): + qname = 'invalid-magic.proxy-protocol.recursor-tests.powerdns.com.' + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + queryPayload = query.to_wire() + ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, []) + ppPayload = b'\x00' + ppPayload[1:] + payload = ppPayload + queryPayload + + # UDP + + self._sock.settimeout(2.0) + + try: + self._sock.send(payload) + data = self._sock.recv(4096) + except socket.timeout: + data = None + finally: + self._sock.settimeout(None) + + res = None + if data: + res = dns.message.from_wire(data) + self.assertEqual(res, None) + + # TCP + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(("127.0.0.1", self._recursorPort)) + + try: + sock.send(ppPayload) + sock.send(struct.pack("!H", len(queryPayload))) + sock.send(queryPayload) + data = sock.recv(2) + if data: + (datalen,) = struct.unpack("!H", data) + data = sock.recv(datalen) + except socket.timeout as e: + print("Timeout: %s" % (str(e))) + data = None + except socket.error as e: + print("Network error: %s" % (str(e))) + data = None + finally: + sock.close() + + res = None + if data: + res = dns.message.from_wire(data) + self.assertEqual(res, None) + + def testTCPOneByteAtATimeProxyProtocol(self): + qname = 'tcp-one-byte-at-a-time.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + queryPayload = query.to_wire() + ppPayload = ProxyProtocol.getPayload(False, True, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + + # TCP + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(("127.0.0.1", self._recursorPort)) + + try: + for i in range(len(ppPayload)): + sock.send(ppPayload[i:i+1]) + time.sleep(0.01) + value = struct.pack("!H", len(queryPayload)) + for i in range(len(value)): + sock.send(value[i:i+1]) + time.sleep(0.01) + for i in range(len(queryPayload)): + sock.send(queryPayload[i:i+1]) + time.sleep(0.01) + + data = sock.recv(2) + if data: + (datalen,) = struct.unpack("!H", data) + data = sock.recv(datalen) + except socket.timeout as e: + print("Timeout: %s" % (str(e))) + data = None + except socket.error as e: + print("Network error: %s" % (str(e))) + data = None + finally: + sock.close() + + res = None + if data: + res = dns.message.from_wire(data) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testNoHeaderProxyProtocol(self): + qname = 'no-header.proxy-protocol.recursor-tests.powerdns.com.' + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + res = sender(query) + self.assertEqual(res, None) + + def testIPv4ProxyProtocol(self): + qname = 'ipv4.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testIPv4NoValuesProxyProtocol(self): + qname = 'ipv4-no-values.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testIPv4ProxyProtocolNotAuthorized(self): + qname = 'ipv4-not-authorized.proxy-protocol.recursor-tests.powerdns.com.' + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, False, '192.0.2.255', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + self.assertEqual(res, None) + + def testIPv6ProxyProtocol(self): + qname = 'ipv6.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, True, '::42', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testIPv6NoValuesProxyProtocol(self): + qname = 'ipv6-no-values.proxy-protocol.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, True, '::42', '2001:db8::ff', 0, 65535) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testIPv6ProxyProtocolNotAuthorized(self): + qname = 'ipv6-not-authorized.proxy-protocol.recursor-tests.powerdns.com.' + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, True, '2001:db8::1', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + self.assertEqual(res, None) + +class ProxyProtocolAllowedFFIRecursorTest(ProxyProtocolAllowedRecursorTest): + # same tests than ProxyProtocolAllowedRecursorTest but with the Lua FFI interface instead of the regular one + _confdir = 'ProxyProtocolFFI' + _lua_dns_script_file = """ + local ffi = require("ffi") + + ffi.cdef[[ + typedef struct pdns_ffi_param pdns_ffi_param_t; + + typedef struct pdns_proxyprotocol_value { + uint8_t type; + uint16_t len; + const void* data; + } pdns_proxyprotocol_value_t; + + size_t pdns_ffi_param_get_proxy_protocol_values(pdns_ffi_param_t* ref, const pdns_proxyprotocol_value_t** out); + const char* pdns_ffi_param_get_remote(pdns_ffi_param_t* ref); + const char* pdns_ffi_param_get_local(pdns_ffi_param_t* ref); + uint16_t pdns_ffi_param_get_remote_port(const pdns_ffi_param_t* ref); + uint16_t pdns_ffi_param_get_local_port(const pdns_ffi_param_t* ref); + + void pdns_ffi_param_set_tag(pdns_ffi_param_t* ref, unsigned int tag); + ]] + + function gettag_ffi(obj) + local remoteaddr = ffi.string(ffi.C.pdns_ffi_param_get_remote(obj)) + local localaddr = ffi.string(ffi.C.pdns_ffi_param_get_local(obj)) + local foundFoo = false + local foundBar = false + + if remoteaddr ~= '127.0.0.42' and remoteaddr ~= '::42' then + pdnslog('gettag-ffi: invalid source '..remoteaddr) + ffi.C.pdns_ffi_param_set_tag(obj, 1) + return + end + if localaddr ~= '255.255.255.255' and localaddr ~= '2001:db8::ff' then + pdnslog('gettag-ffi: invalid dest '..localaddr) + ffi.C.pdns_ffi_param_set_tag(obj, 2) + return + end + + if ffi.C.pdns_ffi_param_get_remote_port(obj) ~= 0 then + pdnslog('gettag-ffi: invalid source port '..ffi.C.pdns_ffi_param_get_remote_port(obj)) + ffi.C.pdns_ffi_param_set_tag(obj, 1) + return + end + + if ffi.C.pdns_ffi_param_get_local_port(obj) ~= 65535 then + pdnslog('gettag-ffi: invalid source port '..ffi.C.pdns_ffi_param_get_local_port(obj)) + ffi.C.pdns_ffi_param_set_tag(obj, 2) + return + end + + local ret_ptr = ffi.new("const pdns_proxyprotocol_value_t *[1]") + local ret_ptr_param = ffi.cast("const pdns_proxyprotocol_value_t **", ret_ptr) + local values_count = ffi.C.pdns_ffi_param_get_proxy_protocol_values(obj, ret_ptr_param) + + if values_count > 0 then + for i = 0,tonumber(values_count)-1 do + local type = ret_ptr[0][i].type + local content = ffi.string(ret_ptr[0][i].data, ret_ptr[0][i].len) + if type == 0 and content == 'foo' then + foundFoo = true + end + if type == 255 and content == 'bar' then + foundBar = true + end + end + end + + if not foundFoo or not foundBar then + pdnslog('gettag-ffi: TLV not found') + ffi.C.pdns_ffi_param_set_tag(obj, 3) + return + end + + ffi.C.pdns_ffi_param_set_tag(obj, 42) + end + + function preresolve(dq) + local foundFoo = false + local foundBar = false + local values = dq:getProxyProtocolValues() + for k,v in pairs(values) do + local type = v:getType() + local content = v:getContent() + if type == 0 and content == 'foo' then + foundFoo = true + end + if type == 255 and content == 'bar' then + foundBar = true + end + end + + if not foundFoo or not foundBar then + pdnslog('TLV not found') + dq:addAnswer(pdns.A, '192.0.2.255', 60) + return true + end + + local remoteaddr = dq.remoteaddr:toStringWithPort() + local localaddr = dq.localaddr:toStringWithPort() + + if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then + pdnslog('invalid source '..remoteaddr) + dq:addAnswer(pdns.A, '192.0.2.128', 60) + return true + end + if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then + pdnslog('invalid dest '..localaddr) + dq:addAnswer(pdns.A, '192.0.2.129', 60) + return true + end + + if dq.tag ~= 42 then + pdnslog('invalid tag '..dq.tag) + dq:addAnswer(pdns.A, '192.0.2.130', 60) + return true + end + + dq:addAnswer(pdns.A, '192.0.2.1', 60) + return true + end + """ + + _config_template = """ + proxy-protocol-from=127.0.0.1 + allow-from=127.0.0.0/24, ::1/128, ::42/128 +""" % () + +class ProxyProtocolNotAllowedRecursorTest(ProxyProtocolRecursorTest): + _confdir = 'ProxyProtocolNotAllowed' + _lua_dns_script_file = """ + + function preresolve(dq) + dq:addAnswer(pdns.A, '192.0.2.1', 60) + return true + end + """ + + _config_template = """ + proxy-protocol-from=192.0.2.1/32 + allow-from=127.0.0.0/24, ::1/128 +""" % () + + def testNoHeaderProxyProtocol(self): + qname = 'no-header.proxy-protocol-not-allowed.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + res = sender(query) + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertRRsetInAnswer(res, expected) + + def testIPv4ProxyProtocol(self): + qname = 'ipv4.proxy-protocol-not-allowed.recursor-tests.powerdns.com.' + expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1') + + query = dns.message.make_query(qname, 'A', want_dnssec=True) + for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"): + sender = getattr(self, method) + res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ]) + self.assertEqual(res, None)