if (!tcp && query->paddedLen < responseLen) {
struct dnsheader* dh = (struct dnsheader*) response;
- responseLen = query->paddedLen;
+ size_t questionSize = 0;
+
+ if (responseLen > sizeof(dnsheader)) {
+ unsigned int consumed = 0;
+ DNSName qname(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed);
+ if (consumed > 0) {
+ questionSize = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+ }
+ }
+
+ responseLen = sizeof(dnsheader) + questionSize;
+
+ if (responseLen > query->paddedLen) {
+ responseLen = query->paddedLen;
+ }
+ dh->ancount = dh->arcount = dh->nscount = 0;
dh->tc = 1;
}
BOOST_CHECK_EQUAL(mdp.d_header.arcount, 0);
BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "2.name.");
- BOOST_CHECK_EQUAL(mdp.d_qclass, QClass::IN);
- BOOST_CHECK_EQUAL(mdp.d_qtype, QType::TXT);
+ BOOST_CHECK(mdp.d_qclass == QClass::IN);
+ BOOST_CHECK(mdp.d_qtype == QType::TXT);
}
// invalid plaintext query (A)
BOOST_CHECK_EQUAL(mdp.d_header.arcount, 0);
BOOST_CHECK_EQUAL(mdp.d_qname, name);
- BOOST_CHECK_EQUAL(mdp.d_qclass, QClass::IN);
- BOOST_CHECK_EQUAL(mdp.d_qtype, QType::AAAA);
+ BOOST_CHECK(mdp.d_qclass == QClass::IN);
+ BOOST_CHECK(mdp.d_qtype == QType::AAAA);
}
// valid encrypted query with not enough room
BOOST_CHECK_EQUAL(mdp.d_header.arcount, 0);
BOOST_CHECK_EQUAL(mdp.d_qname, name);
- BOOST_CHECK_EQUAL(mdp.d_qclass, QClass::IN);
- BOOST_CHECK_EQUAL(mdp.d_qtype, QType::AAAA);
+ BOOST_CHECK(mdp.d_qclass == QClass::IN);
+ BOOST_CHECK(mdp.d_qtype == QType::AAAA);
}
// valid encrypted query with wrong key
_dnsDistPort = 5340
_dnsDistPortDNSCrypt = 8443
_config_template = """
- generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", 42, %d, %d)
+ generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
newServer{address="127.0.0.1:%s"}
"""
- _dnsdistcmd = (os.environ['DNSDISTBIN'] + " -C dnsdist_DNSCrypt.conf --acl 127.0.0.1/32 -l 127.0.0.1:" + str(_dnsDistPort)).split()
_providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
_providerName = "2.provider.name"
-
- @classmethod
- def startDNSDist(cls, shutUp=True):
- print("Launching dnsdist..")
- # valid from 60s ago until 2h from now
- validFrom = time.time() - 60
- validUntil = time.time() + 7200
- with open('dnsdist_DNSCrypt.conf', 'w') as conf:
- conf.write(cls._config_template % (validFrom, validUntil, cls._dnsDistPortDNSCrypt, cls._providerName, str(cls._testServerPort)))
-
- print(' '.join(cls._dnsdistcmd))
- if shutUp:
- with open(os.devnull, 'w') as fdDevNull:
- cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True, stdout=fdDevNull, stderr=fdDevNull)
- else:
- cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True)
-
- time.sleep(1)
-
- if cls._dnsdist.poll() is not None:
- cls._dnsdist.terminate()
- cls._dnsdist.wait()
- sys.exit(cls._dnsdist.returncode)
-
+ _resolverCertificateSerial = 42
+ # valid from 60s ago until 2h from now
+ _resolverCertificateValidFrom = time.time() - 60
+ _resolverCertificateValidUntil = time.time() + 7200
+ _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
def testSimpleA(self):
"""
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
+ def testResponseLargerThanPaddedQuery(self):
+ """
+ Send a small encrypted query (don't forget to take
+ the padding into account) and check that the response
+ is truncated.
+ """
+ client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
+ name = 'smallquerylargeresponse.dnscrypt.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'TXT', 'IN', use_edns=True, payload=4096)
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.TXT,
+ 'A'*255)
+ response.answer.append(rrset)
+
+ self._toResponderQueue.put(response)
+ data = client.query(query.to_wire())
+ receivedQuery = None
+ if not self._fromResponderQueue.empty():
+ receivedQuery = self._fromResponderQueue.get(query)
+
+ receivedResponse = dns.message.from_wire(data)
+
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse.question, response.question)
+ self.assertTrue(receivedResponse.flags & ~dns.flags.TC)
+ self.assertTrue(len(receivedResponse.answer) == 0)
+ self.assertTrue(len(receivedResponse.authority) == 0)
+ self.assertTrue(len(receivedResponse.additional) == 0)
+
if __name__ == '__main__':
unittest.main()
exit(0)