--- /dev/null
+#!/usr/bin/env python
+
+import copy
+import socket
+import struct
+
+class ProxyProtocol(object):
+ MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
+ # Header is magic + versioncommand (1) + family (1) + content length (2)
+ HEADER_SIZE = len(MAGIC) + 1 + 1 + 2
+ PORT_SIZE = 2
+
+ def consumed(self):
+ return self.offset
+
+ def parseHeader(self, data):
+ if len(data) < self.HEADER_SIZE:
+ return False
+
+ if data[:len(self.MAGIC)] != self.MAGIC:
+ return False
+
+ value = struct.unpack('!B', bytes(bytearray([data[12]])))[0]
+ self.version = value >> 4
+ if self.version != 0x02:
+ return False
+
+ self.command = value & ~0x20
+ self.local = False
+ self.offset = self.HEADER_SIZE
+
+ if self.command == 0x00:
+ self.local = True
+ elif self.command == 0x01:
+ value = struct.unpack('!B', bytes(bytearray([data[13]])))[0]
+ self.family = value >> 4
+ if self.family == 0x01:
+ self.addrSize = 4
+ elif self.family == 0x02:
+ self.addrSize = 16
+ else:
+ return False
+
+ self.protocol = value & ~0xF0
+ if self.protocol == 0x01:
+ self.tcp = True
+ elif self.protocol == 0x02:
+ self.tcp = False
+ else:
+ return False
+ else:
+ return False
+
+ self.contentLen = struct.unpack("!H", data[14:16])[0]
+
+ if not self.local:
+ if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2):
+ return False
+
+ return True
+
+ def getAddr(self, data):
+ if len(data) < (self.consumed() + self.addrSize):
+ return False
+
+ value = None
+ if self.family == 0x01:
+ value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize])
+ else:
+ value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize])
+
+ self.offset = self.offset + self.addrSize
+ return value
+
+ def getPort(self, data):
+ if len(data) < (self.consumed() + self.PORT_SIZE):
+ return False
+
+ value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0]
+ self.offset = self.offset + self.PORT_SIZE
+ return value
+
+ def parseAddressesAndPorts(self, data):
+ if self.local:
+ return True
+
+ if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2):
+ return False
+
+ self.source = self.getAddr(data)
+ self.destination = self.getAddr(data)
+ self.sourcePort = self.getPort(data)
+ self.destinationPort = self.getPort(data)
+ return True
+
+ def parseAdditionalValues(self, data):
+ self.values = []
+ if self.local:
+ return True
+
+ if len(data) < (self.HEADER_SIZE + self.contentLen):
+ return False
+
+ remaining = self.HEADER_SIZE + self.contentLen - self.consumed()
+ if len(data) < remaining:
+ return False
+
+ while remaining >= 3:
+ valueType = struct.unpack("!B", bytes(bytearray([data[self.offset]])))[0]
+ self.offset = self.offset + 1
+ valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0]
+ self.offset = self.offset + 2
+
+ remaining = remaining - 3
+ if valueLen > 0:
+ if valueLen > remaining:
+ return False
+ self.values.append([valueType, data[self.offset:self.offset+valueLen]])
+ self.offset = self.offset + valueLen
+ remaining = remaining - valueLen
+
+ else:
+ self.values.append([valueType, ""])
+
+ return True
+
+ @classmethod
+ def getPayload(cls, local, tcp, v6, source, destination, sourcePort, destinationPort, values):
+ payload = copy.deepcopy(cls.MAGIC)
+ version = 0x02
+
+ if local:
+ command = 0x00
+ else:
+ command = 0x01
+
+ value = struct.pack('!B', (version << 4) + command)
+ payload = payload + value
+
+ addrSize = 0
+ family = 0x00
+ protocol = 0x00
+ if not local:
+ if tcp:
+ protocol = 0x01
+ else:
+ protocol = 0x02
+ # sorry but compatibility with python 2 is awful for this,
+ # not going to waste time on it
+ if not v6:
+ family = 0x01
+ addrSize = 4
+ else:
+ family = 0x02
+ addrSize = 16
+
+ value = struct.pack('!B', (family << 4) + protocol)
+ payload = payload + value
+
+ contentSize = 0
+ if not local:
+ contentSize = contentSize + addrSize * 2 + cls.PORT_SIZE *2
+
+ valuesSize = 0
+ for value in values:
+ valuesSize = valuesSize + 3 + len(value[1])
+
+ contentSize = contentSize + valuesSize
+
+ value = struct.pack('!H', contentSize)
+ payload = payload + value
+
+ if not local:
+ if family == 0x01:
+ af = socket.AF_INET
+ else:
+ af = socket.AF_INET6
+
+ value = socket.inet_pton(af, source)
+ payload = payload + value
+ value = socket.inet_pton(af, destination)
+ payload = payload + value
+ value = struct.pack('!H', sourcePort)
+ payload = payload + value
+ value = struct.pack('!H', destinationPort)
+ payload = payload + value
+
+ for value in values:
+ valueType = struct.pack('!B', value[0])
+ valueLen = struct.pack('!H', len(value[1]))
+ payload = payload + valueType + valueLen + value[1]
+
+ return payload
--- /dev/null
+import dns
+import os
+import socket
+import struct
+import sys
+import time
+
+from recursortests import RecursorTest
+from proxyprotocol import ProxyProtocol
+
+class ProxyProtocolRecursorTest(RecursorTest):
+
+ @classmethod
+ def setUpClass(cls):
+
+ # we don't need all the auth stuff
+ cls.setUpSockets()
+ cls.startResponders()
+
+ confdir = os.path.join('configs', cls._confdir)
+ cls.createConfigDir(confdir)
+
+ cls.generateRecursorConfig(confdir)
+ cls.startRecursor(confdir, cls._recursorPort)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.tearDownRecursor()
+
+ @classmethod
+ def sendUDPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0):
+ queryPayload = query.to_wire()
+ ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values)
+ payload = ppPayload + queryPayload
+
+ if timeout:
+ cls._sock.settimeout(timeout)
+
+ try:
+ cls._sock.send(payload)
+ data = cls._sock.recv(4096)
+ except socket.timeout:
+ data = None
+ finally:
+ if timeout:
+ cls._sock.settimeout(None)
+
+ message = None
+ if data:
+ message = dns.message.from_wire(data)
+ return message
+
+ @classmethod
+ def sendTCPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0):
+ queryPayload = query.to_wire()
+ ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values)
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if timeout:
+ sock.settimeout(timeout)
+
+ sock.connect(("127.0.0.1", cls._recursorPort))
+
+ try:
+ sock.send(ppPayload)
+ sock.send(struct.pack("!H", len(queryPayload)))
+ sock.send(queryPayload)
+ data = sock.recv(2)
+ if data:
+ (datalen,) = struct.unpack("!H", data)
+ data = sock.recv(datalen)
+ except socket.timeout as e:
+ print("Timeout: %s" % (str(e)))
+ data = None
+ except socket.error as e:
+ print("Network error: %s" % (str(e)))
+ data = None
+ finally:
+ sock.close()
+
+ message = None
+ if data:
+ message = dns.message.from_wire(data)
+ return message
+
+class ProxyProtocolAllowedRecursorTest(ProxyProtocolRecursorTest):
+ _confdir = 'ProxyProtocol'
+ _lua_dns_script_file = """
+
+ function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp, proxyProtocolValues)
+ local remoteaddr = remote:toStringWithPort()
+ local localaddr = localip:toStringWithPort()
+ local foundFoo = false
+ local foundBar = false
+
+ if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then
+ pdnslog('gettag: invalid source '..remoteaddr)
+ return 1
+ end
+ if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then
+ pdnslog('gettag: invalid dest '..localaddr)
+ return 2
+ end
+
+ for k,v in pairs(proxyProtocolValues) do
+ local type = v:getType()
+ local content = v:getContent()
+ if type == 0 and content == 'foo' then
+ foundFoo = true
+ end
+ if type == 255 and content == 'bar' then
+ foundBar = true
+ end
+ end
+
+ if not foundFoo or not foundBar then
+ pdnslog('gettag: TLV not found')
+ return 3
+ end
+
+ return 42
+ end
+
+ function preresolve(dq)
+ local foundFoo = false
+ local foundBar = false
+ local values = dq:getProxyProtocolValues()
+ for k,v in pairs(values) do
+ local type = v:getType()
+ local content = v:getContent()
+ if type == 0 and content == 'foo' then
+ foundFoo = true
+ end
+ if type == 255 and content == 'bar' then
+ foundBar = true
+ end
+ end
+
+ if not foundFoo or not foundBar then
+ pdnslog('TLV not found')
+ dq:addAnswer(pdns.A, '192.0.2.255', 60)
+ return true
+ end
+
+ local remoteaddr = dq.remoteaddr:toStringWithPort()
+ local localaddr = dq.localaddr:toStringWithPort()
+
+ if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then
+ pdnslog('invalid source '..remoteaddr)
+ dq:addAnswer(pdns.A, '192.0.2.128', 60)
+ return true
+ end
+ if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then
+ pdnslog('invalid dest '..localaddr)
+ dq:addAnswer(pdns.A, '192.0.2.129', 60)
+ return true
+ end
+
+ if dq.tag ~= 42 then
+ pdnslog('invalid tag '..dq.tag)
+ dq:addAnswer(pdns.A, '192.0.2.130', 60)
+ return true
+ end
+
+ dq:addAnswer(pdns.A, '192.0.2.1', 60)
+ return true
+ end
+ """
+
+ _config_template = """
+ proxy-protocol-from=127.0.0.1
+ allow-from=127.0.0.0/24, ::1/128, ::42/128
+""" % ()
+
+ def testLocalProxyProtocol(self):
+ qname = 'local.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ queryPayload = query.to_wire()
+ ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, [])
+ payload = ppPayload + queryPayload
+
+ # UDP
+
+ self._sock.settimeout(2.0)
+
+ try:
+ self._sock.send(payload)
+ data = self._sock.recv(4096)
+ except socket.timeout:
+ data = None
+ finally:
+ self._sock.settimeout(None)
+
+ res = None
+ if data:
+ res = dns.message.from_wire(data)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ # TCP
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(2.0)
+ sock.connect(("127.0.0.1", self._recursorPort))
+
+ try:
+ sock.send(ppPayload)
+ sock.send(struct.pack("!H", len(queryPayload)))
+ sock.send(queryPayload)
+ data = sock.recv(2)
+ if data:
+ (datalen,) = struct.unpack("!H", data)
+ data = sock.recv(datalen)
+ except socket.timeout as e:
+ print("Timeout: %s" % (str(e)))
+ data = None
+ except socket.error as e:
+ print("Network error: %s" % (str(e)))
+ data = None
+ finally:
+ sock.close()
+
+ res = None
+ if data:
+ res = dns.message.from_wire(data)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testInvalidMagicProxyProtocol(self):
+ qname = 'invalid-magic.proxy-protocol.recursor-tests.powerdns.com.'
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ queryPayload = query.to_wire()
+ ppPayload = ProxyProtocol.getPayload(True, False, False, None, None, None, None, [])
+ ppPayload = b'\x00' + ppPayload[1:]
+ payload = ppPayload + queryPayload
+
+ # UDP
+
+ self._sock.settimeout(2.0)
+
+ try:
+ self._sock.send(payload)
+ data = self._sock.recv(4096)
+ except socket.timeout:
+ data = None
+ finally:
+ self._sock.settimeout(None)
+
+ res = None
+ if data:
+ res = dns.message.from_wire(data)
+ self.assertEqual(res, None)
+
+ # TCP
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(2.0)
+ sock.connect(("127.0.0.1", self._recursorPort))
+
+ try:
+ sock.send(ppPayload)
+ sock.send(struct.pack("!H", len(queryPayload)))
+ sock.send(queryPayload)
+ data = sock.recv(2)
+ if data:
+ (datalen,) = struct.unpack("!H", data)
+ data = sock.recv(datalen)
+ except socket.timeout as e:
+ print("Timeout: %s" % (str(e)))
+ data = None
+ except socket.error as e:
+ print("Network error: %s" % (str(e)))
+ data = None
+ finally:
+ sock.close()
+
+ res = None
+ if data:
+ res = dns.message.from_wire(data)
+ self.assertEqual(res, None)
+
+ def testTCPOneByteAtATimeProxyProtocol(self):
+ qname = 'tcp-one-byte-at-a-time.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ queryPayload = query.to_wire()
+ ppPayload = ProxyProtocol.getPayload(False, True, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+
+ # TCP
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(2.0)
+ sock.connect(("127.0.0.1", self._recursorPort))
+
+ try:
+ for i in range(len(ppPayload)):
+ sock.send(ppPayload[i:i+1])
+ time.sleep(0.01)
+ value = struct.pack("!H", len(queryPayload))
+ for i in range(len(value)):
+ sock.send(value[i:i+1])
+ time.sleep(0.01)
+ for i in range(len(queryPayload)):
+ sock.send(queryPayload[i:i+1])
+ time.sleep(0.01)
+
+ data = sock.recv(2)
+ if data:
+ (datalen,) = struct.unpack("!H", data)
+ data = sock.recv(datalen)
+ except socket.timeout as e:
+ print("Timeout: %s" % (str(e)))
+ data = None
+ except socket.error as e:
+ print("Network error: %s" % (str(e)))
+ data = None
+ finally:
+ sock.close()
+
+ res = None
+ if data:
+ res = dns.message.from_wire(data)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testNoHeaderProxyProtocol(self):
+ qname = 'no-header.proxy-protocol.recursor-tests.powerdns.com.'
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ res = sender(query)
+ self.assertEqual(res, None)
+
+ def testIPv4ProxyProtocol(self):
+ qname = 'ipv4.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testIPv4NoValuesProxyProtocol(self):
+ qname = 'ipv4-no-values.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testIPv4ProxyProtocolNotAuthorized(self):
+ qname = 'ipv4-not-authorized.proxy-protocol.recursor-tests.powerdns.com.'
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, False, '192.0.2.255', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+ self.assertEqual(res, None)
+
+ def testIPv6ProxyProtocol(self):
+ qname = 'ipv6.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, True, '::42', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testIPv6NoValuesProxyProtocol(self):
+ qname = 'ipv6-no-values.proxy-protocol.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.255')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, True, '::42', '2001:db8::ff', 0, 65535)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testIPv6ProxyProtocolNotAuthorized(self):
+ qname = 'ipv6-not-authorized.proxy-protocol.recursor-tests.powerdns.com.'
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, True, '2001:db8::1', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+ self.assertEqual(res, None)
+
+class ProxyProtocolAllowedFFIRecursorTest(ProxyProtocolAllowedRecursorTest):
+ # same tests than ProxyProtocolAllowedRecursorTest but with the Lua FFI interface instead of the regular one
+ _confdir = 'ProxyProtocolFFI'
+ _lua_dns_script_file = """
+ local ffi = require("ffi")
+
+ ffi.cdef[[
+ typedef struct pdns_ffi_param pdns_ffi_param_t;
+
+ typedef struct pdns_proxyprotocol_value {
+ uint8_t type;
+ uint16_t len;
+ const void* data;
+ } pdns_proxyprotocol_value_t;
+
+ size_t pdns_ffi_param_get_proxy_protocol_values(pdns_ffi_param_t* ref, const pdns_proxyprotocol_value_t** out);
+ const char* pdns_ffi_param_get_remote(pdns_ffi_param_t* ref);
+ const char* pdns_ffi_param_get_local(pdns_ffi_param_t* ref);
+ uint16_t pdns_ffi_param_get_remote_port(const pdns_ffi_param_t* ref);
+ uint16_t pdns_ffi_param_get_local_port(const pdns_ffi_param_t* ref);
+
+ void pdns_ffi_param_set_tag(pdns_ffi_param_t* ref, unsigned int tag);
+ ]]
+
+ function gettag_ffi(obj)
+ local remoteaddr = ffi.string(ffi.C.pdns_ffi_param_get_remote(obj))
+ local localaddr = ffi.string(ffi.C.pdns_ffi_param_get_local(obj))
+ local foundFoo = false
+ local foundBar = false
+
+ if remoteaddr ~= '127.0.0.42' and remoteaddr ~= '::42' then
+ pdnslog('gettag-ffi: invalid source '..remoteaddr)
+ ffi.C.pdns_ffi_param_set_tag(obj, 1)
+ return
+ end
+ if localaddr ~= '255.255.255.255' and localaddr ~= '2001:db8::ff' then
+ pdnslog('gettag-ffi: invalid dest '..localaddr)
+ ffi.C.pdns_ffi_param_set_tag(obj, 2)
+ return
+ end
+
+ if ffi.C.pdns_ffi_param_get_remote_port(obj) ~= 0 then
+ pdnslog('gettag-ffi: invalid source port '..ffi.C.pdns_ffi_param_get_remote_port(obj))
+ ffi.C.pdns_ffi_param_set_tag(obj, 1)
+ return
+ end
+
+ if ffi.C.pdns_ffi_param_get_local_port(obj) ~= 65535 then
+ pdnslog('gettag-ffi: invalid source port '..ffi.C.pdns_ffi_param_get_local_port(obj))
+ ffi.C.pdns_ffi_param_set_tag(obj, 2)
+ return
+ end
+
+ local ret_ptr = ffi.new("const pdns_proxyprotocol_value_t *[1]")
+ local ret_ptr_param = ffi.cast("const pdns_proxyprotocol_value_t **", ret_ptr)
+ local values_count = ffi.C.pdns_ffi_param_get_proxy_protocol_values(obj, ret_ptr_param)
+
+ if values_count > 0 then
+ for i = 0,tonumber(values_count)-1 do
+ local type = ret_ptr[0][i].type
+ local content = ffi.string(ret_ptr[0][i].data, ret_ptr[0][i].len)
+ if type == 0 and content == 'foo' then
+ foundFoo = true
+ end
+ if type == 255 and content == 'bar' then
+ foundBar = true
+ end
+ end
+ end
+
+ if not foundFoo or not foundBar then
+ pdnslog('gettag-ffi: TLV not found')
+ ffi.C.pdns_ffi_param_set_tag(obj, 3)
+ return
+ end
+
+ ffi.C.pdns_ffi_param_set_tag(obj, 42)
+ end
+
+ function preresolve(dq)
+ local foundFoo = false
+ local foundBar = false
+ local values = dq:getProxyProtocolValues()
+ for k,v in pairs(values) do
+ local type = v:getType()
+ local content = v:getContent()
+ if type == 0 and content == 'foo' then
+ foundFoo = true
+ end
+ if type == 255 and content == 'bar' then
+ foundBar = true
+ end
+ end
+
+ if not foundFoo or not foundBar then
+ pdnslog('TLV not found')
+ dq:addAnswer(pdns.A, '192.0.2.255', 60)
+ return true
+ end
+
+ local remoteaddr = dq.remoteaddr:toStringWithPort()
+ local localaddr = dq.localaddr:toStringWithPort()
+
+ if remoteaddr ~= '127.0.0.42:0' and remoteaddr ~= '[::42]:0' then
+ pdnslog('invalid source '..remoteaddr)
+ dq:addAnswer(pdns.A, '192.0.2.128', 60)
+ return true
+ end
+ if localaddr ~= '255.255.255.255:65535' and localaddr ~= '[2001:db8::ff]:65535' then
+ pdnslog('invalid dest '..localaddr)
+ dq:addAnswer(pdns.A, '192.0.2.129', 60)
+ return true
+ end
+
+ if dq.tag ~= 42 then
+ pdnslog('invalid tag '..dq.tag)
+ dq:addAnswer(pdns.A, '192.0.2.130', 60)
+ return true
+ end
+
+ dq:addAnswer(pdns.A, '192.0.2.1', 60)
+ return true
+ end
+ """
+
+ _config_template = """
+ proxy-protocol-from=127.0.0.1
+ allow-from=127.0.0.0/24, ::1/128, ::42/128
+""" % ()
+
+class ProxyProtocolNotAllowedRecursorTest(ProxyProtocolRecursorTest):
+ _confdir = 'ProxyProtocolNotAllowed'
+ _lua_dns_script_file = """
+
+ function preresolve(dq)
+ dq:addAnswer(pdns.A, '192.0.2.1', 60)
+ return true
+ end
+ """
+
+ _config_template = """
+ proxy-protocol-from=192.0.2.1/32
+ allow-from=127.0.0.0/24, ::1/128
+""" % ()
+
+ def testNoHeaderProxyProtocol(self):
+ qname = 'no-header.proxy-protocol-not-allowed.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ res = sender(query)
+ self.assertRcodeEqual(res, dns.rcode.NOERROR)
+ self.assertRRsetInAnswer(res, expected)
+
+ def testIPv4ProxyProtocol(self):
+ qname = 'ipv4.proxy-protocol-not-allowed.recursor-tests.powerdns.com.'
+ expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+ query = dns.message.make_query(qname, 'A', want_dnssec=True)
+ for method in ("sendUDPQueryWithProxyProtocol", "sendTCPQueryWithProxyProtocol"):
+ sender = getattr(self, method)
+ res = sender(query, False, '127.0.0.42', '255.255.255.255', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+ self.assertEqual(res, None)