]>
Commit | Line | Data |
---|---|---|
6ac8517d RG |
1 | #!/usr/bin/env python |
2 | import socket | |
3 | import struct | |
4 | import threading | |
5 | import time | |
6 | import dns | |
7 | from dnsdisttests import DNSDistTest | |
8 | ||
9 | try: | |
10 | range = xrange | |
11 | except NameError: | |
12 | pass | |
13 | ||
14 | class TestTCPShort(DNSDistTest): | |
15 | # this test suite uses a different responder port | |
16 | # because, contrary to the other ones, its | |
17 | # responders allow trailing data and multiple responses, | |
18 | # and we don't want to mix things up. | |
19 | _testServerPort = 5361 | |
20 | _serverKey = 'server.key' | |
21 | _serverCert = 'server.chain' | |
22 | _serverName = 'tls.tests.dnsdist.org' | |
23 | _caCert = 'ca.pem' | |
24 | _tlsServerPort = 8453 | |
25 | _tcpSendTimeout = 60 | |
26 | _config_template = """ | |
27 | newServer{address="127.0.0.1:%s"} | |
28 | addTLSLocal("127.0.0.1:%s", "%s", "%s") | |
29 | setTCPSendTimeout(%d) | |
30 | """ | |
31 | _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_tcpSendTimeout'] | |
32 | ||
33 | @classmethod | |
34 | def startResponders(cls): | |
35 | print("Launching responders..") | |
36 | ||
37 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True]) | |
38 | cls._UDPResponder.setDaemon(True) | |
39 | cls._UDPResponder.start() | |
40 | ||
41 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True, True]) | |
42 | cls._TCPResponder.setDaemon(True) | |
43 | cls._TCPResponder.start() | |
44 | ||
45 | def testTCPShortRead(self): | |
46 | """ | |
47 | TCP: Short read from client | |
48 | """ | |
49 | name = 'short-read.tcp-short.tests.powerdns.com.' | |
50 | query = dns.message.make_query(name, 'A', 'IN') | |
51 | expectedResponse = dns.message.make_response(query) | |
52 | rrset = dns.rrset.from_text(name, | |
53 | 3600, | |
54 | dns.rdataclass.IN, | |
55 | dns.rdatatype.A, | |
56 | '192.0.2.1') | |
57 | expectedResponse.answer.append(rrset) | |
58 | ||
59 | conn = self.openTCPConnection() | |
60 | wire = query.to_wire() | |
61 | # announce 7680 bytes (more than 4096, less than 8192 - the 512 bytes dnsdist is going to add) | |
62 | announcedSize = 7680 | |
63 | paddingSize = announcedSize - len(wire) | |
64 | wire = wire + (b'A' * (paddingSize - 1)) | |
65 | self._toResponderQueue.put(expectedResponse, True, 2.0) | |
66 | ||
67 | sizeBytes = struct.pack("!H", announcedSize) | |
68 | conn.send(sizeBytes[:1]) | |
69 | time.sleep(1) | |
70 | conn.send(sizeBytes[1:]) | |
71 | # send announcedSize bytes minus 1 so we get a second read | |
72 | conn.send(wire) | |
73 | time.sleep(1) | |
74 | # send 1024 bytes | |
75 | conn.send(b'A' * 1024) | |
76 | ||
77 | (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True) | |
78 | conn.close() | |
79 | ||
80 | self.assertTrue(receivedQuery) | |
81 | self.assertTrue(receivedResponse) | |
82 | receivedQuery.id = query.id | |
83 | self.assertEquals(query, receivedQuery) | |
84 | self.assertEquals(receivedResponse, expectedResponse) | |
85 | ||
86 | def testTCPTLSShortRead(self): | |
87 | """ | |
88 | TCP/TLS: Short read from client | |
89 | """ | |
90 | name = 'short-read-tls.tcp-short.tests.powerdns.com.' | |
91 | query = dns.message.make_query(name, 'A', 'IN') | |
92 | expectedResponse = dns.message.make_response(query) | |
93 | rrset = dns.rrset.from_text(name, | |
94 | 3600, | |
95 | dns.rdataclass.IN, | |
96 | dns.rdatatype.A, | |
97 | '192.0.2.1') | |
98 | expectedResponse.answer.append(rrset) | |
99 | ||
100 | conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert) | |
101 | wire = query.to_wire() | |
102 | # announce 7680 bytes (more than 4096, less than 8192 - the 512 bytes dnsdist is going to add) | |
103 | announcedSize = 7680 | |
104 | paddingSize = announcedSize - len(wire) | |
105 | wire = wire + (b'A' * (paddingSize - 1)) | |
106 | self._toResponderQueue.put(expectedResponse, True, 2.0) | |
107 | ||
108 | sizeBytes = struct.pack("!H", announcedSize) | |
109 | conn.send(sizeBytes[:1]) | |
110 | time.sleep(1) | |
111 | conn.send(sizeBytes[1:]) | |
112 | # send announcedSize bytes minus 1 so we get a second read | |
113 | conn.send(wire) | |
114 | time.sleep(1) | |
115 | # send 1024 bytes | |
116 | conn.send(b'A' * 1024) | |
117 | ||
118 | (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True) | |
119 | conn.close() | |
120 | ||
121 | self.assertTrue(receivedQuery) | |
122 | self.assertTrue(receivedResponse) | |
123 | receivedQuery.id = query.id | |
124 | self.assertEquals(query, receivedQuery) | |
125 | self.assertEquals(receivedResponse, expectedResponse) | |
126 | ||
127 | def testTCPShortWrite(self): | |
128 | """ | |
129 | TCP: Short write to client | |
130 | """ | |
131 | name = 'short-write.tcp-short.tests.powerdns.com.' | |
132 | query = dns.message.make_query(name, 'AXFR', 'IN') | |
133 | ||
134 | # we prepare a large AXFR answer | |
135 | # SOA + 200 dns messages of one huge TXT RRset each + SOA | |
136 | responses = [] | |
137 | soa = dns.rrset.from_text(name, | |
138 | 60, | |
139 | dns.rdataclass.IN, | |
140 | dns.rdatatype.SOA, | |
141 | 'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60') | |
142 | ||
143 | soaResponse = dns.message.make_response(query) | |
144 | soaResponse.use_edns(edns=False) | |
145 | soaResponse.answer.append(soa) | |
146 | responses.append(soaResponse) | |
147 | ||
148 | response = dns.message.make_response(query) | |
149 | response.use_edns(edns=False) | |
150 | content = "" | |
151 | for i in range(200): | |
152 | if len(content) > 0: | |
153 | content = content + ', ' | |
154 | content = content + (str(i)*50) | |
155 | ||
156 | rrset = dns.rrset.from_text(name, | |
157 | 3600, | |
158 | dns.rdataclass.IN, | |
159 | dns.rdatatype.TXT, | |
160 | content) | |
161 | response.answer.append(rrset) | |
162 | ||
163 | for _ in range(200): | |
164 | responses.append(response) | |
165 | ||
166 | responses.append(soaResponse) | |
167 | ||
168 | conn = self.openTCPConnection() | |
169 | ||
170 | for response in responses: | |
171 | self._toResponderQueue.put(response, True, 2.0) | |
172 | ||
173 | self.sendTCPQueryOverConnection(conn, query) | |
174 | ||
175 | # we sleep for one second, making sure that dnsdist | |
176 | # will fill its TCP window and buffers, which will result | |
177 | # in some short writes | |
178 | time.sleep(1) | |
179 | ||
180 | # we then read the messages | |
181 | receivedResponses = [] | |
182 | while True: | |
183 | datalen = conn.recv(2) | |
184 | if not datalen: | |
185 | break | |
186 | ||
187 | (datalen,) = struct.unpack("!H", datalen) | |
188 | data = b'' | |
189 | remaining = datalen | |
190 | got = conn.recv(remaining) | |
191 | while got: | |
192 | data = data + got | |
193 | if len(data) == datalen: | |
194 | break | |
195 | remaining = remaining - len(got) | |
196 | if remaining <= 0: | |
197 | break | |
198 | got = conn.recv(remaining) | |
199 | ||
200 | if data and len(data) == datalen: | |
201 | receivedResponse = dns.message.from_wire(data) | |
202 | receivedResponses.append(receivedResponse) | |
203 | ||
204 | receivedQuery = None | |
205 | if not self._fromResponderQueue.empty(): | |
206 | receivedQuery = self._fromResponderQueue.get(True, 2.0) | |
207 | ||
208 | conn.close() | |
209 | ||
210 | # and check that everything is good | |
211 | self.assertTrue(receivedQuery) | |
212 | receivedQuery.id = query.id | |
213 | self.assertEquals(query, receivedQuery) | |
214 | self.assertEquals(receivedResponses, responses) | |
215 | ||
216 | def testTCPTLSShortWrite(self): | |
217 | """ | |
218 | TCP/TLS: Short write to client | |
219 | """ | |
220 | # same as testTCPShortWrite but over TLS this time | |
221 | name = 'short-write-tls.tcp-short.tests.powerdns.com.' | |
222 | query = dns.message.make_query(name, 'AXFR', 'IN') | |
223 | responses = [] | |
224 | soa = dns.rrset.from_text(name, | |
225 | 60, | |
226 | dns.rdataclass.IN, | |
227 | dns.rdatatype.SOA, | |
228 | 'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60') | |
229 | ||
230 | soaResponse = dns.message.make_response(query) | |
231 | soaResponse.use_edns(edns=False) | |
232 | soaResponse.answer.append(soa) | |
233 | responses.append(soaResponse) | |
234 | ||
235 | response = dns.message.make_response(query) | |
236 | response.use_edns(edns=False) | |
237 | content = "" | |
238 | for i in range(200): | |
239 | if len(content) > 0: | |
240 | content = content + ', ' | |
241 | content = content + (str(i)*50) | |
242 | ||
243 | rrset = dns.rrset.from_text(name, | |
244 | 3600, | |
245 | dns.rdataclass.IN, | |
246 | dns.rdatatype.TXT, | |
247 | content) | |
248 | response.answer.append(rrset) | |
249 | ||
250 | for _ in range(200): | |
251 | responses.append(response) | |
252 | ||
253 | responses.append(soaResponse) | |
254 | ||
255 | conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert) | |
256 | ||
257 | for response in responses: | |
258 | self._toResponderQueue.put(response, True, 2.0) | |
259 | ||
260 | self.sendTCPQueryOverConnection(conn, query) | |
261 | ||
262 | time.sleep(1) | |
263 | ||
264 | receivedResponses = [] | |
265 | while True: | |
266 | datalen = conn.recv(2) | |
267 | if not datalen: | |
268 | break | |
269 | ||
270 | (datalen,) = struct.unpack("!H", datalen) | |
271 | data = b'' | |
272 | remaining = datalen | |
273 | got = conn.recv(remaining) | |
274 | while got: | |
275 | data = data + got | |
276 | if len(data) == datalen: | |
277 | break | |
278 | remaining = remaining - len(got) | |
279 | if remaining <= 0: | |
280 | break | |
281 | got = conn.recv(remaining) | |
282 | ||
283 | if data and len(data) == datalen: | |
284 | receivedResponse = dns.message.from_wire(data) | |
285 | receivedResponses.append(receivedResponse) | |
286 | ||
287 | receivedQuery = None | |
288 | if not self._fromResponderQueue.empty(): | |
289 | receivedQuery = self._fromResponderQueue.get(True, 2.0) | |
290 | ||
291 | conn.close() | |
292 | ||
293 | self.assertTrue(receivedQuery) | |
294 | receivedQuery.id = query.id | |
295 | self.assertEquals(query, receivedQuery) | |
296 | self.assertEquals(len(receivedResponses), len(responses)) | |
297 | self.assertEquals(receivedResponses, responses) |