9 from dnsdisttests
import DNSDistTest
10 from proxyprotocol
import ProxyProtocol
12 # Python2/3 compatibility hacks
14 from queue
import Queue
16 from Queue
import Queue
18 def ProxyProtocolUDPResponder(port
, fromQueue
, toQueue
):
19 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
20 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
22 sock
.bind(("127.0.0.1", port
))
23 except socket
.error
as e
:
24 print("Error binding in the Proxy Protocol UDP responder: %s" % str(e
))
28 data
, addr
= sock
.recvfrom(4096)
30 proxy
= ProxyProtocol()
31 if len(data
) < proxy
.HEADER_SIZE
:
34 if not proxy
.parseHeader(data
):
38 # likely a healthcheck
39 data
= data
[proxy
.HEADER_SIZE
:]
40 request
= dns
.message
.from_wire(data
)
41 response
= dns
.message
.make_response(request
)
42 wire
= response
.to_wire()
44 sock
.sendto(wire
, addr
)
49 payload
= data
[:(proxy
.HEADER_SIZE
+ proxy
.contentLen
)]
50 dnsData
= data
[(proxy
.HEADER_SIZE
+ proxy
.contentLen
):]
51 toQueue
.put([payload
, dnsData
], True, 2.0)
52 # computing the correct ID for the response
53 request
= dns
.message
.from_wire(dnsData
)
54 response
= fromQueue
.get(True, 2.0)
55 response
.id = request
.id
58 sock
.sendto(response
.to_wire(), addr
)
63 def ProxyProtocolTCPResponder(port
, fromQueue
, toQueue
):
64 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
65 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
66 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
68 sock
.bind(("127.0.0.1", port
))
69 except socket
.error
as e
:
70 print("Error binding in the TCP responder: %s" % str(e
))
75 (conn
, _
) = sock
.accept()
77 # try to read the entire Proxy Protocol header
78 proxy
= ProxyProtocol()
79 header
= conn
.recv(proxy
.HEADER_SIZE
)
84 if not proxy
.parseHeader(header
):
88 proxyContent
= conn
.recv(proxy
.contentLen
)
93 payload
= header
+ proxyContent
97 except socket
.timeout
:
104 (datalen
,) = struct
.unpack("!H", data
)
105 data
= conn
.recv(datalen
)
107 toQueue
.put([payload
, data
], True, 2.0)
109 response
= fromQueue
.get(True, 2.0)
114 # computing the correct ID for the response
115 request
= dns
.message
.from_wire(data
)
116 response
.id = request
.id
118 wire
= response
.to_wire()
119 conn
.send(struct
.pack("!H", len(wire
)))
126 toProxyQueue
= Queue()
127 fromProxyQueue
= Queue()
128 proxyResponderPort
= 5470
130 udpResponder
= threading
.Thread(name
='UDP Proxy Protocol Responder', target
=ProxyProtocolUDPResponder
, args
=[proxyResponderPort
, toProxyQueue
, fromProxyQueue
])
131 udpResponder
.setDaemon(True)
133 tcpResponder
= threading
.Thread(name
='TCP Proxy Protocol Responder', target
=ProxyProtocolTCPResponder
, args
=[proxyResponderPort
, toProxyQueue
, fromProxyQueue
])
134 tcpResponder
.setDaemon(True)
137 class ProxyProtocolTest(DNSDistTest
):
138 _proxyResponderPort
= proxyResponderPort
139 _config_params
= ['_proxyResponderPort']
141 def checkMessageProxyProtocol(self
, receivedProxyPayload
, source
, destination
, isTCP
, values
=[]):
142 proxy
= ProxyProtocol()
143 self
.assertTrue(proxy
.parseHeader(receivedProxyPayload
))
144 self
.assertEquals(proxy
.version
, 0x02)
145 self
.assertEquals(proxy
.command
, 0x01)
146 self
.assertEquals(proxy
.family
, 0x01)
148 self
.assertEquals(proxy
.protocol
, 0x02)
150 self
.assertEquals(proxy
.protocol
, 0x01)
151 self
.assertGreater(proxy
.contentLen
, 0)
153 self
.assertTrue(proxy
.parseAddressesAndPorts(receivedProxyPayload
))
154 self
.assertEquals(proxy
.source
, source
)
155 self
.assertEquals(proxy
.destination
, destination
)
156 #self.assertEquals(proxy.sourcePort, sourcePort)
157 self
.assertEquals(proxy
.destinationPort
, self
._dnsDistPort
)
159 self
.assertTrue(proxy
.parseAdditionalValues(receivedProxyPayload
))
162 self
.assertEquals(proxy
.values
, values
)
164 class TestProxyProtocol(ProxyProtocolTest
):
166 dnsdist is configured to prepend a Proxy Protocol header to the query
169 _config_template
= """
170 newServer{address="127.0.0.1:%d", useProxyProtocol=true}
172 function addValues(dq)
173 local values = { [0]="foo", [42]="bar" }
174 dq:setProxyProtocolValues(values)
175 return DNSAction.None
178 addAction("values-lua.proxy.tests.powerdns.com.", LuaAction(addValues))
179 addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
181 _config_params
= ['_proxyResponderPort']
183 def testProxyUDP(self
):
185 Proxy Protocol: no value (UDP)
187 name
= 'simple-udp.proxy.tests.powerdns.com.'
188 query
= dns
.message
.make_query(name
, 'A', 'IN')
189 response
= dns
.message
.make_response(query
)
191 toProxyQueue
.put(response
, True, 2.0)
193 data
= query
.to_wire()
194 self
._sock
.send(data
)
195 receivedResponse
= None
197 self
._sock
.settimeout(2.0)
198 data
= self
._sock
.recv(4096)
199 except socket
.timeout
:
203 receivedResponse
= dns
.message
.from_wire(data
)
205 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
206 self
.assertTrue(receivedProxyPayload
)
207 self
.assertTrue(receivedDNSData
)
208 self
.assertTrue(receivedResponse
)
210 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
211 receivedQuery
.id = query
.id
212 receivedResponse
.id = response
.id
213 self
.assertEquals(receivedQuery
, query
)
214 self
.assertEquals(receivedResponse
, response
)
215 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', False)
217 def testProxyTCP(self
):
219 Proxy Protocol: no value (TCP)
221 name
= 'simple-tcp.proxy.tests.powerdns.com.'
222 query
= dns
.message
.make_query(name
, 'A', 'IN')
223 response
= dns
.message
.make_response(query
)
225 toProxyQueue
.put(response
, True, 2.0)
227 conn
= self
.openTCPConnection(2.0)
228 data
= query
.to_wire()
229 self
.sendTCPQueryOverConnection(conn
, data
, rawQuery
=True)
230 receivedResponse
= None
232 receivedResponse
= self
.recvTCPResponseOverConnection(conn
)
233 except socket
.timeout
:
236 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
237 self
.assertTrue(receivedProxyPayload
)
238 self
.assertTrue(receivedDNSData
)
239 self
.assertTrue(receivedResponse
)
241 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
242 receivedQuery
.id = query
.id
243 receivedResponse
.id = response
.id
244 self
.assertEquals(receivedQuery
, query
)
245 self
.assertEquals(receivedResponse
, response
)
246 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', True)
248 def testProxyUDPWithValuesFromLua(self
):
250 Proxy Protocol: values from Lua (UDP)
252 name
= 'values-lua.proxy.tests.powerdns.com.'
253 query
= dns
.message
.make_query(name
, 'A', 'IN')
254 response
= dns
.message
.make_response(query
)
256 toProxyQueue
.put(response
, True, 2.0)
258 data
= query
.to_wire()
259 self
._sock
.send(data
)
260 receivedResponse
= None
262 self
._sock
.settimeout(2.0)
263 data
= self
._sock
.recv(4096)
264 except socket
.timeout
:
268 receivedResponse
= dns
.message
.from_wire(data
)
270 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
271 self
.assertTrue(receivedProxyPayload
)
272 self
.assertTrue(receivedDNSData
)
273 self
.assertTrue(receivedResponse
)
275 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
276 receivedQuery
.id = query
.id
277 receivedResponse
.id = response
.id
278 self
.assertEquals(receivedQuery
, query
)
279 self
.assertEquals(receivedResponse
, response
)
280 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', False, [ [0, b
'foo'] , [ 42, b
'bar'] ])
282 def testProxyTCPWithValuesFromLua(self
):
284 Proxy Protocol: values from Lua (TCP)
286 name
= 'values-lua.proxy.tests.powerdns.com.'
287 query
= dns
.message
.make_query(name
, 'A', 'IN')
288 response
= dns
.message
.make_response(query
)
290 toProxyQueue
.put(response
, True, 2.0)
292 conn
= self
.openTCPConnection(2.0)
293 data
= query
.to_wire()
294 self
.sendTCPQueryOverConnection(conn
, data
, rawQuery
=True)
295 receivedResponse
= None
297 receivedResponse
= self
.recvTCPResponseOverConnection(conn
)
298 except socket
.timeout
:
301 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
302 self
.assertTrue(receivedProxyPayload
)
303 self
.assertTrue(receivedDNSData
)
304 self
.assertTrue(receivedResponse
)
306 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
307 receivedQuery
.id = query
.id
308 receivedResponse
.id = response
.id
309 self
.assertEquals(receivedQuery
, query
)
310 self
.assertEquals(receivedResponse
, response
)
311 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', True, [ [0, b
'foo'] , [ 42, b
'bar'] ])
313 def testProxyUDPWithValuesFromAction(self
):
315 Proxy Protocol: values from Action (UDP)
317 name
= 'values-action.proxy.tests.powerdns.com.'
318 query
= dns
.message
.make_query(name
, 'A', 'IN')
319 response
= dns
.message
.make_response(query
)
321 toProxyQueue
.put(response
, True, 2.0)
323 data
= query
.to_wire()
324 self
._sock
.send(data
)
325 receivedResponse
= None
327 self
._sock
.settimeout(2.0)
328 data
= self
._sock
.recv(4096)
329 except socket
.timeout
:
333 receivedResponse
= dns
.message
.from_wire(data
)
335 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
336 self
.assertTrue(receivedProxyPayload
)
337 self
.assertTrue(receivedDNSData
)
338 self
.assertTrue(receivedResponse
)
340 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
341 receivedQuery
.id = query
.id
342 receivedResponse
.id = response
.id
343 self
.assertEquals(receivedQuery
, query
)
344 self
.assertEquals(receivedResponse
, response
)
345 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', False, [ [1, b
'dnsdist'] , [ 255, b
'proxy-protocol'] ])
347 def testProxyTCPWithValuesFromAction(self
):
349 Proxy Protocol: values from Action (TCP)
351 name
= 'values-action.proxy.tests.powerdns.com.'
352 query
= dns
.message
.make_query(name
, 'A', 'IN')
353 response
= dns
.message
.make_response(query
)
355 toProxyQueue
.put(response
, True, 2.0)
357 conn
= self
.openTCPConnection(2.0)
358 data
= query
.to_wire()
359 self
.sendTCPQueryOverConnection(conn
, data
, rawQuery
=True)
360 receivedResponse
= None
362 receivedResponse
= self
.recvTCPResponseOverConnection(conn
)
363 except socket
.timeout
:
366 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
367 self
.assertTrue(receivedProxyPayload
)
368 self
.assertTrue(receivedDNSData
)
369 self
.assertTrue(receivedResponse
)
371 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
372 receivedQuery
.id = query
.id
373 receivedResponse
.id = response
.id
374 self
.assertEquals(receivedQuery
, query
)
375 self
.assertEquals(receivedResponse
, response
)
376 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', True, [ [1, b
'dnsdist'] , [ 255, b
'proxy-protocol'] ])
378 def testProxyTCPSeveralQueriesOnSameConnection(self
):
380 Proxy Protocol: Several queries on the same TCP connection
382 name
= 'several-queries-same-conn.proxy.tests.powerdns.com.'
383 query
= dns
.message
.make_query(name
, 'A', 'IN')
384 response
= dns
.message
.make_response(query
)
386 conn
= self
.openTCPConnection(2.0)
387 data
= query
.to_wire()
389 for idx
in range(10):
390 toProxyQueue
.put(response
, True, 2.0)
391 self
.sendTCPQueryOverConnection(conn
, data
, rawQuery
=True)
392 receivedResponse
= None
394 receivedResponse
= self
.recvTCPResponseOverConnection(conn
)
395 except socket
.timeout
:
398 (receivedProxyPayload
, receivedDNSData
) = fromProxyQueue
.get(True, 2.0)
399 self
.assertTrue(receivedProxyPayload
)
400 self
.assertTrue(receivedDNSData
)
401 self
.assertTrue(receivedResponse
)
403 receivedQuery
= dns
.message
.from_wire(receivedDNSData
)
404 receivedQuery
.id = query
.id
405 receivedResponse
.id = response
.id
406 self
.assertEquals(receivedQuery
, query
)
407 self
.assertEquals(receivedResponse
, response
)
408 self
.checkMessageProxyProtocol(receivedProxyPayload
, '127.0.0.1', '127.0.0.1', True, [])