]>
Commit | Line | Data |
---|---|---|
1d0bd88a RG |
1 | #!/usr/bin/env python |
2 | import Queue | |
3 | import threading | |
4 | import socket | |
5 | import struct | |
6 | import sys | |
7 | import time | |
8 | from dnsdisttests import DNSDistTest | |
9 | ||
10 | import dns | |
11 | import dnsmessage_pb2 | |
12 | ||
13 | class TestProtobuf(DNSDistTest): | |
14 | ||
15 | _protobufServerPort = 4242 | |
16 | _protobufQueue = Queue.Queue() | |
17 | _protobufCounter = 0 | |
18 | _config_params = ['_testServerPort', '_protobufServerPort'] | |
19 | _config_template = """ | |
20 | newServer{address="127.0.0.1:%s", useClientSubnet=true} | |
21 | rl = newRemoteLogger('127.0.0.1:%s') | |
22 | addAction(AllRule(), RemoteLogAction(rl)) | |
23 | addResponseAction(AllRule(), RemoteLogResponseAction(rl)) | |
24 | """ | |
25 | ||
26 | @classmethod | |
27 | def ProtobufListener(cls, port): | |
28 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
29 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
30 | try: | |
31 | sock.bind(("127.0.0.1", port)) | |
32 | except socket.error as e: | |
33 | print("Error binding in the protbuf listener: %s" % str(e)) | |
34 | sys.exit(1) | |
35 | ||
36 | sock.listen(100) | |
37 | while True: | |
38 | (conn, _) = sock.accept() | |
39 | data = None | |
40 | while True: | |
41 | data = conn.recv(2) | |
42 | if not data: | |
43 | break | |
44 | (datalen,) = struct.unpack("!H", data) | |
45 | data = conn.recv(datalen) | |
46 | if not data: | |
47 | break | |
48 | ||
49 | cls._protobufQueue.put(data, True, timeout=2.0) | |
50 | ||
51 | conn.close() | |
52 | sock.close() | |
53 | ||
54 | @classmethod | |
55 | def startResponders(cls): | |
56 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort]) | |
57 | cls._UDPResponder.setDaemon(True) | |
58 | cls._UDPResponder.start() | |
59 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort]) | |
60 | cls._TCPResponder.setDaemon(True) | |
61 | cls._TCPResponder.start() | |
62 | ||
63 | cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort]) | |
64 | cls._protobufListener.setDaemon(True) | |
65 | cls._protobufListener.start() | |
66 | ||
67 | def getFirstProtobufMessage(self): | |
68 | self.assertFalse(self._protobufQueue.empty()) | |
69 | data = self._protobufQueue.get(False) | |
70 | self.assertTrue(data) | |
71 | msg = dnsmessage_pb2.PBDNSMessage() | |
72 | msg.ParseFromString(data) | |
73 | return msg | |
74 | ||
75 | def checkProtobufBase(self, msg, protocol, query): | |
76 | self.assertTrue(msg) | |
77 | self.assertTrue(msg.HasField('timeSec')) | |
78 | self.assertTrue(msg.HasField('socketFamily')) | |
79 | self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) | |
80 | self.assertTrue(msg.HasField('from')) | |
81 | fromvalue = getattr(msg, 'from') | |
82 | self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1') | |
83 | self.assertTrue(msg.HasField('socketProtocol')) | |
84 | self.assertEquals(msg.socketProtocol, protocol) | |
85 | self.assertTrue(msg.HasField('messageId')) | |
86 | self.assertTrue(msg.HasField('id')) | |
87 | self.assertEquals(msg.id, query.id) | |
88 | self.assertTrue(msg.HasField('inBytes')) | |
89 | self.assertEquals(msg.inBytes, len(query.to_wire())) | |
90 | # dnsdist doesn't set the existing EDNS Subnet for now, | |
91 | # although it might be set from Lua | |
92 | # self.assertTrue(msg.HasField('originalRequestorSubnet')) | |
93 | # self.assertEquals(len(msg.originalRequestorSubnet), 4) | |
94 | # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') | |
95 | ||
96 | def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname): | |
97 | self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) | |
98 | self.checkProtobufBase(msg, protocol, query) | |
99 | # dnsdist doesn't fill the responder field for responses | |
100 | # because it doesn't keep the information around. | |
101 | self.assertTrue(msg.HasField('to')) | |
102 | self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1') | |
103 | self.assertTrue(msg.HasField('question')) | |
104 | self.assertTrue(msg.question.HasField('qClass')) | |
105 | self.assertEquals(msg.question.qClass, qclass) | |
106 | self.assertTrue(msg.question.HasField('qType')) | |
107 | self.assertEquals(msg.question.qClass, qtype) | |
108 | self.assertTrue(msg.question.HasField('qName')) | |
109 | self.assertEquals(msg.question.qName, qname) | |
110 | ||
111 | def checkProtobufResponse(self, msg, protocol, response): | |
112 | self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) | |
113 | self.checkProtobufBase(msg, protocol, response) | |
114 | self.assertTrue(msg.HasField('response')) | |
115 | self.assertTrue(msg.response.HasField('queryTimeSec')) | |
116 | ||
117 | def testProtobuf(self): | |
118 | """ | |
119 | Protobuf: Send data to a protobuf server | |
120 | """ | |
121 | name = 'query.protobuf.tests.powerdns.com.' | |
122 | query = dns.message.make_query(name, 'A', 'IN') | |
123 | response = dns.message.make_response(query) | |
124 | rrset = dns.rrset.from_text(name, | |
125 | 3600, | |
126 | dns.rdataclass.IN, | |
127 | dns.rdatatype.A, | |
128 | '127.0.0.1') | |
129 | response.answer.append(rrset) | |
130 | ||
131 | (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) | |
132 | self.assertTrue(receivedQuery) | |
133 | self.assertTrue(receivedResponse) | |
134 | receivedQuery.id = query.id | |
135 | self.assertEquals(query, receivedQuery) | |
136 | self.assertEquals(response, receivedResponse) | |
137 | ||
138 | # let the protobuf messages the time to get there | |
139 | time.sleep(1) | |
140 | ||
141 | # check the protobuf message corresponding to the UDP query | |
142 | msg = self.getFirstProtobufMessage() | |
143 | self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) | |
144 | ||
145 | # check the protobuf message corresponding to the UDP response | |
146 | msg = self.getFirstProtobufMessage() | |
147 | self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) | |
148 | self.assertEquals(len(msg.response.rrs), 1) | |
149 | for rr in msg.response.rrs: | |
150 | self.assertTrue(rr.HasField('class')) | |
151 | self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) | |
152 | self.assertTrue(rr.HasField('type')) | |
153 | self.assertEquals(rr.type, dns.rdatatype.A) | |
154 | self.assertTrue(rr.HasField('name')) | |
155 | self.assertEquals(rr.name, name) | |
156 | self.assertTrue(rr.HasField('ttl')) | |
157 | self.assertEquals(rr.ttl, 3600) | |
158 | self.assertTrue(rr.HasField('rdata')) | |
159 | self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') | |
160 | ||
161 | (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) | |
162 | self.assertTrue(receivedQuery) | |
163 | self.assertTrue(receivedResponse) | |
164 | receivedQuery.id = query.id | |
165 | self.assertEquals(query, receivedQuery) | |
166 | self.assertEquals(response, receivedResponse) | |
167 | ||
168 | # let the protobuf messages the time to get there | |
169 | time.sleep(1) | |
170 | ||
171 | # check the protobuf message corresponding to the TCP query | |
172 | msg = self.getFirstProtobufMessage() | |
173 | self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name) | |
174 | ||
175 | # check the protobuf message corresponding to the TCP response | |
176 | msg = self.getFirstProtobufMessage() | |
177 | self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) | |
178 | self.assertEquals(len(msg.response.rrs), 1) | |
179 | for rr in msg.response.rrs: | |
180 | self.assertTrue(rr.HasField('class')) | |
181 | self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) | |
182 | self.assertTrue(rr.HasField('type')) | |
183 | self.assertEquals(rr.type, dns.rdatatype.A) | |
184 | self.assertTrue(rr.HasField('name')) | |
185 | self.assertEquals(rr.name, name) | |
186 | self.assertTrue(rr.HasField('ttl')) | |
187 | self.assertEquals(rr.ttl, 3600) | |
188 | self.assertTrue(rr.HasField('rdata')) | |
189 | self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') |