--- /dev/null
+#!/usr/bin/env python
+
+import socket
+import struct
+
+class ProxyProtocol(object):
+ MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
+ # Header is magic + versioncommand (1) + family (1) + content length (2)
+ HEADER_SIZE = len(MAGIC) + 1 + 1 + 2
+ PORT_SIZE = 2
+
+ def consumed(self):
+ return self.offset
+
+ def parseHeader(self, data):
+ if len(data) < self.HEADER_SIZE:
+ return False
+
+ if data[:len(self.MAGIC)] != self.MAGIC:
+ return False
+
+ self.version = int(data[12]) >> 4
+ if self.version != 0x02:
+ return False
+
+ self.command = int(data[12]) & ~0x20
+ self.local = False
+ self.offset = self.HEADER_SIZE
+
+ if self.command == 0x00:
+ self.local = True
+ elif self.command == 0x01:
+ self.family = int(data[13]) >> 4
+ if self.family == 0x01:
+ self.addrSize = 4
+ elif self.family == 0x02:
+ self.addrSize = 16
+ else:
+ return False
+
+ self.protocol = int(data[13]) & ~0xF0
+ if self.protocol == 0x01:
+ self.tcp = True
+ elif self.protocol == 0x02:
+ self.tcp = False
+ else:
+ return False
+ else:
+ return False
+
+ self.contentLen = struct.unpack("!H", data[14:16])[0]
+
+ if not self.local:
+ if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2):
+ return False
+
+ return True
+
+ def getAddr(self, data):
+ if len(data) < (self.consumed() + self.addrSize):
+ return False
+
+ value = None
+ if self.family == 0x01:
+ value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize])
+ else:
+ value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize])
+
+ self.offset = self.offset + self.addrSize
+ return value
+
+ def getPort(self, data):
+ if len(data) < (self.consumed() + self.PORT_SIZE):
+ return False
+
+ value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0]
+ self.offset = self.offset + self.PORT_SIZE
+ return value
+
+ def parseAddressesAndPorts(self, data):
+ if self.local:
+ return True
+
+ if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2):
+ return False
+
+ self.source = self.getAddr(data)
+ self.destination = self.getAddr(data)
+ self.sourcePort = self.getPort(data)
+ self.destinationPort = self.getPort(data)
+ return True
+
+ def parseAdditionalValues(self, data):
+ self.values = []
+ if self.local:
+ return True
+
+ if len(data) < (self.HEADER_SIZE + self.contentLen):
+ return False
+
+ remaining = self.HEADER_SIZE + self.contentLen - self.consumed()
+ if len(data) < remaining:
+ return False
+
+ while remaining >= 3:
+ valueType = data[self.offset]
+ self.offset = self.offset + 1
+ valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0]
+ self.offset = self.offset + 2
+
+ remaining = remaining - 3
+ if valueLen > 0:
+ if valueLen > remaining:
+ return False
+ self.values.append([valueType, data[self.offset:self.offset+valueLen]])
+ self.offset = self.offset + valueLen
+ remaining = remaining - valueLen
+
+ else:
+ self.values.append([valueType, ""])
+
+ return True
--- /dev/null
+#!/usr/bin/env python
+
+import dns
+import socket
+import struct
+import sys
+import threading
+
+from dnsdisttests import DNSDistTest
+from proxyprotocol import ProxyProtocol
+
+# Python2/3 compatibility hacks
+try:
+ from queue import Queue
+except ImportError:
+ from Queue import Queue
+
+def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(("127.0.0.1", port))
+ except socket.error as e:
+ print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
+ sys.exit(1)
+
+ while True:
+ data, addr = sock.recvfrom(4096)
+
+ proxy = ProxyProtocol()
+ if len(data) < proxy.HEADER_SIZE:
+ continue
+
+ if not proxy.parseHeader(data):
+ continue
+
+ if proxy.local:
+ # likely a healthcheck
+ data = data[proxy.HEADER_SIZE:]
+ request = dns.message.from_wire(data)
+ response = dns.message.make_response(request)
+ wire = response.to_wire()
+ sock.settimeout(2.0)
+ sock.sendto(wire, addr)
+ sock.settimeout(None)
+
+ continue
+
+ payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
+ dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
+ toQueue.put([payload, dnsData], True, 2.0)
+ # computing the correct ID for the response
+ request = dns.message.from_wire(dnsData)
+ response = fromQueue.get(True, 2.0)
+ response.id = request.id
+
+ sock.settimeout(2.0)
+ sock.sendto(response.to_wire(), addr)
+ sock.settimeout(None)
+
+ sock.close()
+
+def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ try:
+ sock.bind(("127.0.0.1", port))
+ except socket.error as e:
+ print("Error binding in the TCP responder: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+ while True:
+ (conn, _) = sock.accept()
+ conn.settimeout(5.0)
+ # try to read the entire Proxy Protocol header
+ proxy = ProxyProtocol()
+ header = conn.recv(proxy.HEADER_SIZE)
+ if not header:
+ conn.close()
+ continue
+
+ if not proxy.parseHeader(header):
+ conn.close()
+ continue
+
+ proxyContent = conn.recv(proxy.contentLen)
+ if not proxyContent:
+ conn.close()
+ continue
+
+ payload = header + proxyContent
+
+ data = conn.recv(2)
+ (datalen,) = struct.unpack("!H", data)
+
+ data = conn.recv(datalen)
+
+ toQueue.put([payload, data], True, 2.0)
+
+ response = fromQueue.get(True, 2.0)
+ if not response:
+ conn.close()
+ continue
+
+ # computing the correct ID for the response
+ request = dns.message.from_wire(data)
+ response.id = request.id
+
+ wire = response.to_wire()
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ conn.close()
+
+ sock.close()
+
+toProxyQueue = Queue()
+fromProxyQueue = Queue()
+proxyResponderPort = 5470
+
+udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+udpResponder.setDaemon(True)
+udpResponder.start()
+tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+tcpResponder.setDaemon(True)
+tcpResponder.start()
+
+class ProxyProtocolTest(DNSDistTest):
+ _proxyResponderPort = proxyResponderPort
+ _config_params = ['_proxyResponderPort']
+
+ def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]):
+ proxy = ProxyProtocol()
+ self.assertTrue(proxy.parseHeader(receivedProxyPayload))
+ self.assertEquals(proxy.version, 0x02)
+ self.assertEquals(proxy.command, 0x01)
+ self.assertEquals(proxy.family, 0x01)
+ if not isTCP:
+ self.assertEquals(proxy.protocol, 0x02)
+ else:
+ self.assertEquals(proxy.protocol, 0x01)
+ self.assertGreater(proxy.contentLen, 0)
+
+ self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
+ self.assertEquals(proxy.source, source)
+ self.assertEquals(proxy.destination, destination)
+ #self.assertEquals(proxy.sourcePort, sourcePort)
+ self.assertEquals(proxy.destinationPort, self._dnsDistPort)
+
+ self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
+ proxy.values.sort()
+ values.sort()
+ self.assertEquals(proxy.values, values)
+
+class TestProxyProtocol(ProxyProtocolTest):
+ """
+ dnsdist is configured to prepend a Proxy Protocol header to the query
+ """
+
+ _config_template = """
+ newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+
+ function addValues(dq)
+ local values = { ["0"]="foo", ["42"]="bar" }
+ dq:setProxyProtocolValues(values)
+ return DNSAction.None
+ end
+
+ addAction("values-lua.proxy.tests.powerdns.com.", LuaAction(addValues))
+ """
+ _config_params = ['_proxyResponderPort']
+
+ def testProxyUDP(self):
+ """
+ Proxy Protocol: no value (UDP)
+ """
+ name = 'simple-udp.proxy.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ data = query.to_wire()
+ self._sock.send(data)
+ receivedResponse = None
+ try:
+ self._sock.settimeout(2.0)
+ data = self._sock.recv(4096)
+ except socket.timeout:
+ print('timeout')
+ data = None
+ if data:
+ receivedResponse = dns.message.from_wire(data)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False)
+
+ def testProxyTCP(self):
+ """
+ Proxy Protocol: no value (TCP)
+ """
+ name = 'simple-tcp.proxy.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ conn = self.openTCPConnection(2.0)
+ data = query.to_wire()
+ self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+ receivedResponse = None
+ try:
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
+
+ def testProxyUDPWithValuesFromLua(self):
+ """
+ Proxy Protocol: values from Lua (UDP)
+ """
+ name = 'values-lua.proxy.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ data = query.to_wire()
+ self._sock.send(data)
+ receivedResponse = None
+ try:
+ self._sock.settimeout(2.0)
+ data = self._sock.recv(4096)
+ except socket.timeout:
+ print('timeout')
+ data = None
+ if data:
+ receivedResponse = dns.message.from_wire(data)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [0, b'foo'] , [ 42, b'bar'] ])
+
+ def testProxyTCPWithValuesFromLua(self):
+ """
+ Proxy Protocol: values from Lua (TCP)
+ """
+ name = 'values-lua.proxy.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ conn = self.openTCPConnection(2.0)
+ data = query.to_wire()
+ self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+ receivedResponse = None
+ try:
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'] , [ 42, b'bar'] ])