#!/usr/bin/env python
-import copy
import dns
import selectors
import socket
from dnsdisttests import DNSDistTest, pickAvailablePort
from proxyprotocol import ProxyProtocol
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
from dnsdistdohtests import DNSDistDOHTest
# Python2/3 compatibility hacks
except ImportError:
from Queue import Queue
-def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- try:
- sock.bind(("127.0.0.1", port))
- except socket.error as e:
- print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
- sys.exit(1)
-
- while True:
- data, addr = sock.recvfrom(4096)
-
- proxy = ProxyProtocol()
- if len(data) < proxy.HEADER_SIZE:
- continue
-
- if not proxy.parseHeader(data):
- continue
-
- if proxy.local:
- # likely a healthcheck
- data = data[proxy.HEADER_SIZE:]
- request = dns.message.from_wire(data)
- response = dns.message.make_response(request)
- wire = response.to_wire()
- sock.settimeout(2.0)
- sock.sendto(wire, addr)
- sock.settimeout(None)
-
- continue
-
- payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
- dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
- toQueue.put([payload, dnsData], True, 2.0)
- # computing the correct ID for the response
- request = dns.message.from_wire(dnsData)
- response = fromQueue.get(True, 2.0)
- response.id = request.id
-
- sock.settimeout(2.0)
- sock.sendto(response.to_wire(), addr)
- sock.settimeout(None)
-
- sock.close()
-
-def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
- # be aware that this responder will not accept a new connection
- # until the last one has been closed. This is done on purpose to
- # to check for connection reuse, making sure that a lot of connections
- # are not opened in parallel.
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- try:
- sock.bind(("127.0.0.1", port))
- except socket.error as e:
- print("Error binding in the TCP responder: %s" % str(e))
- sys.exit(1)
-
- sock.listen(100)
- while True:
- (conn, _) = sock.accept()
- conn.settimeout(5.0)
- # try to read the entire Proxy Protocol header
- proxy = ProxyProtocol()
- header = conn.recv(proxy.HEADER_SIZE)
- if not header:
- conn.close()
- continue
-
- if not proxy.parseHeader(header):
- conn.close()
- continue
-
- proxyContent = conn.recv(proxy.contentLen)
- if not proxyContent:
- conn.close()
- continue
-
- payload = header + proxyContent
- while True:
- try:
- data = conn.recv(2)
- except socket.timeout:
- data = None
-
- if not data:
- conn.close()
- break
-
- (datalen,) = struct.unpack("!H", data)
- data = conn.recv(datalen)
-
- toQueue.put([payload, data], True, 2.0)
-
- response = copy.deepcopy(fromQueue.get(True, 2.0))
- if not response:
- conn.close()
- break
-
- # computing the correct ID for the response
- request = dns.message.from_wire(data)
- response.id = request.id
-
- wire = response.to_wire()
- conn.send(struct.pack("!H", len(wire)))
- conn.send(wire)
-
- conn.close()
-
- sock.close()
-
toProxyQueue = Queue()
fromProxyQueue = Queue()
proxyResponderPort = pickAvailablePort()
addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
setProxyProtocolACL( { "127.0.0.1/32" } )
- newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+ newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true}
function addValues(dq)
dq:addProxyProtocolValue(0, 'foo')
receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
for idx in range(5):
receivedResponse = None
receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
"""
receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
for idx in range(5):
receivedResponse = None
receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
@classmethod
def tearDownClass(cls):
print('timeout')
self.assertEqual(receivedResponse, None)
+class TestProxyProtocolNotAllowedOnBind(DNSDistTest):
+ """
+ dnsdist is configured to expect a Proxy Protocol header on incoming queries but not on the 127.0.0.1 bind
+ """
+ _skipListeningOnCL = True
+ _config_template = """
+ -- proxy protocol payloads are not allowed on this bind address!
+ addLocal('127.0.0.1:%d', {enableProxyProtocol=false})
+ setProxyProtocolACL( { "127.0.0.1/8" } )
+ newServer{address="127.0.0.1:%d"}
+ """
+ # NORMAL responder, does not expect a proxy protocol payload!
+ _config_params = ['_dnsDistPort', '_testServerPort']
+
+ def testNoHeader(self):
+ """
+ Unexpected Proxy Protocol: no header
+ """
+ # no proxy protocol header and none is expected from this source, should be passed on
+ name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(response, receivedResponse)
+
+ def testIncomingProxyDest(self):
+ """
+ Unexpected Proxy Protocol: should be dropped
+ """
+ name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ # Make sure that the proxy payload does NOT turn into a legal qname
+ destAddr = "ff:db8::ffff"
+ destPort = 65535
+ srcAddr = "ff:db8::ffff"
+ srcPort = 65535
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+ self.assertEqual(receivedResponse, None)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ wire = query.to_wire()
+
+ receivedResponse = None
+ try:
+ conn = self.openTCPConnection(2.0)
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEqual(receivedResponse, None)
+
class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
_serverKey = 'server.key'