_proxyResponderPort = proxyResponderPort
_config_params = ['_proxyResponderPort']
- def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]):
+ def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
proxy = ProxyProtocol()
self.assertTrue(proxy.parseHeader(receivedProxyPayload))
self.assertEquals(proxy.version, 0x02)
self.assertEquals(proxy.command, 0x01)
- self.assertEquals(proxy.family, 0x01)
+ if v6:
+ self.assertEquals(proxy.family, 0x02)
+ else:
+ self.assertEquals(proxy.family, 0x01)
if not isTCP:
self.assertEquals(proxy.protocol, 0x02)
else:
self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
self.assertEquals(proxy.source, source)
self.assertEquals(proxy.destination, destination)
- #self.assertEquals(proxy.sourcePort, sourcePort)
- self.assertEquals(proxy.destinationPort, self._dnsDistPort)
+ if sourcePort:
+ self.assertEquals(proxy.sourcePort, sourcePort)
+ if destinationPort:
+ self.assertEquals(proxy.destinationPort, destinationPort)
+ else:
+ self.assertEquals(proxy.destinationPort, self._dnsDistPort)
self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
proxy.values.sort()
self.assertEquals(receivedQuery, query)
self.assertEquals(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])
+
+class TestProxyProtocolIncoming(ProxyProtocolTest):
+ """
+ dnsdist is configured to prepend a Proxy Protocol header to the query and expect one on incoming queries
+ """
+
+ _config_template = """
+ setProxyProtocolACL( { "127.0.0.1/32" } )
+ newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+
+ function addValues(dq)
+ dq:addProxyProtocolValue(0, 'foo')
+ dq:addProxyProtocolValue(42, 'bar')
+ return DNSAction.None
+ end
+
+ -- refuse queries with no TLV value type 2
+ addAction(NotRule(ProxyProtocolValueRule(2)), RCodeAction(DNSRCode.REFUSED))
+ -- or with a TLV value type 3 different from "proxy"
+ addAction(NotRule(ProxyProtocolValueRule(3, "proxy")), RCodeAction(DNSRCode.REFUSED))
+
+ function answerBasedOnForwardedDest(dq)
+ local port = dq.localaddr:getPort()
+ local dest = dq.localaddr:toString()
+ return DNSAction.Spoof, "address-was-"..dest.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
+ end
+ addAction("get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedDest))
+
+ function answerBasedOnForwardedSrc(dq)
+ local port = dq.remoteaddr:getPort()
+ local src = dq.remoteaddr:toString()
+ return DNSAction.Spoof, "address-was-"..src.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
+ end
+ addAction("get-forwarded-src.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedSrc))
+
+ -- add these values for all queries
+ addAction("proxy-protocol-incoming.tests.powerdns.com.", LuaAction(addValues))
+ addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(1, "dnsdist"))
+ addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(255, "proxy-protocol"))
+
+ -- override all existing values
+ addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
+ """
+ _config_params = ['_proxyResponderPort']
+ _verboseMode = True
+
+ def testNoHeader(self):
+ """
+ Incoming Proxy Protocol: no header
+ """
+ # no proxy protocol header while one is expected, should be dropped
+ name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None)
+ self.assertEquals(receivedResponse, None)
+
+ def testIncomingProxyDest(self):
+ """
+ Incoming Proxy Protocol: values from Lua
+ """
+ name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+
+ destAddr = "2001:db8::9"
+ destPort = 9999
+ srcAddr = "2001:db8::8"
+ srcPort = 8888
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.CNAME,
+ "address-was-{}-port-was-{}.proxy-protocol-incoming.tests.powerdns.com.".format(destAddr, destPort, self._dnsDistPort))
+ response.answer.append(rrset)
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+ self.assertEquals(receivedResponse, response)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ wire = query.to_wire()
+
+ receivedResponse = None
+ try:
+ conn = self.openTCPConnection(2.0)
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEquals(receivedResponse, response)
+
+ def testProxyUDPWithValuesFromLua(self):
+ """
+ Incoming Proxy Protocol: values from Lua (UDP)
+ """
+ name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ destAddr = "2001:db8::9"
+ destPort = 9999
+ srcAddr = "2001:db8::8"
+ srcPort = 8888
+ response = dns.message.make_response(query)
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ toProxyQueue.put(response, True, 2.0)
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+ def testProxyTCPWithValuesFromLua(self):
+ """
+ Incoming Proxy Protocol: values from Lua (TCP)
+ """
+ name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ destAddr = "2001:db8::9"
+ destPort = 9999
+ srcAddr = "2001:db8::8"
+ srcPort = 8888
+ response = dns.message.make_response(query)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+ toProxyQueue.put(response, True, 2.0)
+
+ wire = query.to_wire()
+
+ receivedResponse = None
+ try:
+ conn = self.openTCPConnection(2.0)
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEquals(receivedResponse, response)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+ def testProxyUDPWithValueOverride(self):
+ """
+ Incoming Proxy Protocol: override existing value (UDP)
+ """
+ name = 'override.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ destAddr = "2001:db8::9"
+ destPort = 9999
+ srcAddr = "2001:db8::8"
+ srcPort = 8888
+ response = dns.message.make_response(query)
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [2, b'foo'], [3, b'proxy'], [ 50, b'initial-value']])
+ toProxyQueue.put(response, True, 2.0)
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [50, b'overridden'] ], True, srcPort, destPort)
+
+ def testProxyTCPSeveralQueriesOverConnection(self):
+ """
+ Incoming Proxy Protocol: Several queries over the same connection (TCP)
+ """
+ name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ destAddr = "2001:db8::9"
+ destPort = 9999
+ srcAddr = "2001:db8::8"
+ srcPort = 8888
+ response = dns.message.make_response(query)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+ toProxyQueue.put(response, True, 2.0)
+
+ wire = query.to_wire()
+
+ receivedResponse = None
+ conn = self.openTCPConnection(2.0)
+ try:
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEquals(receivedResponse, response)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+ for idx in range(5):
+ receivedResponse = None
+ toProxyQueue.put(response, True, 2.0)
+ try:
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+
+ self.assertEquals(receivedResponse, response)
+
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+class TestProxyProtocolNotExpected(DNSDistTest):
+ """
+ dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1
+ """
+
+ _config_template = """
+ setProxyProtocolACL( { "192.0.2.1/32" } )
+ newServer{address="127.0.0.1:%d"}
+ """
+ # NORMAL responder, does not expect a proxy protocol payload!
+ _config_params = ['_testServerPort']
+ _verboseMode = True
+
+ def testNoHeader(self):
+ """
+ Unexpected Proxy Protocol: no header
+ """
+ # no proxy protocol header and none is expected from this source, should be passed on
+ name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
+ def testIncomingProxyDest(self):
+ """
+ Unexpected Proxy Protocol: should be dropped
+ """
+ name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ # Make sure that the proxy payload does NOT turn into a legal qname
+ destAddr = "ff:db8::ffff"
+ destPort = 65535
+ srcAddr = "ff:db8::ffff"
+ srcPort = 65535
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+ self.assertEquals(receivedResponse, None)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ wire = query.to_wire()
+
+ receivedResponse = None
+ try:
+ conn = self.openTCPConnection(2.0)
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEquals(receivedResponse, None)