]> git.ipfire.org Git - thirdparty/pdns.git/blame - pdns/dnsdist-cache.cc
Sphinx 1.8.0 seems broken, use any other version available instead
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
CommitLineData
12471842
PL
1/*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
1ea747c0 22#include "dnsdist.hh"
886e2cf2 23#include "dolog.hh"
886e2cf2 24#include "dnsparser.hh"
1ea747c0 25#include "dnsdist-cache.hh"
886e2cf2 26
2b3eefc3 27DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock)
886e2cf2 28{
2b3eefc3
RG
29 d_shards.resize(d_shardCount);
30
ccac98a0 31 /* we reserve maxEntries + 1 to avoid rehashing from occurring
886e2cf2 32 when we get to maxEntries, as it means a load factor of 1 */
2b3eefc3
RG
33 for (auto& shard : d_shards) {
34 shard.setSize((maxEntries / d_shardCount) + 1);
35 }
886e2cf2
RG
36}
37
38DNSDistPacketCache::~DNSDistPacketCache()
39{
737a287f 40 try {
2b3eefc3
RG
41 vector<std::unique_ptr<WriteLock>> locks;
42 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
43 locks.push_back(std::unique_ptr<WriteLock>(new WriteLock(&d_shards.at(shardIndex).d_lock)));
44 }
737a287f 45 }
2b3eefc3 46 catch(...) {
737a287f 47 }
886e2cf2
RG
48}
49
a176d205 50bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
886e2cf2 51{
a176d205 52 if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname)
886e2cf2
RG
53 return false;
54 return true;
55}
56
2b3eefc3
RG
57void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, CacheValue& newValue, time_t now, time_t newValidity)
58{
59 auto& map = shard.d_map;
60 /* check again now that we hold the lock to prevent a race */
61 if (map.size() >= (d_maxEntries / d_shardCount)) {
62 return;
63 }
64
65 std::unordered_map<uint32_t,CacheValue>::iterator it;
66 bool result;
67 tie(it, result) = map.insert({key, newValue});
68
69 if (result) {
70 shard.d_entriesCount++;
71 return;
72 }
73
74 /* in case of collision, don't override the existing entry
75 except if it has expired */
76 CacheValue& value = it->second;
77 bool wasExpired = value.validity <= now;
78
79 if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) {
80 d_insertCollisions++;
81 return;
82 }
83
84 /* if the existing entry had a longer TTD, keep it */
85 if (newValidity <= value.validity) {
86 return;
87 }
88
89 value = newValue;
90}
91
2714396e 92void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode)
886e2cf2 93{
f87c4aff 94 if (responseLen < sizeof(dnsheader))
886e2cf2
RG
95 return;
96
0f08e82b 97 uint32_t minTTL;
886e2cf2 98
2714396e
RG
99 if (rcode == RCode::ServFail || rcode == RCode::Refused) {
100 minTTL = d_tempFailureTTL;
f4e5b47d
RG
101 if (minTTL == 0) {
102 return;
103 }
0f08e82b
RG
104 }
105 else {
106 minTTL = getMinTTL(response, responseLen);
a3824e43
RG
107
108 /* no TTL found, we don't want to cache this */
109 if (minTTL == std::numeric_limits<uint32_t>::max()) {
110 return;
111 }
112
113 if (minTTL > d_maxTTL) {
0f08e82b 114 minTTL = d_maxTTL;
a3824e43 115 }
0f08e82b 116
cc8cefe1
RG
117 if (minTTL < d_minTTL) {
118 d_ttlTooShorts++;
0f08e82b 119 return;
cc8cefe1 120 }
0f08e82b 121 }
886e2cf2 122
2b3eefc3
RG
123 uint32_t shardIndex = getShardIndex(key);
124
125 if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
126 return;
886e2cf2
RG
127 }
128
129 const time_t now = time(NULL);
886e2cf2
RG
130 time_t newValidity = now + minTTL;
131 CacheValue newValue;
132 newValue.qname = qname;
133 newValue.qtype = qtype;
134 newValue.qclass = qclass;
135 newValue.len = responseLen;
136 newValue.validity = newValidity;
137 newValue.added = now;
a176d205 138 newValue.tcp = tcp;
886e2cf2
RG
139 newValue.value = std::string(response, responseLen);
140
2b3eefc3
RG
141 auto& shard = d_shards.at(shardIndex);
142
143 if (d_deferrableInsertLock) {
144 TryWriteLock w(&shard.d_lock);
886e2cf2
RG
145
146 if (!w.gotIt()) {
147 d_deferredInserts++;
148 return;
149 }
2b3eefc3
RG
150 insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity) ;
151 }
152 else {
153 WriteLock w(&shard.d_lock);
886e2cf2 154
2b3eefc3 155 insertLocked(shard, key, qname, qtype, qclass, tcp, newValue, now, newValidity) ;
886e2cf2
RG
156 }
157}
158
1ea747c0 159bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, uint32_t allowExpired, bool skipAging)
886e2cf2 160{
7acad2d5
RG
161 std::string dnsQName(dq.qname->toDNSString());
162 uint32_t key = getKey(dnsQName, consumed, (const unsigned char*)dq.dh, dq.len, dq.tcp);
886e2cf2
RG
163 if (keyOut)
164 *keyOut = key;
165
2b3eefc3 166 uint32_t shardIndex = getShardIndex(key);
886e2cf2
RG
167 time_t now = time(NULL);
168 time_t age;
1ea747c0 169 bool stale = false;
2b3eefc3
RG
170 auto& shard = d_shards.at(shardIndex);
171 auto& map = shard.d_map;
886e2cf2 172 {
2b3eefc3 173 TryReadLock r(&shard.d_lock);
886e2cf2
RG
174 if (!r.gotIt()) {
175 d_deferredLookups++;
176 return false;
177 }
178
2b3eefc3
RG
179 std::unordered_map<uint32_t,CacheValue>::const_iterator it = map.find(key);
180 if (it == map.end()) {
886e2cf2
RG
181 d_misses++;
182 return false;
183 }
184
185 const CacheValue& value = it->second;
186 if (value.validity < now) {
a1a0a75a 187 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
1ea747c0
RG
188 d_misses++;
189 return false;
190 }
191 else {
192 stale = true;
193 }
886e2cf2
RG
194 }
195
39a21975 196 if (*responseLen < value.len || value.len < sizeof(dnsheader)) {
886e2cf2
RG
197 return false;
198 }
199
200 /* check for collision */
1ea747c0 201 if (!cachedValueMatches(value, *dq.qname, dq.qtype, dq.qclass, dq.tcp)) {
886e2cf2
RG
202 d_lookupCollisions++;
203 return false;
204 }
205
c8c3d4e4
RG
206 memcpy(response, &queryId, sizeof(queryId));
207 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
208
209 if (value.len == sizeof(dnsheader)) {
210 /* DNS header only, our work here is done */
211 *responseLen = value.len;
212 d_hits++;
213 return true;
214 }
215
f87c4aff
RG
216 const size_t dnsQNameLen = dnsQName.length();
217 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
218 return false;
219 }
220
f87c4aff
RG
221 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
222 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
223 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
224 }
886e2cf2 225 *responseLen = value.len;
1ea747c0
RG
226 if (!stale) {
227 age = now - value.added;
228 }
229 else {
230 age = (value.validity - value.added) - d_staleTTL;
231 }
886e2cf2
RG
232 }
233
2b67180c 234 if (!d_dontAge && !skipAging) {
886e2cf2 235 ageDNSPacket(response, *responseLen, age);
1ea747c0
RG
236 }
237
886e2cf2
RG
238 d_hits++;
239 return true;
240}
241
4275aaba
RG
242/* Remove expired entries, until the cache has at most
243 upTo entries in it.
244*/
245void DNSDistPacketCache::purgeExpired(size_t upTo)
886e2cf2
RG
246{
247 time_t now = time(NULL);
2b3eefc3
RG
248 uint64_t size = getSize();
249
250 if (upTo >= size) {
886e2cf2 251 return;
4275aaba 252 }
886e2cf2 253
2b3eefc3 254 size_t toRemove = size - upTo;
886e2cf2 255
2b3eefc3
RG
256 size_t scannedMaps = 0;
257
258 do {
259 uint32_t shardIndex = (d_expungeIndex++ % d_shardCount);
260 WriteLock w(&d_shards.at(shardIndex).d_lock);
261 auto& map = d_shards[shardIndex].d_map;
262
263 for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
264 const CacheValue& value = it->second;
265
266 if (value.validity < now) {
267 it = map.erase(it);
886e2cf2 268 --toRemove;
2b3eefc3
RG
269 d_shards[shardIndex].d_entriesCount--;
270 } else {
271 ++it;
272 }
886e2cf2 273 }
2b3eefc3
RG
274
275 scannedMaps++;
886e2cf2 276 }
2b3eefc3 277 while (toRemove > 0 && scannedMaps < d_shardCount);
886e2cf2
RG
278}
279
4275aaba
RG
280/* Remove all entries, keeping only upTo
281 entries in the cache */
282void DNSDistPacketCache::expunge(size_t upTo)
283{
2b3eefc3 284 const uint64_t size = getSize();
4275aaba 285
2b3eefc3 286 if (upTo >= size) {
4275aaba
RG
287 return;
288 }
289
2b3eefc3
RG
290 size_t toRemove = size - upTo;
291 size_t removed = 0;
292
293 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
294 WriteLock w(&d_shards.at(shardIndex).d_lock);
295 auto& map = d_shards[shardIndex].d_map;
296 auto beginIt = map.begin();
297 auto endIt = beginIt;
298 size_t removeFromThisShard = (toRemove - removed) / (d_shardCount - shardIndex);
299 if (map.size() >= removeFromThisShard) {
300 std::advance(endIt, removeFromThisShard);
301 map.erase(beginIt, endIt);
302 d_shards[shardIndex].d_entriesCount -= removeFromThisShard;
303 removed += removeFromThisShard;
304 }
305 else {
306 removed += map.size();
307 map.clear();
308 d_shards[shardIndex].d_entriesCount = 0;
309 }
310 }
4275aaba
RG
311}
312
490dc586 313void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
886e2cf2 314{
2b3eefc3
RG
315 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
316 WriteLock w(&d_shards.at(shardIndex).d_lock);
317 auto& map = d_shards[shardIndex].d_map;
318
319 for(auto it = map.begin(); it != map.end(); ) {
320 const CacheValue& value = it->second;
321
322 if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
323 it = map.erase(it);
324 d_shards[shardIndex].d_entriesCount--;
325 } else {
326 ++it;
327 }
886e2cf2
RG
328 }
329 }
330}
331
332bool DNSDistPacketCache::isFull()
333{
2b3eefc3
RG
334 return (getSize() >= d_maxEntries);
335}
336
337uint64_t DNSDistPacketCache::getSize()
338{
339 uint64_t count = 0;
340
341 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
342 count += d_shards.at(shardIndex).d_entriesCount;
343 }
344
345 return count;
886e2cf2
RG
346}
347
348uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length)
349{
0766890a 350 return getDNSPacketMinTTL(packet, length);
886e2cf2
RG
351}
352
7acad2d5 353uint32_t DNSDistPacketCache::getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
886e2cf2
RG
354{
355 uint32_t result = 0;
356 /* skip the query ID */
cceddbef
RG
357 if (packetLen < sizeof(dnsheader))
358 throw std::range_error("Computing packet cache key for an invalid packet size");
886e2cf2 359 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
7acad2d5 360 string lc(toLower(qname));
886e2cf2 361 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
cceddbef
RG
362 if (packetLen < sizeof(dnsheader) + consumed) {
363 throw std::range_error("Computing packet cache key for an invalid packet");
364 }
365 if (packetLen > ((sizeof(dnsheader) + consumed))) {
366 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
367 }
a176d205 368 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
886e2cf2
RG
369 return result;
370}
371
2b3eefc3
RG
372uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
373{
374 return key % d_shardCount;
375}
376
886e2cf2
RG
377string DNSDistPacketCache::toString()
378{
2b3eefc3 379 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
886e2cf2 380}
9e9be156
RG
381
382uint64_t DNSDistPacketCache::getEntriesCount()
383{
2b3eefc3 384 return getSize();
9e9be156 385}