]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
rec: ensure correct service user on debian
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include <cinttypes>
23
24 #include "dnsdist.hh"
25 #include "dolog.hh"
26 #include "dnsparser.hh"
27 #include "dnsdist-cache.hh"
28 #include "dnsdist-ecs.hh"
29 #include "ednsoptions.hh"
30 #include "ednssubnet.hh"
31
32 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t maxNegativeTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock, bool parseECS): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_maxNegativeTTL(maxNegativeTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock), d_parseECS(parseECS)
33 {
34 d_shards.resize(d_shardCount);
35
36 /* we reserve maxEntries + 1 to avoid rehashing from occurring
37 when we get to maxEntries, as it means a load factor of 1 */
38 for (auto& shard : d_shards) {
39 shard.setSize((maxEntries / d_shardCount) + 1);
40 }
41 }
42
43 DNSDistPacketCache::~DNSDistPacketCache()
44 {
45 try {
46 vector<std::unique_ptr<WriteLock>> locks;
47 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
48 locks.push_back(std::unique_ptr<WriteLock>(new WriteLock(&d_shards.at(shardIndex).d_lock)));
49 }
50 }
51 catch(...) {
52 }
53 }
54
55 bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional<Netmask>& subnet)
56 {
57 uint16_t optRDPosition;
58 size_t remaining = 0;
59
60 int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDPosition, &remaining);
61
62 if (res == 0) {
63 char * ecsOptionStart = nullptr;
64 size_t ecsOptionSize = 0;
65
66 res = getEDNSOption(const_cast<char*>(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
67
68 if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
69
70 EDNSSubnetOpts eso;
71 if (getEDNSSubnetOptsFromString(ecsOptionStart + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) {
72 subnet = eso.source;
73 return true;
74 }
75 }
76 }
77
78 return false;
79 }
80
81 bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, bool dnssecOK, const boost::optional<Netmask>& subnet) const
82 {
83 if (cachedValue.queryFlags != queryFlags || cachedValue.dnssecOK != dnssecOK || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
84 return false;
85 }
86
87 if (d_parseECS && cachedValue.subnet != subnet) {
88 return false;
89 }
90
91 return true;
92 }
93
94 void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValue& newValue)
95 {
96 auto& map = shard.d_map;
97 /* check again now that we hold the lock to prevent a race */
98 if (map.size() >= (d_maxEntries / d_shardCount)) {
99 return;
100 }
101
102 std::unordered_map<uint32_t,CacheValue>::iterator it;
103 bool result;
104 tie(it, result) = map.insert({key, newValue});
105
106 if (result) {
107 shard.d_entriesCount++;
108 return;
109 }
110
111 /* in case of collision, don't override the existing entry
112 except if it has expired */
113 CacheValue& value = it->second;
114 bool wasExpired = value.validity <= newValue.added;
115
116 if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.dnssecOK, newValue.subnet)) {
117 d_insertCollisions++;
118 return;
119 }
120
121 /* if the existing entry had a longer TTD, keep it */
122 if (newValue.validity <= value.validity) {
123 return;
124 }
125
126 value = newValue;
127 }
128
129 void DNSDistPacketCache::insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
130 {
131 if (responseLen < sizeof(dnsheader)) {
132 return;
133 }
134
135 uint32_t minTTL;
136
137 if (rcode == RCode::ServFail || rcode == RCode::Refused) {
138 minTTL = tempFailureTTL == boost::none ? d_tempFailureTTL : *tempFailureTTL;
139 if (minTTL == 0) {
140 return;
141 }
142 }
143 else {
144 bool seenAuthSOA = false;
145 minTTL = getMinTTL(response, responseLen, &seenAuthSOA);
146
147 /* no TTL found, we don't want to cache this */
148 if (minTTL == std::numeric_limits<uint32_t>::max()) {
149 return;
150 }
151
152 if (rcode == RCode::NXDomain || (rcode == RCode::NoError && seenAuthSOA)) {
153 minTTL = std::min(minTTL, d_maxNegativeTTL);
154 }
155 else if (minTTL > d_maxTTL) {
156 minTTL = d_maxTTL;
157 }
158
159 if (minTTL < d_minTTL) {
160 d_ttlTooShorts++;
161 return;
162 }
163 }
164
165 uint32_t shardIndex = getShardIndex(key);
166
167 if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
168 return;
169 }
170
171 const time_t now = time(nullptr);
172 time_t newValidity = now + minTTL;
173 CacheValue newValue;
174 newValue.qname = qname;
175 newValue.qtype = qtype;
176 newValue.qclass = qclass;
177 newValue.queryFlags = queryFlags;
178 newValue.len = responseLen;
179 newValue.validity = newValidity;
180 newValue.added = now;
181 newValue.tcp = tcp;
182 newValue.dnssecOK = dnssecOK;
183 newValue.value = std::string(response, responseLen);
184 newValue.subnet = subnet;
185
186 auto& shard = d_shards.at(shardIndex);
187
188 if (d_deferrableInsertLock) {
189 TryWriteLock w(&shard.d_lock);
190
191 if (!w.gotIt()) {
192 d_deferredInserts++;
193 return;
194 }
195 insertLocked(shard, key, newValue);
196 }
197 else {
198 WriteLock w(&shard.d_lock);
199
200 insertLocked(shard, key, newValue);
201 }
202 }
203
204 bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional<Netmask>& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging)
205 {
206 std::string dnsQName(dq.qname->toDNSString());
207 uint32_t key = getKey(dnsQName, consumed, reinterpret_cast<const unsigned char*>(dq.dh), dq.len, dq.tcp);
208
209 if (keyOut)
210 *keyOut = key;
211
212 if (d_parseECS) {
213 getClientSubnet(reinterpret_cast<const char*>(dq.dh), consumed, dq.len, subnet);
214 }
215
216 uint32_t shardIndex = getShardIndex(key);
217 time_t now = time(nullptr);
218 time_t age;
219 bool stale = false;
220 auto& shard = d_shards.at(shardIndex);
221 auto& map = shard.d_map;
222 {
223 TryReadLock r(&shard.d_lock);
224 if (!r.gotIt()) {
225 d_deferredLookups++;
226 return false;
227 }
228
229 std::unordered_map<uint32_t,CacheValue>::const_iterator it = map.find(key);
230 if (it == map.end()) {
231 d_misses++;
232 return false;
233 }
234
235 const CacheValue& value = it->second;
236 if (value.validity <= now) {
237 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
238 d_misses++;
239 return false;
240 }
241 else {
242 stale = true;
243 }
244 }
245
246 if (*responseLen < value.len || value.len < sizeof(dnsheader)) {
247 return false;
248 }
249
250 /* check for collision */
251 if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) {
252 d_lookupCollisions++;
253 return false;
254 }
255
256 memcpy(response, &queryId, sizeof(queryId));
257 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
258
259 if (value.len == sizeof(dnsheader)) {
260 /* DNS header only, our work here is done */
261 *responseLen = value.len;
262 d_hits++;
263 return true;
264 }
265
266 const size_t dnsQNameLen = dnsQName.length();
267 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
268 return false;
269 }
270
271 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
272 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
273 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
274 }
275 *responseLen = value.len;
276 if (!stale) {
277 age = now - value.added;
278 }
279 else {
280 age = (value.validity - value.added) - d_staleTTL;
281 }
282 }
283
284 if (!d_dontAge && !skipAging) {
285 ageDNSPacket(response, *responseLen, age);
286 }
287
288 d_hits++;
289 return true;
290 }
291
292 /* Remove expired entries, until the cache has at most
293 upTo entries in it.
294 */
295 size_t DNSDistPacketCache::purgeExpired(size_t upTo)
296 {
297 size_t removed = 0;
298 uint64_t size = getSize();
299
300 if (size == 0 || upTo >= size) {
301 return removed;
302 }
303
304 size_t toRemove = size - upTo;
305
306 size_t scannedMaps = 0;
307
308 const time_t now = time(nullptr);
309 do {
310 uint32_t shardIndex = (d_expungeIndex++ % d_shardCount);
311 WriteLock w(&d_shards.at(shardIndex).d_lock);
312 auto& map = d_shards[shardIndex].d_map;
313
314 for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
315 const CacheValue& value = it->second;
316
317 if (value.validity <= now) {
318 it = map.erase(it);
319 --toRemove;
320 d_shards[shardIndex].d_entriesCount--;
321 ++removed;
322 } else {
323 ++it;
324 }
325 }
326
327 scannedMaps++;
328 }
329 while (toRemove > 0 && scannedMaps < d_shardCount);
330
331 return removed;
332 }
333
334 /* Remove all entries, keeping only upTo
335 entries in the cache */
336 size_t DNSDistPacketCache::expunge(size_t upTo)
337 {
338 size_t removed = 0;
339 const uint64_t size = getSize();
340
341 if (upTo >= size) {
342 return removed;
343 }
344
345 size_t toRemove = size - upTo;
346
347 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
348 WriteLock w(&d_shards.at(shardIndex).d_lock);
349 auto& map = d_shards[shardIndex].d_map;
350 auto beginIt = map.begin();
351 auto endIt = beginIt;
352 size_t removeFromThisShard = (toRemove - removed) / (d_shardCount - shardIndex);
353 if (map.size() >= removeFromThisShard) {
354 std::advance(endIt, removeFromThisShard);
355 map.erase(beginIt, endIt);
356 d_shards[shardIndex].d_entriesCount -= removeFromThisShard;
357 removed += removeFromThisShard;
358 }
359 else {
360 removed += map.size();
361 map.clear();
362 d_shards[shardIndex].d_entriesCount = 0;
363 }
364 }
365
366 return removed;
367 }
368
369 size_t DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
370 {
371 size_t removed = 0;
372
373 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
374 WriteLock w(&d_shards.at(shardIndex).d_lock);
375 auto& map = d_shards[shardIndex].d_map;
376
377 for(auto it = map.begin(); it != map.end(); ) {
378 const CacheValue& value = it->second;
379
380 if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
381 it = map.erase(it);
382 d_shards[shardIndex].d_entriesCount--;
383 ++removed;
384 } else {
385 ++it;
386 }
387 }
388 }
389
390 return removed;
391 }
392
393 bool DNSDistPacketCache::isFull()
394 {
395 return (getSize() >= d_maxEntries);
396 }
397
398 uint64_t DNSDistPacketCache::getSize()
399 {
400 uint64_t count = 0;
401
402 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
403 count += d_shards.at(shardIndex).d_entriesCount;
404 }
405
406 return count;
407 }
408
409 uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA)
410 {
411 return getDNSPacketMinTTL(packet, length, seenNoDataSOA);
412 }
413
414 uint32_t DNSDistPacketCache::getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
415 {
416 uint32_t result = 0;
417 /* skip the query ID */
418 if (packetLen < sizeof(dnsheader))
419 throw std::range_error("Computing packet cache key for an invalid packet size");
420 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
421 string lc(toLower(qname));
422 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
423 if (packetLen < sizeof(dnsheader) + consumed) {
424 throw std::range_error("Computing packet cache key for an invalid packet");
425 }
426 if (packetLen > ((sizeof(dnsheader) + consumed))) {
427 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
428 }
429 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
430 return result;
431 }
432
433 uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
434 {
435 return key % d_shardCount;
436 }
437
438 string DNSDistPacketCache::toString()
439 {
440 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
441 }
442
443 uint64_t DNSDistPacketCache::getEntriesCount()
444 {
445 return getSize();
446 }
447
448 uint64_t DNSDistPacketCache::dump(int fd)
449 {
450 FILE * fp = fdopen(dup(fd), "w");
451 if (fp == nullptr) {
452 return 0;
453 }
454
455 fprintf(fp, "; dnsdist's packet cache dump follows\n;\n");
456
457 uint64_t count = 0;
458 time_t now = time(nullptr);
459 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
460 ReadLock w(&d_shards.at(shardIndex).d_lock);
461 auto& map = d_shards[shardIndex].d_map;
462
463 for(const auto entry : map) {
464 const CacheValue& value = entry.second;
465 count++;
466
467 try {
468 fprintf(fp, "%s %" PRId64 " %s ; key %" PRIu32 ", length %" PRIu16 ", tcp %d, added %" PRId64 "\n", value.qname.toString().c_str(), static_cast<int64_t>(value.validity - now), QType(value.qtype).getName().c_str(), entry.first, value.len, value.tcp, static_cast<int64_t>(value.added));
469 }
470 catch(...) {
471 fprintf(fp, "; error printing '%s'\n", value.qname.empty() ? "EMPTY" : value.qname.toString().c_str());
472 }
473 }
474 }
475
476 fclose(fp);
477 return count;
478 }