5 from dnsdisttests
import DNSDistTest
8 class TestDNSCrypt(DNSDistTest
):
10 dnsdist is configured to accept DNSCrypt queries on 127.0.0.1:_dnsDistPortDNSCrypt.
11 The provider's keys have been generated with:
12 generateDNSCryptProviderKeys("DNSCryptProviderPublic.key", "DNSCryptProviderPrivate.key")
13 Be careful to change the _providerFingerprint below if you want to regenerate the keys.
17 _dnsDistPortDNSCrypt
= 8443
18 _config_template
= """
19 generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
20 addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
21 newServer{address="127.0.0.1:%s"}
24 _providerFingerprint
= 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
25 _providerName
= "2.provider.name"
26 _resolverCertificateSerial
= 42
27 # valid from 60s ago until 2h from now
28 _resolverCertificateValidFrom
= time
.time() - 60
29 _resolverCertificateValidUntil
= time
.time() + 7200
30 _config_params
= ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
31 _dnsdistStartupDelay
= 10
33 def testSimpleA(self
):
35 DNSCrypt: encrypted A query
37 client
= dnscrypt
.DNSCryptClient(self
._providerName
, self
._providerFingerprint
, "127.0.0.1", 8443)
38 name
= 'a.dnscrypt.tests.powerdns.com.'
39 query
= dns
.message
.make_query(name
, 'A', 'IN')
40 response
= dns
.message
.make_response(query
)
41 rrset
= dns
.rrset
.from_text(name
,
46 response
.answer
.append(rrset
)
48 self
._toResponderQueue
.put(response
)
49 data
= client
.query(query
.to_wire())
50 receivedResponse
= dns
.message
.from_wire(data
)
52 if not self
._fromResponderQueue
.empty():
53 receivedQuery
= self
._fromResponderQueue
.get(query
)
55 self
.assertTrue(receivedQuery
)
56 self
.assertTrue(receivedResponse
)
57 receivedQuery
.id = query
.id
58 self
.assertEquals(query
, receivedQuery
)
59 self
.assertEquals(response
, receivedResponse
)
61 self
._toResponderQueue
.put(response
)
62 data
= client
.query(query
.to_wire(), tcp
=True)
63 receivedResponse
= dns
.message
.from_wire(data
)
65 if not self
._fromResponderQueue
.empty():
66 receivedQuery
= self
._fromResponderQueue
.get(query
)
68 self
.assertTrue(receivedQuery
)
69 self
.assertTrue(receivedResponse
)
70 receivedQuery
.id = query
.id
71 self
.assertEquals(query
, receivedQuery
)
72 self
.assertEquals(response
, receivedResponse
)
74 def testResponseLargerThanPaddedQuery(self
):
76 DNSCrypt: response larger than query
78 Send a small encrypted query (don't forget to take
79 the padding into account) and check that the response
82 client
= dnscrypt
.DNSCryptClient(self
._providerName
, self
._providerFingerprint
, "127.0.0.1", 8443)
83 name
= 'smallquerylargeresponse.dnscrypt.tests.powerdns.com.'
84 query
= dns
.message
.make_query(name
, 'TXT', 'IN', use_edns
=True, payload
=4096)
85 response
= dns
.message
.make_response(query
)
86 rrset
= dns
.rrset
.from_text(name
,
91 response
.answer
.append(rrset
)
93 self
._toResponderQueue
.put(response
)
94 data
= client
.query(query
.to_wire())
96 if not self
._fromResponderQueue
.empty():
97 receivedQuery
= self
._fromResponderQueue
.get(query
)
99 receivedResponse
= dns
.message
.from_wire(data
)
101 self
.assertTrue(receivedQuery
)
102 receivedQuery
.id = query
.id
103 self
.assertEquals(query
, receivedQuery
)
104 self
.assertEquals(receivedResponse
.question
, response
.question
)
105 self
.assertTrue(receivedResponse
.flags
& ~dns
.flags
.TC
)
106 self
.assertTrue(len(receivedResponse
.answer
) == 0)
107 self
.assertTrue(len(receivedResponse
.authority
) == 0)
108 self
.assertTrue(len(receivedResponse
.additional
) == 0)
110 class TestDNSCryptWithCache(DNSDistTest
):
111 _dnsDistPortDNSCrypt
= 8443
112 _providerFingerprint
= 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
113 _providerName
= "2.provider.name"
114 _resolverCertificateSerial
= 42
115 # valid from 60s ago until 2h from now
116 _resolverCertificateValidFrom
= time
.time() - 60
117 _resolverCertificateValidUntil
= time
.time() + 7200
118 _config_params
= ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
119 _config_template
= """
120 generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
121 addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
122 pc = newPacketCache(5, 86400, 1)
123 getPool(""):setCache(pc)
124 newServer{address="127.0.0.1:%s"}
127 def testCachedSimpleA(self
):
129 DNSCrypt: encrypted A query served from cache
132 client
= dnscrypt
.DNSCryptClient(self
._providerName
, self
._providerFingerprint
, "127.0.0.1", 8443)
133 name
= 'cacheda.dnscrypt.tests.powerdns.com.'
134 query
= dns
.message
.make_query(name
, 'A', 'IN')
135 response
= dns
.message
.make_response(query
)
136 rrset
= dns
.rrset
.from_text(name
,
141 response
.answer
.append(rrset
)
143 # first query to fill the cache
144 self
._toResponderQueue
.put(response
)
145 data
= client
.query(query
.to_wire())
146 receivedResponse
= dns
.message
.from_wire(data
)
148 if not self
._fromResponderQueue
.empty():
149 receivedQuery
= self
._fromResponderQueue
.get(query
)
151 self
.assertTrue(receivedQuery
)
152 self
.assertTrue(receivedResponse
)
153 receivedQuery
.id = query
.id
154 self
.assertEquals(query
, receivedQuery
)
155 self
.assertEquals(response
, receivedResponse
)
158 # second query should get a cached response
159 data
= client
.query(query
.to_wire())
160 receivedResponse
= dns
.message
.from_wire(data
)
162 if not self
._fromResponderQueue
.empty():
163 receivedQuery
= self
._fromResponderQueue
.get(query
)
165 self
.assertEquals(receivedQuery
, None)
166 self
.assertTrue(receivedResponse
)
167 self
.assertEquals(response
, receivedResponse
)
169 for key
in self
._responsesCounter
:
170 total
+= self
._responsesCounter
[key
]
171 self
.assertEquals(total
, misses
)