]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_ProxyProtocol.py
Merge pull request #12908 from wwijkander/master
[thirdparty/pdns.git] / regression-tests.dnsdist / test_ProxyProtocol.py
1 #!/usr/bin/env python
2
3 import copy
4 import dns
5 import selectors
6 import socket
7 import ssl
8 import struct
9 import sys
10 import threading
11 import time
12
13 from dnsdisttests import DNSDistTest, pickAvailablePort
14 from proxyprotocol import ProxyProtocol
15 from dnsdistdohtests import DNSDistDOHTest
16
17 # Python2/3 compatibility hacks
18 try:
19 from queue import Queue
20 except ImportError:
21 from Queue import Queue
22
23 def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
24 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
25 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
26 try:
27 sock.bind(("127.0.0.1", port))
28 except socket.error as e:
29 print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
30 sys.exit(1)
31
32 while True:
33 data, addr = sock.recvfrom(4096)
34
35 proxy = ProxyProtocol()
36 if len(data) < proxy.HEADER_SIZE:
37 continue
38
39 if not proxy.parseHeader(data):
40 continue
41
42 if proxy.local:
43 # likely a healthcheck
44 data = data[proxy.HEADER_SIZE:]
45 request = dns.message.from_wire(data)
46 response = dns.message.make_response(request)
47 wire = response.to_wire()
48 sock.settimeout(2.0)
49 sock.sendto(wire, addr)
50 sock.settimeout(None)
51
52 continue
53
54 payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
55 dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
56 toQueue.put([payload, dnsData], True, 2.0)
57 # computing the correct ID for the response
58 request = dns.message.from_wire(dnsData)
59 response = fromQueue.get(True, 2.0)
60 response.id = request.id
61
62 sock.settimeout(2.0)
63 sock.sendto(response.to_wire(), addr)
64 sock.settimeout(None)
65
66 sock.close()
67
68 def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
69 # be aware that this responder will not accept a new connection
70 # until the last one has been closed. This is done on purpose to
71 # to check for connection reuse, making sure that a lot of connections
72 # are not opened in parallel.
73 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
74 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
75 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
76 try:
77 sock.bind(("127.0.0.1", port))
78 except socket.error as e:
79 print("Error binding in the TCP responder: %s" % str(e))
80 sys.exit(1)
81
82 sock.listen(100)
83 while True:
84 (conn, _) = sock.accept()
85 conn.settimeout(5.0)
86 # try to read the entire Proxy Protocol header
87 proxy = ProxyProtocol()
88 header = conn.recv(proxy.HEADER_SIZE)
89 if not header:
90 conn.close()
91 continue
92
93 if not proxy.parseHeader(header):
94 conn.close()
95 continue
96
97 proxyContent = conn.recv(proxy.contentLen)
98 if not proxyContent:
99 conn.close()
100 continue
101
102 payload = header + proxyContent
103 while True:
104 try:
105 data = conn.recv(2)
106 except socket.timeout:
107 data = None
108
109 if not data:
110 conn.close()
111 break
112
113 (datalen,) = struct.unpack("!H", data)
114 data = conn.recv(datalen)
115
116 toQueue.put([payload, data], True, 2.0)
117
118 response = copy.deepcopy(fromQueue.get(True, 2.0))
119 if not response:
120 conn.close()
121 break
122
123 # computing the correct ID for the response
124 request = dns.message.from_wire(data)
125 response.id = request.id
126
127 wire = response.to_wire()
128 conn.send(struct.pack("!H", len(wire)))
129 conn.send(wire)
130
131 conn.close()
132
133 sock.close()
134
135 toProxyQueue = Queue()
136 fromProxyQueue = Queue()
137 proxyResponderPort = pickAvailablePort()
138
139 udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
140 udpResponder.daemon = True
141 udpResponder.start()
142 tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
143 tcpResponder.daemon = True
144 tcpResponder.start()
145
146 backgroundThreads = {}
147
148 def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort, serverCtx=None, ca=None, sni=None):
149 # this responder accepts TCP connections on the listening port,
150 # and relay the raw content to a second TCP connection to the
151 # forwarding port, after adding a Proxy Protocol v2 payload
152 # containing the initial source IP and port, destination IP
153 # and port.
154 backgroundThreads[threading.get_native_id()] = True
155
156 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
157 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
158 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
159
160 if serverCtx is not None:
161 sock = serverCtx.wrap_socket(sock, server_side=True)
162
163 try:
164 sock.bind(("127.0.0.1", listeningPort))
165 except socket.error as e:
166 print("Error binding in the Mock TCP reverse proxy: %s" % str(e))
167 sys.exit(1)
168 sock.settimeout(0.5)
169 sock.listen(100)
170
171 while True:
172 try:
173 (incoming, _) = sock.accept()
174 except socket.timeout:
175 if backgroundThreads.get(threading.get_native_id(), False) == False:
176 del backgroundThreads[threading.get_native_id()]
177 break
178 else:
179 continue
180
181 incoming.settimeout(5.0)
182 payload = ProxyProtocol.getPayload(False, True, False, '127.0.0.1', '127.0.0.1', incoming.getpeername()[1], listeningPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
183
184 outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
185 outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
186 outgoing.settimeout(2.0)
187 if sni:
188 if hasattr(ssl, 'create_default_context'):
189 sslctx = ssl.create_default_context(cafile=ca)
190 if hasattr(sslctx, 'set_alpn_protocols'):
191 sslctx.set_alpn_protocols(['h2'])
192 outgoing = sslctx.wrap_socket(outgoing, server_hostname=sni)
193 else:
194 outgoing = ssl.wrap_socket(outgoing, ca_certs=ca, cert_reqs=ssl.CERT_REQUIRED)
195
196 outgoing.connect(('127.0.0.1', forwardingPort))
197
198 outgoing.send(payload)
199
200 sel = selectors.DefaultSelector()
201 def readFromClient(conn):
202 data = conn.recv(512)
203 if not data or len(data) == 0:
204 return False
205 outgoing.send(data)
206 return True
207
208 def readFromBackend(conn):
209 data = conn.recv(512)
210 if not data or len(data) == 0:
211 return False
212 incoming.send(data)
213 return True
214
215 sel.register(incoming, selectors.EVENT_READ, readFromClient)
216 sel.register(outgoing, selectors.EVENT_READ, readFromBackend)
217 done = False
218 while not done:
219 try:
220 events = sel.select()
221 for key, mask in events:
222 if not (key.data)(key.fileobj):
223 done = True
224 break
225 except socket.timeout:
226 break
227 except:
228 break
229
230 incoming.close()
231 outgoing.close()
232
233 sock.close()
234
235 class ProxyProtocolTest(DNSDistTest):
236 _proxyResponderPort = proxyResponderPort
237 _config_params = ['_proxyResponderPort']
238
239 class TestProxyProtocol(ProxyProtocolTest):
240 """
241 dnsdist is configured to prepend a Proxy Protocol header to the query
242 """
243
244 _config_template = """
245 newServer{address="127.0.0.1:%d", useProxyProtocol=true}
246
247 function addValues(dq)
248 local values = { [0]="foo", [42]="bar" }
249 dq:setProxyProtocolValues(values)
250 return DNSAction.None
251 end
252
253 addAction("values-lua.proxy.tests.powerdns.com.", LuaAction(addValues))
254 addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
255 """
256 _config_params = ['_proxyResponderPort']
257 _verboseMode = True
258
259 def testProxyUDP(self):
260 """
261 Proxy Protocol: no value (UDP)
262 """
263 name = 'simple-udp.proxy.tests.powerdns.com.'
264 query = dns.message.make_query(name, 'A', 'IN')
265 response = dns.message.make_response(query)
266
267 toProxyQueue.put(response, True, 2.0)
268
269 data = query.to_wire()
270 self._sock.send(data)
271 receivedResponse = None
272 try:
273 self._sock.settimeout(2.0)
274 data = self._sock.recv(4096)
275 except socket.timeout:
276 print('timeout')
277 data = None
278 if data:
279 receivedResponse = dns.message.from_wire(data)
280
281 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
282 self.assertTrue(receivedProxyPayload)
283 self.assertTrue(receivedDNSData)
284 self.assertTrue(receivedResponse)
285
286 receivedQuery = dns.message.from_wire(receivedDNSData)
287 receivedQuery.id = query.id
288 receivedResponse.id = response.id
289 self.assertEqual(receivedQuery, query)
290 self.assertEqual(receivedResponse, response)
291 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False)
292
293 def testProxyTCP(self):
294 """
295 Proxy Protocol: no value (TCP)
296 """
297 name = 'simple-tcp.proxy.tests.powerdns.com.'
298 query = dns.message.make_query(name, 'A', 'IN')
299 response = dns.message.make_response(query)
300
301 toProxyQueue.put(response, True, 2.0)
302
303 conn = self.openTCPConnection(2.0)
304 data = query.to_wire()
305 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
306 receivedResponse = None
307 try:
308 receivedResponse = self.recvTCPResponseOverConnection(conn)
309 except socket.timeout:
310 print('timeout')
311
312 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
313 self.assertTrue(receivedProxyPayload)
314 self.assertTrue(receivedDNSData)
315 self.assertTrue(receivedResponse)
316
317 receivedQuery = dns.message.from_wire(receivedDNSData)
318 receivedQuery.id = query.id
319 receivedResponse.id = response.id
320 self.assertEqual(receivedQuery, query)
321 self.assertEqual(receivedResponse, response)
322 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
323
324 def testProxyUDPWithValuesFromLua(self):
325 """
326 Proxy Protocol: values from Lua (UDP)
327 """
328 name = 'values-lua.proxy.tests.powerdns.com.'
329 query = dns.message.make_query(name, 'A', 'IN')
330 response = dns.message.make_response(query)
331
332 toProxyQueue.put(response, True, 2.0)
333
334 data = query.to_wire()
335 self._sock.send(data)
336 receivedResponse = None
337 try:
338 self._sock.settimeout(2.0)
339 data = self._sock.recv(4096)
340 except socket.timeout:
341 print('timeout')
342 data = None
343 if data:
344 receivedResponse = dns.message.from_wire(data)
345
346 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
347 self.assertTrue(receivedProxyPayload)
348 self.assertTrue(receivedDNSData)
349 self.assertTrue(receivedResponse)
350
351 receivedQuery = dns.message.from_wire(receivedDNSData)
352 receivedQuery.id = query.id
353 receivedResponse.id = response.id
354 self.assertEqual(receivedQuery, query)
355 self.assertEqual(receivedResponse, response)
356 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [0, b'foo'] , [ 42, b'bar'] ])
357
358 def testProxyTCPWithValuesFromLua(self):
359 """
360 Proxy Protocol: values from Lua (TCP)
361 """
362 name = 'values-lua.proxy.tests.powerdns.com.'
363 query = dns.message.make_query(name, 'A', 'IN')
364 response = dns.message.make_response(query)
365
366 toProxyQueue.put(response, True, 2.0)
367
368 conn = self.openTCPConnection(2.0)
369 data = query.to_wire()
370 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
371 receivedResponse = None
372 try:
373 receivedResponse = self.recvTCPResponseOverConnection(conn)
374 except socket.timeout:
375 print('timeout')
376
377 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
378 self.assertTrue(receivedProxyPayload)
379 self.assertTrue(receivedDNSData)
380 self.assertTrue(receivedResponse)
381
382 receivedQuery = dns.message.from_wire(receivedDNSData)
383 receivedQuery.id = query.id
384 receivedResponse.id = response.id
385 self.assertEqual(receivedQuery, query)
386 self.assertEqual(receivedResponse, response)
387 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'] , [ 42, b'bar'] ])
388
389 def testProxyUDPWithValuesFromAction(self):
390 """
391 Proxy Protocol: values from Action (UDP)
392 """
393 name = 'values-action.proxy.tests.powerdns.com.'
394 query = dns.message.make_query(name, 'A', 'IN')
395 response = dns.message.make_response(query)
396
397 toProxyQueue.put(response, True, 2.0)
398
399 data = query.to_wire()
400 self._sock.send(data)
401 receivedResponse = None
402 try:
403 self._sock.settimeout(2.0)
404 data = self._sock.recv(4096)
405 except socket.timeout:
406 print('timeout')
407 data = None
408 if data:
409 receivedResponse = dns.message.from_wire(data)
410
411 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
412 self.assertTrue(receivedProxyPayload)
413 self.assertTrue(receivedDNSData)
414 self.assertTrue(receivedResponse)
415
416 receivedQuery = dns.message.from_wire(receivedDNSData)
417 receivedQuery.id = query.id
418 receivedResponse.id = response.id
419 self.assertEqual(receivedQuery, query)
420 self.assertEqual(receivedResponse, response)
421 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
422
423 def testProxyTCPWithValuesFromAction(self):
424 """
425 Proxy Protocol: values from Action (TCP)
426 """
427 name = 'values-action.proxy.tests.powerdns.com.'
428 query = dns.message.make_query(name, 'A', 'IN')
429 response = dns.message.make_response(query)
430
431 toProxyQueue.put(response, True, 2.0)
432
433 conn = self.openTCPConnection(2.0)
434 data = query.to_wire()
435 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
436 receivedResponse = None
437 try:
438 receivedResponse = self.recvTCPResponseOverConnection(conn)
439 except socket.timeout:
440 print('timeout')
441
442 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
443 self.assertTrue(receivedProxyPayload)
444 self.assertTrue(receivedDNSData)
445 self.assertTrue(receivedResponse)
446
447 receivedQuery = dns.message.from_wire(receivedDNSData)
448 receivedQuery.id = query.id
449 receivedResponse.id = response.id
450 self.assertEqual(receivedQuery, query)
451 self.assertEqual(receivedResponse, response)
452 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
453
454 def testProxyTCPSeveralQueriesOnSameConnection(self):
455 """
456 Proxy Protocol: Several queries on the same TCP connection
457 """
458 name = 'several-queries-same-conn.proxy.tests.powerdns.com.'
459 query = dns.message.make_query(name, 'A', 'IN')
460 response = dns.message.make_response(query)
461
462 conn = self.openTCPConnection(2.0)
463 data = query.to_wire()
464
465 for idx in range(10):
466 toProxyQueue.put(response, True, 2.0)
467 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
468 receivedResponse = None
469 try:
470 receivedResponse = self.recvTCPResponseOverConnection(conn)
471 except socket.timeout:
472 print('timeout')
473
474 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
475 self.assertTrue(receivedProxyPayload)
476 self.assertTrue(receivedDNSData)
477 self.assertTrue(receivedResponse)
478
479 receivedQuery = dns.message.from_wire(receivedDNSData)
480 receivedQuery.id = query.id
481 receivedResponse.id = response.id
482 self.assertEqual(receivedQuery, query)
483 self.assertEqual(receivedResponse, response)
484 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])
485
486 class TestProxyProtocolIncoming(ProxyProtocolTest):
487 """
488 dnsdist is configured to prepend a Proxy Protocol header to the query and expect one on incoming queries
489 """
490
491 _config_template = """
492 addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
493 addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
494 setProxyProtocolACL( { "127.0.0.1/32" } )
495 newServer{address="127.0.0.1:%d", useProxyProtocol=true}
496
497 function addValues(dq)
498 dq:addProxyProtocolValue(0, 'foo')
499 dq:addProxyProtocolValue(42, 'bar')
500 return DNSAction.None
501 end
502
503 -- refuse queries with no TLV value type 2
504 addAction(NotRule(ProxyProtocolValueRule(2)), RCodeAction(DNSRCode.REFUSED))
505 -- or with a TLV value type 3 different from "proxy"
506 addAction(NotRule(ProxyProtocolValueRule(3, "proxy")), RCodeAction(DNSRCode.REFUSED))
507
508 function answerBasedOnForwardedDest(dq)
509 local port = dq.localaddr:getPort()
510 local dest = dq.localaddr:toString()
511 return DNSAction.Spoof, "address-was-"..dest.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
512 end
513 addAction("get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedDest))
514
515 function answerBasedOnForwardedSrc(dq)
516 local port = dq.remoteaddr:getPort()
517 local src = dq.remoteaddr:toString()
518 return DNSAction.Spoof, "address-was-"..src.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
519 end
520 addAction("get-forwarded-src.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedSrc))
521
522 -- add these values for all queries
523 addAction("proxy-protocol-incoming.tests.powerdns.com.", LuaAction(addValues))
524 addAction("proxy-protocol-incoming.tests.powerdns.com.", SetAdditionalProxyProtocolValueAction(1, "dnsdist"))
525 addAction("proxy-protocol-incoming.tests.powerdns.com.", SetAdditionalProxyProtocolValueAction(255, "proxy-protocol"))
526
527 -- override all existing values
528 addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
529 """
530 _serverKey = 'server.key'
531 _serverCert = 'server.chain'
532 _serverName = 'tls.tests.dnsdist.org'
533 _caCert = 'ca.pem'
534 _dohServerPPOutsidePort = pickAvailablePort()
535 _dohServerPPInsidePort = pickAvailablePort()
536 _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
537
538 def testNoHeader(self):
539 """
540 Incoming Proxy Protocol: no header
541 """
542 # no proxy protocol header while one is expected, should be dropped
543 name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
544 query = dns.message.make_query(name, 'A', 'IN')
545
546 for method in ("sendUDPQuery", "sendTCPQuery", "sendDOHQueryWrapper"):
547 sender = getattr(self, method)
548 try:
549 (_, receivedResponse) = sender(query, response=None)
550 except Exception:
551 receivedResponse = None
552 self.assertEqual(receivedResponse, None)
553
554 def testIncomingProxyDest(self):
555 """
556 Incoming Proxy Protocol: values from Lua
557 """
558 name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.'
559 query = dns.message.make_query(name, 'A', 'IN')
560 # dnsdist set RA = RD for spoofed responses
561 query.flags &= ~dns.flags.RD
562
563 destAddr = "2001:db8::9"
564 destPort = 9999
565 srcAddr = "2001:db8::8"
566 srcPort = 8888
567 response = dns.message.make_response(query)
568 rrset = dns.rrset.from_text(name,
569 60,
570 dns.rdataclass.IN,
571 dns.rdatatype.CNAME,
572 "address-was-{}-port-was-{}.proxy-protocol-incoming.tests.powerdns.com.".format(destAddr, destPort, self._dnsDistPort))
573 response.answer.append(rrset)
574
575 udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
576 (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
577 self.assertEqual(receivedResponse, response)
578
579 tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
580 wire = query.to_wire()
581
582 receivedResponse = None
583 try:
584 conn = self.openTCPConnection(2.0)
585 conn.send(tcpPayload)
586 conn.send(struct.pack("!H", len(wire)))
587 conn.send(wire)
588 receivedResponse = self.recvTCPResponseOverConnection(conn)
589 except socket.timeout:
590 print('timeout')
591 self.assertEqual(receivedResponse, response)
592
593 def testProxyUDPWithValuesFromLua(self):
594 """
595 Incoming Proxy Protocol: values from Lua (UDP)
596 """
597 name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
598 query = dns.message.make_query(name, 'A', 'IN')
599 response = dns.message.make_response(query)
600
601 destAddr = "2001:db8::9"
602 destPort = 9999
603 srcAddr = "2001:db8::8"
604 srcPort = 8888
605 response = dns.message.make_response(query)
606
607 udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
608 toProxyQueue.put(response, True, 2.0)
609 (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
610
611 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
612 self.assertTrue(receivedProxyPayload)
613 self.assertTrue(receivedDNSData)
614 self.assertTrue(receivedResponse)
615
616 receivedQuery = dns.message.from_wire(receivedDNSData)
617 receivedQuery.id = query.id
618 receivedResponse.id = response.id
619 self.assertEqual(receivedQuery, query)
620 self.assertEqual(receivedResponse, response)
621 self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
622
623 def testProxyTCPWithValuesFromLua(self):
624 """
625 Incoming Proxy Protocol: values from Lua (TCP)
626 """
627 name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
628 query = dns.message.make_query(name, 'A', 'IN')
629 response = dns.message.make_response(query)
630
631 destAddr = "2001:db8::9"
632 destPort = 9999
633 srcAddr = "2001:db8::8"
634 srcPort = 8888
635 response = dns.message.make_response(query)
636
637 tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
638
639 toProxyQueue.put(response, True, 2.0)
640
641 wire = query.to_wire()
642
643 receivedResponse = None
644 try:
645 conn = self.openTCPConnection(2.0)
646 conn.send(tcpPayload)
647 conn.send(struct.pack("!H", len(wire)))
648 conn.send(wire)
649 receivedResponse = self.recvTCPResponseOverConnection(conn)
650 except socket.timeout:
651 print('timeout')
652 self.assertEqual(receivedResponse, response)
653
654 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
655 self.assertTrue(receivedProxyPayload)
656 self.assertTrue(receivedDNSData)
657 self.assertTrue(receivedResponse)
658
659 receivedQuery = dns.message.from_wire(receivedDNSData)
660 receivedQuery.id = query.id
661 self.assertEqual(receivedQuery, query)
662 self.assertEqual(receivedResponse, response)
663 self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
664
665 def testProxyUDPWithValueOverride(self):
666 """
667 Incoming Proxy Protocol: override existing value (UDP)
668 """
669 name = 'override.proxy-protocol-incoming.tests.powerdns.com.'
670 query = dns.message.make_query(name, 'A', 'IN')
671 response = dns.message.make_response(query)
672
673 destAddr = "2001:db8::9"
674 destPort = 9999
675 srcAddr = "2001:db8::8"
676 srcPort = 8888
677 response = dns.message.make_response(query)
678
679 udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [2, b'foo'], [3, b'proxy'], [ 50, b'initial-value']])
680 toProxyQueue.put(response, True, 2.0)
681 (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
682
683 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
684 self.assertTrue(receivedProxyPayload)
685 self.assertTrue(receivedDNSData)
686 self.assertTrue(receivedResponse)
687
688 receivedQuery = dns.message.from_wire(receivedDNSData)
689 receivedQuery.id = query.id
690 receivedResponse.id = response.id
691 self.assertEqual(receivedQuery, query)
692 self.assertEqual(receivedResponse, response)
693 self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [50, b'overridden'] ], True, srcPort, destPort)
694
695 def testProxyTCPSeveralQueriesOverConnection(self):
696 """
697 Incoming Proxy Protocol: Several queries over the same connection (TCP)
698 """
699 name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
700 query = dns.message.make_query(name, 'A', 'IN')
701 response = dns.message.make_response(query)
702
703 destAddr = "2001:db8::9"
704 destPort = 9999
705 srcAddr = "2001:db8::8"
706 srcPort = 8888
707
708 tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
709
710 toProxyQueue.put(response, True, 2.0)
711
712 wire = query.to_wire()
713
714 receivedResponse = None
715 conn = self.openTCPConnection(2.0)
716 try:
717 conn.send(tcpPayload)
718 conn.send(struct.pack("!H", len(wire)))
719 conn.send(wire)
720 receivedResponse = self.recvTCPResponseOverConnection(conn)
721 except socket.timeout:
722 print('timeout')
723 self.assertEqual(receivedResponse, response)
724
725 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
726 self.assertTrue(receivedProxyPayload)
727 self.assertTrue(receivedDNSData)
728 self.assertTrue(receivedResponse)
729
730 receivedQuery = dns.message.from_wire(receivedDNSData)
731 receivedQuery.id = query.id
732 receivedResponse.id = response.id
733 self.assertEqual(receivedQuery, query)
734 self.assertEqual(receivedResponse, response)
735 self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
736
737 for idx in range(5):
738 receivedResponse = None
739 toProxyQueue.put(response, True, 2.0)
740 try:
741 conn.send(struct.pack("!H", len(wire)))
742 conn.send(wire)
743 receivedResponse = self.recvTCPResponseOverConnection(conn)
744 except socket.timeout:
745 print('timeout')
746
747 self.assertEqual(receivedResponse, response)
748
749 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
750 self.assertTrue(receivedProxyPayload)
751 self.assertTrue(receivedDNSData)
752 self.assertTrue(receivedResponse)
753
754 receivedQuery = dns.message.from_wire(receivedDNSData)
755 receivedQuery.id = query.id
756 self.assertEqual(receivedQuery, query)
757 self.assertEqual(receivedResponse, response)
758 self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
759
760 def testProxyDoHSeveralQueriesOverConnectionPPOutside(self):
761 """
762 Incoming Proxy Protocol: Several queries over the same connection (DoH, PP outside TLS)
763 """
764 name = 'several-queries.doh-outside.proxy-protocol-incoming.tests.powerdns.com.'
765 query = dns.message.make_query(name, 'A', 'IN')
766 response = dns.message.make_response(query)
767
768 toProxyQueue.put(response, True, 2.0)
769
770 wire = query.to_wire()
771
772 reverseProxyPort = pickAvailablePort()
773 reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPOutsidePort])
774 reverseProxy.start()
775 time.sleep(1)
776
777 receivedResponse = None
778 conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
779
780 reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
781 (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
782 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
783 self.assertTrue(receivedProxyPayload)
784 self.assertTrue(receivedDNSData)
785 self.assertTrue(receivedResponse)
786
787 receivedQuery = dns.message.from_wire(receivedDNSData)
788 receivedQuery.id = query.id
789 receivedResponse.id = response.id
790 self.assertEqual(receivedQuery, query)
791 self.assertEqual(receivedResponse, response)
792 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
793
794 for idx in range(5):
795 receivedResponse = None
796 toProxyQueue.put(response, True, 2.0)
797 (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
798 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
799 self.assertTrue(receivedProxyPayload)
800 self.assertTrue(receivedDNSData)
801 self.assertTrue(receivedResponse)
802
803 receivedQuery = dns.message.from_wire(receivedDNSData)
804 receivedQuery.id = query.id
805 receivedResponse.id = response.id
806 self.assertEqual(receivedQuery, query)
807 self.assertEqual(receivedResponse, response)
808 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
809
810 def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
811 """
812 Incoming Proxy Protocol: Several queries over the same connection (DoH, PP inside TLS)
813 """
814 name = 'several-queries.doh-inside.proxy-protocol-incoming.tests.powerdns.com.'
815 query = dns.message.make_query(name, 'A', 'IN')
816 response = dns.message.make_response(query)
817
818 toProxyQueue.put(response, True, 2.0)
819
820 wire = query.to_wire()
821
822 reverseProxyPort = pickAvailablePort()
823 tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
824 tlsContext.load_cert_chain(self._serverCert, self._serverKey)
825 tlsContext.set_alpn_protocols(['h2'])
826 reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPInsidePort, tlsContext, self._caCert, self._serverName])
827 reverseProxy.start()
828
829 receivedResponse = None
830 time.sleep(1)
831 conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
832
833 reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
834 (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
835 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
836 self.assertTrue(receivedProxyPayload)
837 self.assertTrue(receivedDNSData)
838 self.assertTrue(receivedResponse)
839
840 receivedQuery = dns.message.from_wire(receivedDNSData)
841 receivedQuery.id = query.id
842 receivedResponse.id = response.id
843 self.assertEqual(receivedQuery, query)
844 self.assertEqual(receivedResponse, response)
845 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
846
847 for idx in range(5):
848 receivedResponse = None
849 toProxyQueue.put(response, True, 2.0)
850 (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
851 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
852 self.assertTrue(receivedProxyPayload)
853 self.assertTrue(receivedDNSData)
854 self.assertTrue(receivedResponse)
855
856 receivedQuery = dns.message.from_wire(receivedDNSData)
857 receivedQuery.id = query.id
858 receivedResponse.id = response.id
859 self.assertEqual(receivedQuery, query)
860 self.assertEqual(receivedResponse, response)
861 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
862
863 @classmethod
864 def tearDownClass(cls):
865 cls._sock.close()
866 for backgroundThread in cls._backgroundThreads:
867 cls._backgroundThreads[backgroundThread] = False
868 for backgroundThread in backgroundThreads:
869 backgroundThreads[backgroundThread] = False
870 cls.killProcess(cls._dnsdist)
871
872 class TestProxyProtocolNotExpected(DNSDistTest):
873 """
874 dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1
875 """
876
877 _config_template = """
878 setProxyProtocolACL( { "192.0.2.1/32" } )
879 newServer{address="127.0.0.1:%d"}
880 """
881 # NORMAL responder, does not expect a proxy protocol payload!
882 _config_params = ['_testServerPort']
883 _verboseMode = True
884
885 def testNoHeader(self):
886 """
887 Unexpected Proxy Protocol: no header
888 """
889 # no proxy protocol header and none is expected from this source, should be passed on
890 name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
891 query = dns.message.make_query(name, 'A', 'IN')
892 response = dns.message.make_response(query)
893 rrset = dns.rrset.from_text(name,
894 60,
895 dns.rdataclass.IN,
896 dns.rdatatype.A,
897 '127.0.0.1')
898
899 response.answer.append(rrset)
900
901 for method in ("sendUDPQuery", "sendTCPQuery"):
902 sender = getattr(self, method)
903 (receivedQuery, receivedResponse) = sender(query, response)
904 receivedQuery.id = query.id
905 self.assertEqual(query, receivedQuery)
906 self.assertEqual(response, receivedResponse)
907
908 def testIncomingProxyDest(self):
909 """
910 Unexpected Proxy Protocol: should be dropped
911 """
912 name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
913 query = dns.message.make_query(name, 'A', 'IN')
914
915 # Make sure that the proxy payload does NOT turn into a legal qname
916 destAddr = "ff:db8::ffff"
917 destPort = 65535
918 srcAddr = "ff:db8::ffff"
919 srcPort = 65535
920
921 udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
922 (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
923 self.assertEqual(receivedResponse, None)
924
925 tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
926 wire = query.to_wire()
927
928 receivedResponse = None
929 try:
930 conn = self.openTCPConnection(2.0)
931 conn.send(tcpPayload)
932 conn.send(struct.pack("!H", len(wire)))
933 conn.send(wire)
934 receivedResponse = self.recvTCPResponseOverConnection(conn)
935 except socket.timeout:
936 print('timeout')
937 self.assertEqual(receivedResponse, None)
938
939 class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
940
941 _serverKey = 'server.key'
942 _serverCert = 'server.chain'
943 _serverName = 'tls.tests.dnsdist.org'
944 _caCert = 'ca.pem'
945 _dohWithNGHTTP2ServerPort = pickAvailablePort()
946 _dohWithNGHTTP2BaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithNGHTTP2ServerPort))
947 _dohWithH2OServerPort = pickAvailablePort()
948 _dohWithH2OBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithH2OServerPort))
949 _proxyResponderPort = proxyResponderPort
950 _config_template = """
951 newServer{address="127.0.0.1:%s", useProxyProtocol=true}
952 addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='nghttp2' })
953 addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='h2o' })
954 setACL( { "::1/128", "127.0.0.0/8" } )
955 """
956 _config_params = ['_proxyResponderPort', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey']
957 _verboseMode = True
958
959 def testTruncation(self):
960 """
961 DOH: Truncation over UDP
962 """
963 # the query is first forwarded over UDP, leading to a TC=1 answer from the
964 # backend, then over TCP
965 name = 'truncated-udp.doh.proxy-protocol.tests.powerdns.com.'
966 query = dns.message.make_query(name, 'A', 'IN')
967 query.id = 42
968 expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
969 expectedQuery.id = 42
970 response = dns.message.make_response(query)
971 rrset = dns.rrset.from_text(name,
972 3600,
973 dns.rdataclass.IN,
974 dns.rdatatype.A,
975 '127.0.0.1')
976 response.answer.append(rrset)
977
978 for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
979 # first response is a TC=1
980 tcResponse = dns.message.make_response(query)
981 tcResponse.flags |= dns.flags.TC
982 toProxyQueue.put(tcResponse, True, 2.0)
983
984 ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
985 # first query, received by the responder over UDP
986 self.assertTrue(receivedProxyPayload)
987 self.assertTrue(receivedDNSData)
988 receivedQuery = dns.message.from_wire(receivedDNSData)
989 self.assertTrue(receivedQuery)
990 receivedQuery.id = expectedQuery.id
991 self.assertEqual(expectedQuery, receivedQuery)
992 self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
993 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
994
995 # check the response
996 self.assertTrue(receivedResponse)
997 self.assertEqual(response, receivedResponse)
998
999 # check the second query, received by the responder over TCP
1000 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
1001 self.assertTrue(receivedDNSData)
1002 receivedQuery = dns.message.from_wire(receivedDNSData)
1003 self.assertTrue(receivedQuery)
1004 receivedQuery.id = expectedQuery.id
1005 self.assertEqual(expectedQuery, receivedQuery)
1006 self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
1007 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
1008
1009 # make sure we consumed everything
1010 self.assertTrue(toProxyQueue.empty())
1011 self.assertTrue(fromProxyQueue.empty())
1012
1013 def testAddressFamilyMismatch(self):
1014 """
1015 DOH with IPv6 X-Forwarded-For to an IPv4 endpoint
1016 """
1017 name = 'x-forwarded-for-af-mismatch.doh.outgoing-proxy-protocol.tests.powerdns.com.'
1018 query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
1019 query.id = 0
1020 expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
1021 expectedQuery.id = 0
1022 response = dns.message.make_response(query)
1023 rrset = dns.rrset.from_text(name,
1024 3600,
1025 dns.rdataclass.IN,
1026 dns.rdatatype.A,
1027 '127.0.0.1')
1028 response.answer.append(rrset)
1029
1030 for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
1031 # the query should be dropped
1032 (receivedQuery, receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: [::1]:8080'], useQueue=False)
1033 self.assertFalse(receivedQuery)
1034 self.assertFalse(receivedResponse)
1035
1036 # make sure the timeout is detected, if any
1037 time.sleep(4)
1038
1039 # this one should not
1040 ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: 127.0.0.42:8080'], response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
1041 self.assertTrue(receivedProxyPayload)
1042 self.assertTrue(receivedDNSData)
1043 receivedQuery = dns.message.from_wire(receivedDNSData)
1044 self.assertTrue(receivedQuery)
1045 receivedQuery.id = expectedQuery.id
1046 self.assertEqual(expectedQuery, receivedQuery)
1047 self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
1048 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.42', '127.0.0.1', True, destinationPort=port)
1049 # check the response
1050 self.assertTrue(receivedResponse)
1051 receivedResponse.id = response.id
1052 self.assertEqual(response, receivedResponse)