]>
git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26 #include "dnsparser.hh"
27 #include "dnsdist-cache.hh"
28 #include "dnsdist-ecs.hh"
29 #include "ednsoptions.hh"
30 #include "ednssubnet.hh"
32 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries
, uint32_t maxTTL
, uint32_t minTTL
, uint32_t tempFailureTTL
, uint32_t maxNegativeTTL
, uint32_t staleTTL
, bool dontAge
, uint32_t shards
, bool deferrableInsertLock
, bool parseECS
): d_maxEntries(maxEntries
), d_shardCount(shards
), d_maxTTL(maxTTL
), d_tempFailureTTL(tempFailureTTL
), d_maxNegativeTTL(maxNegativeTTL
), d_minTTL(minTTL
), d_staleTTL(staleTTL
), d_dontAge(dontAge
), d_deferrableInsertLock(deferrableInsertLock
), d_parseECS(parseECS
)
34 d_shards
.resize(d_shardCount
);
36 /* we reserve maxEntries + 1 to avoid rehashing from occurring
37 when we get to maxEntries, as it means a load factor of 1 */
38 for (auto& shard
: d_shards
) {
39 shard
.setSize((maxEntries
/ d_shardCount
) + 1);
43 DNSDistPacketCache::~DNSDistPacketCache()
46 vector
<std::unique_ptr
<WriteLock
>> locks
;
47 for (uint32_t shardIndex
= 0; shardIndex
< d_shardCount
; shardIndex
++) {
48 locks
.push_back(std::unique_ptr
<WriteLock
>(new WriteLock(&d_shards
.at(shardIndex
).d_lock
)));
55 bool DNSDistPacketCache::getClientSubnet(const char* packet
, unsigned int consumed
, uint16_t len
, boost::optional
<Netmask
>& subnet
)
57 uint16_t optRDPosition
;
60 int res
= getEDNSOptionsStart(const_cast<char*>(packet
), consumed
, len
, &optRDPosition
, &remaining
);
63 char * ecsOptionStart
= nullptr;
64 size_t ecsOptionSize
= 0;
66 res
= getEDNSOption(const_cast<char*>(packet
) + optRDPosition
, remaining
, EDNSOptionCode::ECS
, &ecsOptionStart
, &ecsOptionSize
);
68 if (res
== 0 && ecsOptionSize
> (EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
)) {
71 if (getEDNSSubnetOptsFromString(ecsOptionStart
+ (EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
), ecsOptionSize
- (EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
), &eso
) == true) {
81 bool DNSDistPacketCache::cachedValueMatches(const CacheValue
& cachedValue
, uint16_t queryFlags
, const DNSName
& qname
, uint16_t qtype
, uint16_t qclass
, bool tcp
, bool dnssecOK
, const boost::optional
<Netmask
>& subnet
) const
83 if (cachedValue
.queryFlags
!= queryFlags
|| cachedValue
.dnssecOK
!= dnssecOK
|| cachedValue
.tcp
!= tcp
|| cachedValue
.qtype
!= qtype
|| cachedValue
.qclass
!= qclass
|| cachedValue
.qname
!= qname
) {
87 if (d_parseECS
&& cachedValue
.subnet
!= subnet
) {
94 void DNSDistPacketCache::insertLocked(CacheShard
& shard
, uint32_t key
, CacheValue
& newValue
)
96 auto& map
= shard
.d_map
;
97 /* check again now that we hold the lock to prevent a race */
98 if (map
.size() >= (d_maxEntries
/ d_shardCount
)) {
102 std::unordered_map
<uint32_t,CacheValue
>::iterator it
;
104 tie(it
, result
) = map
.insert({key
, newValue
});
107 shard
.d_entriesCount
++;
111 /* in case of collision, don't override the existing entry
112 except if it has expired */
113 CacheValue
& value
= it
->second
;
114 bool wasExpired
= value
.validity
<= newValue
.added
;
116 if (!wasExpired
&& !cachedValueMatches(value
, newValue
.queryFlags
, newValue
.qname
, newValue
.qtype
, newValue
.qclass
, newValue
.tcp
, newValue
.dnssecOK
, newValue
.subnet
)) {
117 d_insertCollisions
++;
121 /* if the existing entry had a longer TTD, keep it */
122 if (newValue
.validity
<= value
.validity
) {
129 void DNSDistPacketCache::insert(uint32_t key
, const boost::optional
<Netmask
>& subnet
, uint16_t queryFlags
, bool dnssecOK
, const DNSName
& qname
, uint16_t qtype
, uint16_t qclass
, const char* response
, uint16_t responseLen
, bool tcp
, uint8_t rcode
, boost::optional
<uint32_t> tempFailureTTL
)
131 if (responseLen
< sizeof(dnsheader
)) {
137 if (rcode
== RCode::ServFail
|| rcode
== RCode::Refused
) {
138 minTTL
= tempFailureTTL
== boost::none
? d_tempFailureTTL
: *tempFailureTTL
;
144 bool seenAuthSOA
= false;
145 minTTL
= getMinTTL(response
, responseLen
, &seenAuthSOA
);
147 /* no TTL found, we don't want to cache this */
148 if (minTTL
== std::numeric_limits
<uint32_t>::max()) {
152 if (rcode
== RCode::NXDomain
|| (rcode
== RCode::NoError
&& seenAuthSOA
)) {
153 minTTL
= std::min(minTTL
, d_maxNegativeTTL
);
155 else if (minTTL
> d_maxTTL
) {
159 if (minTTL
< d_minTTL
) {
165 uint32_t shardIndex
= getShardIndex(key
);
167 if (d_shards
.at(shardIndex
).d_entriesCount
>= (d_maxEntries
/ d_shardCount
)) {
171 const time_t now
= time(nullptr);
172 time_t newValidity
= now
+ minTTL
;
174 newValue
.qname
= qname
;
175 newValue
.qtype
= qtype
;
176 newValue
.qclass
= qclass
;
177 newValue
.queryFlags
= queryFlags
;
178 newValue
.len
= responseLen
;
179 newValue
.validity
= newValidity
;
180 newValue
.added
= now
;
182 newValue
.dnssecOK
= dnssecOK
;
183 newValue
.value
= std::string(response
, responseLen
);
184 newValue
.subnet
= subnet
;
186 auto& shard
= d_shards
.at(shardIndex
);
188 if (d_deferrableInsertLock
) {
189 TryWriteLock
w(&shard
.d_lock
);
195 insertLocked(shard
, key
, newValue
);
198 WriteLock
w(&shard
.d_lock
);
200 insertLocked(shard
, key
, newValue
);
204 bool DNSDistPacketCache::get(const DNSQuestion
& dq
, uint16_t consumed
, uint16_t queryId
, char* response
, uint16_t* responseLen
, uint32_t* keyOut
, boost::optional
<Netmask
>& subnet
, bool dnssecOK
, uint32_t allowExpired
, bool skipAging
)
206 std::string
dnsQName(dq
.qname
->toDNSString());
207 uint32_t key
= getKey(dnsQName
, consumed
, reinterpret_cast<const unsigned char*>(dq
.dh
), dq
.len
, dq
.tcp
);
213 getClientSubnet(reinterpret_cast<const char*>(dq
.dh
), consumed
, dq
.len
, subnet
);
216 uint32_t shardIndex
= getShardIndex(key
);
217 time_t now
= time(nullptr);
220 auto& shard
= d_shards
.at(shardIndex
);
221 auto& map
= shard
.d_map
;
223 TryReadLock
r(&shard
.d_lock
);
229 std::unordered_map
<uint32_t,CacheValue
>::const_iterator it
= map
.find(key
);
230 if (it
== map
.end()) {
235 const CacheValue
& value
= it
->second
;
236 if (value
.validity
<= now
) {
237 if ((now
- value
.validity
) >= static_cast<time_t>(allowExpired
)) {
246 if (*responseLen
< value
.len
|| value
.len
< sizeof(dnsheader
)) {
250 /* check for collision */
251 if (!cachedValueMatches(value
, *(getFlagsFromDNSHeader(dq
.dh
)), *dq
.qname
, dq
.qtype
, dq
.qclass
, dq
.tcp
, dnssecOK
, subnet
)) {
252 d_lookupCollisions
++;
256 memcpy(response
, &queryId
, sizeof(queryId
));
257 memcpy(response
+ sizeof(queryId
), value
.value
.c_str() + sizeof(queryId
), sizeof(dnsheader
) - sizeof(queryId
));
259 if (value
.len
== sizeof(dnsheader
)) {
260 /* DNS header only, our work here is done */
261 *responseLen
= value
.len
;
266 const size_t dnsQNameLen
= dnsQName
.length();
267 if (value
.len
< (sizeof(dnsheader
) + dnsQNameLen
)) {
271 memcpy(response
+ sizeof(dnsheader
), dnsQName
.c_str(), dnsQNameLen
);
272 if (value
.len
> (sizeof(dnsheader
) + dnsQNameLen
)) {
273 memcpy(response
+ sizeof(dnsheader
) + dnsQNameLen
, value
.value
.c_str() + sizeof(dnsheader
) + dnsQNameLen
, value
.len
- (sizeof(dnsheader
) + dnsQNameLen
));
275 *responseLen
= value
.len
;
277 age
= now
- value
.added
;
280 age
= (value
.validity
- value
.added
) - d_staleTTL
;
284 if (!d_dontAge
&& !skipAging
) {
285 ageDNSPacket(response
, *responseLen
, age
);
292 /* Remove expired entries, until the cache has at most
295 size_t DNSDistPacketCache::purgeExpired(size_t upTo
)
298 uint64_t size
= getSize();
300 if (size
== 0 || upTo
>= size
) {
304 size_t toRemove
= size
- upTo
;
306 size_t scannedMaps
= 0;
308 const time_t now
= time(nullptr);
310 uint32_t shardIndex
= (d_expungeIndex
++ % d_shardCount
);
311 WriteLock
w(&d_shards
.at(shardIndex
).d_lock
);
312 auto& map
= d_shards
[shardIndex
].d_map
;
314 for(auto it
= map
.begin(); toRemove
> 0 && it
!= map
.end(); ) {
315 const CacheValue
& value
= it
->second
;
317 if (value
.validity
<= now
) {
320 d_shards
[shardIndex
].d_entriesCount
--;
329 while (toRemove
> 0 && scannedMaps
< d_shardCount
);
334 /* Remove all entries, keeping only upTo
335 entries in the cache */
336 size_t DNSDistPacketCache::expunge(size_t upTo
)
339 const uint64_t size
= getSize();
345 size_t toRemove
= size
- upTo
;
347 for (uint32_t shardIndex
= 0; shardIndex
< d_shardCount
; shardIndex
++) {
348 WriteLock
w(&d_shards
.at(shardIndex
).d_lock
);
349 auto& map
= d_shards
[shardIndex
].d_map
;
350 auto beginIt
= map
.begin();
351 auto endIt
= beginIt
;
352 size_t removeFromThisShard
= (toRemove
- removed
) / (d_shardCount
- shardIndex
);
353 if (map
.size() >= removeFromThisShard
) {
354 std::advance(endIt
, removeFromThisShard
);
355 map
.erase(beginIt
, endIt
);
356 d_shards
[shardIndex
].d_entriesCount
-= removeFromThisShard
;
357 removed
+= removeFromThisShard
;
360 removed
+= map
.size();
362 d_shards
[shardIndex
].d_entriesCount
= 0;
369 size_t DNSDistPacketCache::expungeByName(const DNSName
& name
, uint16_t qtype
, bool suffixMatch
)
373 for (uint32_t shardIndex
= 0; shardIndex
< d_shardCount
; shardIndex
++) {
374 WriteLock
w(&d_shards
.at(shardIndex
).d_lock
);
375 auto& map
= d_shards
[shardIndex
].d_map
;
377 for(auto it
= map
.begin(); it
!= map
.end(); ) {
378 const CacheValue
& value
= it
->second
;
380 if ((value
.qname
== name
|| (suffixMatch
&& value
.qname
.isPartOf(name
))) && (qtype
== QType::ANY
|| qtype
== value
.qtype
)) {
382 d_shards
[shardIndex
].d_entriesCount
--;
393 bool DNSDistPacketCache::isFull()
395 return (getSize() >= d_maxEntries
);
398 uint64_t DNSDistPacketCache::getSize()
402 for (uint32_t shardIndex
= 0; shardIndex
< d_shardCount
; shardIndex
++) {
403 count
+= d_shards
.at(shardIndex
).d_entriesCount
;
409 uint32_t DNSDistPacketCache::getMinTTL(const char* packet
, uint16_t length
, bool* seenNoDataSOA
)
411 return getDNSPacketMinTTL(packet
, length
, seenNoDataSOA
);
414 uint32_t DNSDistPacketCache::getKey(const std::string
& qname
, uint16_t consumed
, const unsigned char* packet
, uint16_t packetLen
, bool tcp
)
417 /* skip the query ID */
418 if (packetLen
< sizeof(dnsheader
))
419 throw std::range_error("Computing packet cache key for an invalid packet size");
420 result
= burtle(packet
+ 2, sizeof(dnsheader
) - 2, result
);
421 string
lc(toLower(qname
));
422 result
= burtle((const unsigned char*) lc
.c_str(), lc
.length(), result
);
423 if (packetLen
< sizeof(dnsheader
) + consumed
) {
424 throw std::range_error("Computing packet cache key for an invalid packet");
426 if (packetLen
> ((sizeof(dnsheader
) + consumed
))) {
427 result
= burtle(packet
+ sizeof(dnsheader
) + consumed
, packetLen
- (sizeof(dnsheader
) + consumed
), result
);
429 result
= burtle((const unsigned char*) &tcp
, sizeof(tcp
), result
);
433 uint32_t DNSDistPacketCache::getShardIndex(uint32_t key
) const
435 return key
% d_shardCount
;
438 string
DNSDistPacketCache::toString()
440 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries
);
443 uint64_t DNSDistPacketCache::getEntriesCount()
448 uint64_t DNSDistPacketCache::dump(int fd
)
450 FILE * fp
= fdopen(dup(fd
), "w");
455 fprintf(fp
, "; dnsdist's packet cache dump follows\n;\n");
458 time_t now
= time(nullptr);
459 for (uint32_t shardIndex
= 0; shardIndex
< d_shardCount
; shardIndex
++) {
460 ReadLock
w(&d_shards
.at(shardIndex
).d_lock
);
461 auto& map
= d_shards
[shardIndex
].d_map
;
463 for(const auto entry
: map
) {
464 const CacheValue
& value
= entry
.second
;
468 fprintf(fp
, "%s %" PRId64
" %s ; key %" PRIu32
", length %" PRIu16
", tcp %d, added %" PRId64
"\n", value
.qname
.toString().c_str(), static_cast<int64_t>(value
.validity
- now
), QType(value
.qtype
).getName().c_str(), entry
.first
, value
.len
, value
.tcp
, static_cast<int64_t>(value
.added
));
471 fprintf(fp
, "; error printing '%s'\n", value
.qname
.empty() ? "EMPTY" : value
.qname
.toString().c_str());