]>
Commit | Line | Data |
---|---|---|
eb353acc RG |
1 | #!/usr/bin/env python2 |
2 | ||
3 | import errno | |
4 | import shutil | |
5 | import os | |
6 | import socket | |
7 | import struct | |
8 | import subprocess | |
9 | import sys | |
10 | import time | |
11 | import unittest | |
12 | import dns | |
13 | import dns.message | |
14 | ||
5d5d29d4 PD |
15 | from eqdnsmessage import AssertEqualDNSMessageMixin |
16 | ||
17 | class IXFRDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): | |
eb353acc RG |
18 | |
19 | _ixfrDistStartupDelay = 2.0 | |
20 | _ixfrDistPort = 5342 | |
21 | ||
22 | _config_template = """ | |
23 | listen: | |
24 | - '127.0.0.1:%d' | |
25 | acl: | |
26 | - '127.0.0.0/8' | |
27 | axfr-timeout: 20 | |
28 | keep: 20 | |
29 | tcp-in-threads: 10 | |
30 | work-dir: 'ixfrdist.dir' | |
31 | failed-soa-retry: 3 | |
32 | """ | |
33 | _config_domains = None | |
34 | _config_params = ['_ixfrDistPort'] | |
35 | ||
36 | @classmethod | |
37 | def startIXFRDist(cls): | |
38 | print("Launching ixfrdist..") | |
39 | conffile = 'ixfrdist.yml' | |
40 | params = tuple([getattr(cls, param) for param in cls._config_params]) | |
41 | print(params) | |
42 | with open(conffile, 'w') as conf: | |
43 | conf.write("# Autogenerated by ixfrdisttests.py\n") | |
44 | conf.write(cls._config_template % params) | |
45 | ||
46 | if cls._config_domains is not None: | |
47 | conf.write("domains:\n") | |
48 | ||
49 | for domain, master in cls._config_domains.items(): | |
50 | conf.write(" - domain: %s\n" % (domain)) | |
51 | conf.write(" master: %s\n" % (master)) | |
52 | ||
53 | ixfrdistcmd = [os.environ['IXFRDISTBIN'], '--config', conffile, '--debug'] | |
54 | ||
55 | logFile = 'ixfrdist.log' | |
56 | with open(logFile, 'w') as fdLog: | |
57 | cls._ixfrdist = subprocess.Popen(ixfrdistcmd, close_fds=True, | |
58 | stdout=fdLog, stderr=fdLog) | |
59 | ||
60 | if 'IXFRDIST_FAST_TESTS' in os.environ: | |
61 | delay = 0.5 | |
62 | else: | |
63 | delay = cls._ixfrDistStartupDelay | |
64 | ||
65 | time.sleep(delay) | |
66 | ||
67 | if cls._ixfrdist.poll() is not None: | |
68 | cls._ixfrdist.kill() | |
69 | sys.exit(cls._ixfrdist.returncode) | |
70 | ||
71 | @classmethod | |
72 | def setUpSockets(cls): | |
73 | print("Setting up UDP socket..") | |
74 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
75 | cls._sock.settimeout(2.0) | |
76 | cls._sock.connect(("127.0.0.1", cls._ixfrDistPort)) | |
77 | ||
78 | @classmethod | |
79 | def setUpClass(cls): | |
80 | cls.startIXFRDist() | |
81 | cls.setUpSockets() | |
82 | ||
83 | print("Launching tests..") | |
84 | ||
85 | @classmethod | |
86 | def tearDownClass(cls): | |
87 | cls.tearDownIXFRDist() | |
88 | ||
89 | @classmethod | |
90 | def tearDownIXFRDist(cls): | |
91 | if 'IXFRDIST_FAST_TESTS' in os.environ: | |
92 | delay = 0.1 | |
93 | else: | |
94 | delay = 1.0 | |
95 | ||
96 | try: | |
97 | if cls._ixfrdist: | |
98 | cls._ixfrdist.terminate() | |
99 | if cls._ixfrdist.poll() is None: | |
100 | time.sleep(delay) | |
101 | if cls._ixfrdist.poll() is None: | |
102 | cls._ixfrdist.kill() | |
103 | cls._ixfrdist.wait() | |
104 | except OSError as e: | |
105 | # There is a race-condition with the poll() and | |
106 | # kill() statements, when the process is dead on the | |
107 | # kill(), this is fine | |
108 | if e.errno != errno.ESRCH: | |
109 | raise | |
110 | ||
111 | @classmethod | |
112 | def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()): | |
113 | if timeout: | |
114 | cls._sock.settimeout(timeout) | |
115 | ||
116 | try: | |
117 | cls._sock.send(query.to_wire()) | |
118 | data = cls._sock.recv(4096) | |
119 | except socket.timeout: | |
120 | data = None | |
121 | finally: | |
122 | if timeout: | |
123 | cls._sock.settimeout(None) | |
124 | ||
125 | message = None | |
126 | if data: | |
127 | if not decode: | |
128 | return data | |
129 | message = dns.message.from_wire(data, **fwparams) | |
130 | return message | |
131 | ||
132 | # FIXME: sendTCPQuery and sendTCPQueryMultiResponse, when they are done reading | |
133 | # should wait for a short while on the socket to see if more data is coming | |
134 | # and error if it does! | |
135 | @classmethod | |
136 | def sendTCPQuery(cls, query, timeout=2.0): | |
137 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
138 | if timeout: | |
139 | sock.settimeout(timeout) | |
140 | ||
141 | sock.connect(("127.0.0.1", cls._ixfrDistPort)) | |
142 | ||
143 | try: | |
144 | wire = query.to_wire() | |
145 | sock.send(struct.pack("!H", len(wire))) | |
146 | sock.send(wire) | |
147 | data = sock.recv(2) | |
148 | if data: | |
149 | (datalen,) = struct.unpack("!H", data) | |
150 | data = sock.recv(datalen) | |
151 | except socket.timeout as e: | |
152 | print("Timeout: %s" % (str(e))) | |
153 | data = None | |
154 | except socket.error as e: | |
155 | print("Network error: %s" % (str(e))) | |
156 | data = None | |
157 | finally: | |
158 | sock.close() | |
159 | ||
160 | message = None | |
161 | if data: | |
162 | message = dns.message.from_wire(data) | |
163 | return message | |
164 | ||
165 | @classmethod | |
166 | def sendTCPQueryMultiResponse(cls, query, timeout=2.0, count=1): | |
167 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
168 | if timeout: | |
169 | sock.settimeout(timeout) | |
170 | ||
171 | sock.connect(("127.0.0.1", cls._ixfrDistPort)) | |
172 | ||
173 | try: | |
174 | wire = query.to_wire() | |
175 | sock.send(struct.pack("!H", len(wire))) | |
176 | sock.send(wire) | |
177 | except socket.timeout as e: | |
178 | raise Exception("Timeout: %s" % (str(e))) | |
179 | except socket.error as e: | |
180 | raise Exception("Network error: %s" % (str(e))) | |
181 | ||
182 | messages = [] | |
183 | for i in range(count): | |
184 | try: | |
185 | data = sock.recv(2) | |
186 | if data: | |
187 | (datalen,) = struct.unpack("!H", data) | |
188 | data = sock.recv(datalen) | |
189 | messages.append(dns.message.from_wire(data)) | |
190 | else: | |
191 | break | |
192 | except socket.timeout as e: | |
193 | raise Exception("Timeout: %s" % (str(e))) | |
194 | except socket.error as e: | |
195 | raise Exception("Network error: %s" % (str(e))) | |
196 | ||
197 | return messages | |
198 | ||
199 | def setUp(self): | |
200 | # This function is called before every tests | |
5d5d29d4 | 201 | super(IXFRDistTest, self).setUp() |
eb353acc | 202 |