]>
git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
3e085bd7d9ef27eebfb317577e08d1253f829ba4
3 #include "dnsparser.hh"
4 #include "dnsdist-cache.hh"
6 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries
, uint32_t maxTTL
, uint32_t minTTL
, uint32_t servFailTTL
, uint32_t staleTTL
): d_maxEntries(maxEntries
), d_maxTTL(maxTTL
), d_servFailTTL(servFailTTL
), d_minTTL(minTTL
), d_staleTTL(staleTTL
)
8 pthread_rwlock_init(&d_lock
, 0);
9 /* we reserve maxEntries + 1 to avoid rehashing from occuring
10 when we get to maxEntries, as it means a load factor of 1 */
11 d_map
.reserve(maxEntries
+ 1);
14 DNSDistPacketCache::~DNSDistPacketCache()
19 bool DNSDistPacketCache::cachedValueMatches(const CacheValue
& cachedValue
, const DNSName
& qname
, uint16_t qtype
, uint16_t qclass
, bool tcp
)
21 if (cachedValue
.tcp
!= tcp
|| cachedValue
.qtype
!= qtype
|| cachedValue
.qclass
!= qclass
|| cachedValue
.qname
!= qname
)
26 void DNSDistPacketCache::insert(uint32_t key
, const DNSName
& qname
, uint16_t qtype
, uint16_t qclass
, const char* response
, uint16_t responseLen
, bool tcp
, bool servFail
)
28 if (responseLen
< sizeof(dnsheader
))
34 minTTL
= d_servFailTTL
;
37 minTTL
= getMinTTL(response
, responseLen
);
38 if (minTTL
> d_maxTTL
)
41 if (minTTL
< d_minTTL
) {
48 TryReadLock
r(&d_lock
);
53 if (d_map
.size() >= d_maxEntries
) {
58 const time_t now
= time(NULL
);
59 std::unordered_map
<uint32_t,CacheValue
>::iterator it
;
61 time_t newValidity
= now
+ minTTL
;
63 newValue
.qname
= qname
;
64 newValue
.qtype
= qtype
;
65 newValue
.qclass
= qclass
;
66 newValue
.len
= responseLen
;
67 newValue
.validity
= newValidity
;
70 newValue
.value
= std::string(response
, responseLen
);
73 TryWriteLock
w(&d_lock
);
80 tie(it
, result
) = d_map
.insert({key
, newValue
});
86 /* in case of collision, don't override the existing entry
87 except if it has expired */
88 CacheValue
& value
= it
->second
;
89 bool wasExpired
= value
.validity
<= now
;
91 if (!wasExpired
&& !cachedValueMatches(value
, qname
, qtype
, qclass
, tcp
)) {
96 /* if the existing entry had a longer TTD, keep it */
97 if (newValidity
<= value
.validity
) {
105 bool DNSDistPacketCache::get(const DNSQuestion
& dq
, uint16_t consumed
, uint16_t queryId
, char* response
, uint16_t* responseLen
, uint32_t* keyOut
, uint32_t allowExpired
, bool skipAging
)
107 uint32_t key
= getKey(*dq
.qname
, consumed
, (const unsigned char*)dq
.dh
, dq
.len
, dq
.tcp
);
111 time_t now
= time(NULL
);
115 TryReadLock
r(&d_lock
);
121 std::unordered_map
<uint32_t,CacheValue
>::const_iterator it
= d_map
.find(key
);
122 if (it
== d_map
.end()) {
127 const CacheValue
& value
= it
->second
;
128 if (value
.validity
< now
) {
129 if ((now
- value
.validity
) >= static_cast<time_t>(allowExpired
)) {
138 if (*responseLen
< value
.len
) {
142 /* check for collision */
143 if (!cachedValueMatches(value
, *dq
.qname
, dq
.qtype
, dq
.qclass
, dq
.tcp
)) {
144 d_lookupCollisions
++;
148 string
dnsQName(dq
.qname
->toDNSString());
149 const size_t dnsQNameLen
= dnsQName
.length();
150 if (value
.len
< (sizeof(dnsheader
) + dnsQNameLen
)) {
154 memcpy(response
, &queryId
, sizeof(queryId
));
155 memcpy(response
+ sizeof(queryId
), value
.value
.c_str() + sizeof(queryId
), sizeof(dnsheader
) - sizeof(queryId
));
156 memcpy(response
+ sizeof(dnsheader
), dnsQName
.c_str(), dnsQNameLen
);
157 if (value
.len
> (sizeof(dnsheader
) + dnsQNameLen
)) {
158 memcpy(response
+ sizeof(dnsheader
) + dnsQNameLen
, value
.value
.c_str() + sizeof(dnsheader
) + dnsQNameLen
, value
.len
- (sizeof(dnsheader
) + dnsQNameLen
));
160 *responseLen
= value
.len
;
162 age
= now
- value
.added
;
165 age
= (value
.validity
- value
.added
) - d_staleTTL
;
170 ageDNSPacket(response
, *responseLen
, age
);
177 /* Remove expired entries, until the cache has at most
180 void DNSDistPacketCache::purgeExpired(size_t upTo
)
182 time_t now
= time(NULL
);
183 WriteLock
w(&d_lock
);
184 if (upTo
>= d_map
.size()) {
188 size_t toRemove
= d_map
.size() - upTo
;
189 for(auto it
= d_map
.begin(); toRemove
> 0 && it
!= d_map
.end(); ) {
190 const CacheValue
& value
= it
->second
;
192 if (value
.validity
< now
) {
193 it
= d_map
.erase(it
);
201 /* Remove all entries, keeping only upTo
202 entries in the cache */
203 void DNSDistPacketCache::expunge(size_t upTo
)
205 WriteLock
w(&d_lock
);
207 if (upTo
>= d_map
.size()) {
211 size_t toRemove
= d_map
.size() - upTo
;
212 auto beginIt
= d_map
.begin();
213 auto endIt
= beginIt
;
214 std::advance(endIt
, toRemove
);
215 d_map
.erase(beginIt
, endIt
);
218 void DNSDistPacketCache::expungeByName(const DNSName
& name
, uint16_t qtype
)
220 WriteLock
w(&d_lock
);
222 for(auto it
= d_map
.begin(); it
!= d_map
.end(); ) {
223 const CacheValue
& value
= it
->second
;
225 uint16_t cqclass
= 0;
226 DNSName
cqname(value
.value
.c_str(), value
.len
, sizeof(dnsheader
), false, &cqtype
, &cqclass
, nullptr);
228 if (cqname
== name
&& (qtype
== QType::ANY
|| qtype
== cqtype
)) {
229 it
= d_map
.erase(it
);
236 bool DNSDistPacketCache::isFull()
239 return (d_map
.size() >= d_maxEntries
);
242 uint32_t DNSDistPacketCache::getMinTTL(const char* packet
, uint16_t length
)
244 return getDNSPacketMinTTL(packet
, length
);
247 uint32_t DNSDistPacketCache::getKey(const DNSName
& qname
, uint16_t consumed
, const unsigned char* packet
, uint16_t packetLen
, bool tcp
)
250 /* skip the query ID */
251 if (packetLen
< sizeof(dnsheader
))
252 throw std::range_error("Computing packet cache key for an invalid packet size");
253 result
= burtle(packet
+ 2, sizeof(dnsheader
) - 2, result
);
254 string
lc(qname
.toDNSStringLC());
255 result
= burtle((const unsigned char*) lc
.c_str(), lc
.length(), result
);
256 if (packetLen
< sizeof(dnsheader
) + consumed
) {
257 throw std::range_error("Computing packet cache key for an invalid packet");
259 if (packetLen
> ((sizeof(dnsheader
) + consumed
))) {
260 result
= burtle(packet
+ sizeof(dnsheader
) + consumed
, packetLen
- (sizeof(dnsheader
) + consumed
), result
);
262 result
= burtle((const unsigned char*) &tcp
, sizeof(tcp
), result
);
266 string
DNSDistPacketCache::toString()
269 return std::to_string(d_map
.size()) + "/" + std::to_string(d_maxEntries
);
272 uint64_t DNSDistPacketCache::getEntriesCount()