WriteLock l(&d_lock);
}
-bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass)
+bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
{
- if (cachedValue.qname != qname || cachedValue.qtype != qtype || cachedValue.qclass != qclass)
+ if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname)
return false;
return true;
}
-void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen)
+void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp)
{
if (responseLen == 0)
return;
newValue.len = responseLen;
newValue.validity = newValidity;
newValue.added = now;
+ newValue.tcp = tcp;
newValue.value = std::string(response, responseLen);
{
CacheValue& value = it->second;
bool wasExpired = value.validity <= now;
- if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass)) {
+ if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) {
d_insertCollisions++;
return;
}
}
}
-bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging)
+bool DNSDistPacketCache::get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, bool tcp, uint32_t* keyOut, bool skipAging)
{
- uint32_t key = getKey(qname, consumed, query, queryLen);
+ uint32_t key = getKey(qname, consumed, query, queryLen, tcp);
if (keyOut)
*keyOut = key;
}
/* check for collision */
- if (!cachedValueMatches(value, qname, qtype, qclass)) {
+ if (!cachedValueMatches(value, qname, qtype, qclass, tcp)) {
d_misses++;
d_lookupCollisions++;
return false;
return result;
}
-uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen)
+uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
{
uint32_t result = 0;
/* skip the query ID */
string lc(qname.toDNSStringLC());
result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
+ result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
return result;
}
DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=60);
~DNSDistPacketCache();
- void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen);
- bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, bool skipAging=false);
+ void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp);
+ bool get(const unsigned char* query, uint16_t queryLen, const DNSName& qname, uint16_t qtype, uint16_t qclass, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, bool tcp, uint32_t* keyOut, bool skipAging=false);
void purge(size_t upTo=0);
void expunge(const DNSName& name, uint16_t qtype=QType::ANY);
bool isFull();
time_t added{0};
time_t validity{0};
uint16_t len{0};
+ bool tcp{false};
};
- static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen);
- static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass);
+ static uint32_t getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp);
+ static bool cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp);
pthread_rwlock_t d_lock;
std::unordered_map<uint32_t,CacheValue> d_map;
if (serverPool->packetCache && !dq.skipCache) {
char cachedResponse[4096];
uint16_t cachedResponseSize = sizeof cachedResponse;
- if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+ if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, true, &cacheKey)) {
if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
g_stats.cacheHits++;
}
if (serverPool->packetCache && !dq.skipCache) {
- serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen);
+ serverPool->packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true);
}
#ifdef HAVE_DNSCRYPT
g_stats.responses++;
if (ids->packetCache && !ids->skipCache) {
- ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen);
+ ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false);
}
#ifdef HAVE_DNSCRYPT
if (serverPool->packetCache && !dq.skipCache) {
char cachedResponse[4096];
uint16_t cachedResponseSize = sizeof cachedResponse;
- if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey)) {
+ if (serverPool->packetCache->get((unsigned char*) query, dq.len, *dq.qname, dq.qtype, dq.qclass, consumed, dh->id, cachedResponse, &cachedResponseSize, false, &cacheKey)) {
ComboAddress dest;
if(HarvestDestinationAddress(&msgh, &dest))
sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
char responseBuf[4096];
uint16_t responseBufSize = sizeof(responseBuf);
uint32_t key = 0;
- bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key);
BOOST_CHECK_EQUAL(found, false);
- PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+ PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false);
- found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, true);
+ found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, false, &key, true);
if (found == true) {
BOOST_CHECK_EQUAL(responseBufSize, responseLen);
int match = memcmp(responseBuf, response.data(), responseLen);
char responseBuf[4096];
uint16_t responseBufSize = sizeof(responseBuf);
uint32_t key = 0;
- bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key);
if (found == true) {
PC.expunge(a);
deleted++;
uint32_t key = 0;
char response[4096];
uint16_t responseSize = sizeof(response);
- if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key)) {
+ if(PC.get(query.data(), len, a, QType::A, QClass::IN, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, false, &key)) {
matches++;
}
}
char responseBuf[4096];
uint16_t responseBufSize = sizeof(responseBuf);
uint32_t key = 0;
- PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key);
- PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen);
+ PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false);
}
}
catch(PDNSException& e) {
char responseBuf[4096];
uint16_t responseBufSize = sizeof(responseBuf);
uint32_t key = 0;
- bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, &key);
+ bool found = PC.get((const unsigned char*) query.data(), query.size(), a, QType::A, QClass::IN, a.wirelength(), 0, responseBuf, &responseBufSize, false, &key);
if (!found) {
g_missing++;
}
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
for _ in range(numberOfQueries):
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
+ total = 0
+ for key in TestAdvancedCaching._responsesCounter:
+ total += TestAdvancedCaching._responsesCounter[key]
+ TestAdvancedCaching._responsesCounter[key] = 0
+
+ self.assertEquals(total, 1)
+
+ # TCP should not be cached
+ # first query to fill the cache
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ for _ in range(numberOfQueries):
(_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
total = 0
for key in TestAdvancedCaching._responsesCounter:
total += TestAdvancedCaching._responsesCounter[key]
+ TestAdvancedCaching._responsesCounter[key] = 0
self.assertEquals(total, 1)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
misses += 1
# next queries should hit the cache
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
# now we wait a bit for the cache entry to expire
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
misses += 1
# following queries should hit the cache again
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
total = 0
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
misses += 1
# next queries should hit the cache
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
# now we wait a bit for the cache entry to expire
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
misses += 1
# following queries should hit the cache again
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
total = 0
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
misses += 1
# next queries should hit the cache
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
- for an in receivedResponse.answer:
- self.assertTrue(an.ttl <= ttl)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
for an in receivedResponse.answer:
self.assertTrue(an.ttl <= ttl)
# next queries should hit the cache
(_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
- self.assertEquals(receivedResponse, response)
- for an in receivedResponse.answer:
- self.assertTrue(an.ttl < ttl)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- receivedResponse.id = response.id
self.assertEquals(receivedResponse, response)
for an in receivedResponse.answer:
self.assertTrue(an.ttl < ttl)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(receivedResponse, response)
# different case query should still hit the cache
(_, receivedResponse) = self.sendUDPQuery(differentCaseQuery, response=None, useQueue=False)
- receivedResponse.id = differentCaseResponse.id
self.assertEquals(receivedResponse, differentCaseResponse)
- (_, receivedResponse) = self.sendTCPQuery(differentCaseQuery, response=None, useQueue=False)
- receivedResponse.id = differentCaseResponse.id
- self.assertEquals(receivedResponse, differentCaseResponse)
class TestAdvancedCachingWithExistingEDNS(DNSDistTest):
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
misses += 1
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
misses += 1
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
self.assertTrue(receivedQuery)
self.assertTrue(receivedResponse)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)