]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_ProxyProtocol.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.dnsdist / test_ProxyProtocol.py
1 #!/usr/bin/env python
2
3 import dns
4 import socket
5 import struct
6 import sys
7 import threading
8
9 from dnsdisttests import DNSDistTest
10 from proxyprotocol import ProxyProtocol
11
12 # Python2/3 compatibility hacks
13 try:
14 from queue import Queue
15 except ImportError:
16 from Queue import Queue
17
18 def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
19 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
20 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
21 try:
22 sock.bind(("127.0.0.1", port))
23 except socket.error as e:
24 print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
25 sys.exit(1)
26
27 while True:
28 data, addr = sock.recvfrom(4096)
29
30 proxy = ProxyProtocol()
31 if len(data) < proxy.HEADER_SIZE:
32 continue
33
34 if not proxy.parseHeader(data):
35 continue
36
37 if proxy.local:
38 # likely a healthcheck
39 data = data[proxy.HEADER_SIZE:]
40 request = dns.message.from_wire(data)
41 response = dns.message.make_response(request)
42 wire = response.to_wire()
43 sock.settimeout(2.0)
44 sock.sendto(wire, addr)
45 sock.settimeout(None)
46
47 continue
48
49 payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
50 dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
51 toQueue.put([payload, dnsData], True, 2.0)
52 # computing the correct ID for the response
53 request = dns.message.from_wire(dnsData)
54 response = fromQueue.get(True, 2.0)
55 response.id = request.id
56
57 sock.settimeout(2.0)
58 sock.sendto(response.to_wire(), addr)
59 sock.settimeout(None)
60
61 sock.close()
62
63 def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
64 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
65 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
66 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
67 try:
68 sock.bind(("127.0.0.1", port))
69 except socket.error as e:
70 print("Error binding in the TCP responder: %s" % str(e))
71 sys.exit(1)
72
73 sock.listen(100)
74 while True:
75 (conn, _) = sock.accept()
76 conn.settimeout(5.0)
77 # try to read the entire Proxy Protocol header
78 proxy = ProxyProtocol()
79 header = conn.recv(proxy.HEADER_SIZE)
80 if not header:
81 conn.close()
82 continue
83
84 if not proxy.parseHeader(header):
85 conn.close()
86 continue
87
88 proxyContent = conn.recv(proxy.contentLen)
89 if not proxyContent:
90 conn.close()
91 continue
92
93 payload = header + proxyContent
94 while True:
95 try:
96 data = conn.recv(2)
97 except socket.timeout:
98 data = None
99
100 if not data:
101 conn.close()
102 break
103
104 (datalen,) = struct.unpack("!H", data)
105 data = conn.recv(datalen)
106
107 toQueue.put([payload, data], True, 2.0)
108
109 response = fromQueue.get(True, 2.0)
110 if not response:
111 conn.close()
112 break
113
114 # computing the correct ID for the response
115 request = dns.message.from_wire(data)
116 response.id = request.id
117
118 wire = response.to_wire()
119 conn.send(struct.pack("!H", len(wire)))
120 conn.send(wire)
121
122 conn.close()
123
124 sock.close()
125
126 toProxyQueue = Queue()
127 fromProxyQueue = Queue()
128 proxyResponderPort = 5470
129
130 udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
131 udpResponder.setDaemon(True)
132 udpResponder.start()
133 tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
134 tcpResponder.setDaemon(True)
135 tcpResponder.start()
136
137 class ProxyProtocolTest(DNSDistTest):
138 _proxyResponderPort = proxyResponderPort
139 _config_params = ['_proxyResponderPort']
140
141 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]):
142 proxy = ProxyProtocol()
143 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
144 self.assertEquals(proxy.version, 0x02)
145 self.assertEquals(proxy.command, 0x01)
146 self.assertEquals(proxy.family, 0x01)
147 if not isTCP:
148 self.assertEquals(proxy.protocol, 0x02)
149 else:
150 self.assertEquals(proxy.protocol, 0x01)
151 self.assertGreater(proxy.contentLen, 0)
152
153 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
154 self.assertEquals(proxy.source, source)
155 self.assertEquals(proxy.destination, destination)
156 #self.assertEquals(proxy.sourcePort, sourcePort)
157 self.assertEquals(proxy.destinationPort, self._dnsDistPort)
158
159 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
160 proxy.values.sort()
161 values.sort()
162 self.assertEquals(proxy.values, values)
163
164 class TestProxyProtocol(ProxyProtocolTest):
165 """
166 dnsdist is configured to prepend a Proxy Protocol header to the query
167 """
168
169 _config_template = """
170 newServer{address="127.0.0.1:%d", useProxyProtocol=true}
171
172 function addValues(dq)
173 local values = { [0]="foo", [42]="bar" }
174 dq:setProxyProtocolValues(values)
175 return DNSAction.None
176 end
177
178 addAction("values-lua.proxy.tests.powerdns.com.", LuaAction(addValues))
179 addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
180 """
181 _config_params = ['_proxyResponderPort']
182
183 def testProxyUDP(self):
184 """
185 Proxy Protocol: no value (UDP)
186 """
187 name = 'simple-udp.proxy.tests.powerdns.com.'
188 query = dns.message.make_query(name, 'A', 'IN')
189 response = dns.message.make_response(query)
190
191 toProxyQueue.put(response, True, 2.0)
192
193 data = query.to_wire()
194 self._sock.send(data)
195 receivedResponse = None
196 try:
197 self._sock.settimeout(2.0)
198 data = self._sock.recv(4096)
199 except socket.timeout:
200 print('timeout')
201 data = None
202 if data:
203 receivedResponse = dns.message.from_wire(data)
204
205 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
206 self.assertTrue(receivedProxyPayload)
207 self.assertTrue(receivedDNSData)
208 self.assertTrue(receivedResponse)
209
210 receivedQuery = dns.message.from_wire(receivedDNSData)
211 receivedQuery.id = query.id
212 receivedResponse.id = response.id
213 self.assertEquals(receivedQuery, query)
214 self.assertEquals(receivedResponse, response)
215 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False)
216
217 def testProxyTCP(self):
218 """
219 Proxy Protocol: no value (TCP)
220 """
221 name = 'simple-tcp.proxy.tests.powerdns.com.'
222 query = dns.message.make_query(name, 'A', 'IN')
223 response = dns.message.make_response(query)
224
225 toProxyQueue.put(response, True, 2.0)
226
227 conn = self.openTCPConnection(2.0)
228 data = query.to_wire()
229 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
230 receivedResponse = None
231 try:
232 receivedResponse = self.recvTCPResponseOverConnection(conn)
233 except socket.timeout:
234 print('timeout')
235
236 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
237 self.assertTrue(receivedProxyPayload)
238 self.assertTrue(receivedDNSData)
239 self.assertTrue(receivedResponse)
240
241 receivedQuery = dns.message.from_wire(receivedDNSData)
242 receivedQuery.id = query.id
243 receivedResponse.id = response.id
244 self.assertEquals(receivedQuery, query)
245 self.assertEquals(receivedResponse, response)
246 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
247
248 def testProxyUDPWithValuesFromLua(self):
249 """
250 Proxy Protocol: values from Lua (UDP)
251 """
252 name = 'values-lua.proxy.tests.powerdns.com.'
253 query = dns.message.make_query(name, 'A', 'IN')
254 response = dns.message.make_response(query)
255
256 toProxyQueue.put(response, True, 2.0)
257
258 data = query.to_wire()
259 self._sock.send(data)
260 receivedResponse = None
261 try:
262 self._sock.settimeout(2.0)
263 data = self._sock.recv(4096)
264 except socket.timeout:
265 print('timeout')
266 data = None
267 if data:
268 receivedResponse = dns.message.from_wire(data)
269
270 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
271 self.assertTrue(receivedProxyPayload)
272 self.assertTrue(receivedDNSData)
273 self.assertTrue(receivedResponse)
274
275 receivedQuery = dns.message.from_wire(receivedDNSData)
276 receivedQuery.id = query.id
277 receivedResponse.id = response.id
278 self.assertEquals(receivedQuery, query)
279 self.assertEquals(receivedResponse, response)
280 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [0, b'foo'] , [ 42, b'bar'] ])
281
282 def testProxyTCPWithValuesFromLua(self):
283 """
284 Proxy Protocol: values from Lua (TCP)
285 """
286 name = 'values-lua.proxy.tests.powerdns.com.'
287 query = dns.message.make_query(name, 'A', 'IN')
288 response = dns.message.make_response(query)
289
290 toProxyQueue.put(response, True, 2.0)
291
292 conn = self.openTCPConnection(2.0)
293 data = query.to_wire()
294 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
295 receivedResponse = None
296 try:
297 receivedResponse = self.recvTCPResponseOverConnection(conn)
298 except socket.timeout:
299 print('timeout')
300
301 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
302 self.assertTrue(receivedProxyPayload)
303 self.assertTrue(receivedDNSData)
304 self.assertTrue(receivedResponse)
305
306 receivedQuery = dns.message.from_wire(receivedDNSData)
307 receivedQuery.id = query.id
308 receivedResponse.id = response.id
309 self.assertEquals(receivedQuery, query)
310 self.assertEquals(receivedResponse, response)
311 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'] , [ 42, b'bar'] ])
312
313 def testProxyUDPWithValuesFromAction(self):
314 """
315 Proxy Protocol: values from Action (UDP)
316 """
317 name = 'values-action.proxy.tests.powerdns.com.'
318 query = dns.message.make_query(name, 'A', 'IN')
319 response = dns.message.make_response(query)
320
321 toProxyQueue.put(response, True, 2.0)
322
323 data = query.to_wire()
324 self._sock.send(data)
325 receivedResponse = None
326 try:
327 self._sock.settimeout(2.0)
328 data = self._sock.recv(4096)
329 except socket.timeout:
330 print('timeout')
331 data = None
332 if data:
333 receivedResponse = dns.message.from_wire(data)
334
335 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
336 self.assertTrue(receivedProxyPayload)
337 self.assertTrue(receivedDNSData)
338 self.assertTrue(receivedResponse)
339
340 receivedQuery = dns.message.from_wire(receivedDNSData)
341 receivedQuery.id = query.id
342 receivedResponse.id = response.id
343 self.assertEquals(receivedQuery, query)
344 self.assertEquals(receivedResponse, response)
345 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
346
347 def testProxyTCPWithValuesFromAction(self):
348 """
349 Proxy Protocol: values from Action (TCP)
350 """
351 name = 'values-action.proxy.tests.powerdns.com.'
352 query = dns.message.make_query(name, 'A', 'IN')
353 response = dns.message.make_response(query)
354
355 toProxyQueue.put(response, True, 2.0)
356
357 conn = self.openTCPConnection(2.0)
358 data = query.to_wire()
359 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
360 receivedResponse = None
361 try:
362 receivedResponse = self.recvTCPResponseOverConnection(conn)
363 except socket.timeout:
364 print('timeout')
365
366 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
367 self.assertTrue(receivedProxyPayload)
368 self.assertTrue(receivedDNSData)
369 self.assertTrue(receivedResponse)
370
371 receivedQuery = dns.message.from_wire(receivedDNSData)
372 receivedQuery.id = query.id
373 receivedResponse.id = response.id
374 self.assertEquals(receivedQuery, query)
375 self.assertEquals(receivedResponse, response)
376 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
377
378 def testProxyTCPSeveralQueriesOnSameConnection(self):
379 """
380 Proxy Protocol: Several queries on the same TCP connection
381 """
382 name = 'several-queries-same-conn.proxy.tests.powerdns.com.'
383 query = dns.message.make_query(name, 'A', 'IN')
384 response = dns.message.make_response(query)
385
386 conn = self.openTCPConnection(2.0)
387 data = query.to_wire()
388
389 for idx in range(10):
390 toProxyQueue.put(response, True, 2.0)
391 self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
392 receivedResponse = None
393 try:
394 receivedResponse = self.recvTCPResponseOverConnection(conn)
395 except socket.timeout:
396 print('timeout')
397
398 (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
399 self.assertTrue(receivedProxyPayload)
400 self.assertTrue(receivedDNSData)
401 self.assertTrue(receivedResponse)
402
403 receivedQuery = dns.message.from_wire(receivedDNSData)
404 receivedQuery.id = query.id
405 receivedResponse.id = response.id
406 self.assertEquals(receivedQuery, query)
407 self.assertEquals(receivedResponse, response)
408 self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])