]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
3e085bd7d9ef27eebfb317577e08d1253f829ba4
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
1 #include "dnsdist.hh"
2 #include "dolog.hh"
3 #include "dnsparser.hh"
4 #include "dnsdist-cache.hh"
5
6 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t servFailTTL, uint32_t staleTTL): d_maxEntries(maxEntries), d_maxTTL(maxTTL), d_servFailTTL(servFailTTL), d_minTTL(minTTL), d_staleTTL(staleTTL)
7 {
8 pthread_rwlock_init(&d_lock, 0);
9 /* we reserve maxEntries + 1 to avoid rehashing from occuring
10 when we get to maxEntries, as it means a load factor of 1 */
11 d_map.reserve(maxEntries + 1);
12 }
13
14 DNSDistPacketCache::~DNSDistPacketCache()
15 {
16 WriteLock l(&d_lock);
17 }
18
19 bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
20 {
21 if (cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname)
22 return false;
23 return true;
24 }
25
26 void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, bool servFail)
27 {
28 if (responseLen < sizeof(dnsheader))
29 return;
30
31 uint32_t minTTL;
32
33 if (servFail) {
34 minTTL = d_servFailTTL;
35 }
36 else {
37 minTTL = getMinTTL(response, responseLen);
38 if (minTTL > d_maxTTL)
39 minTTL = d_maxTTL;
40
41 if (minTTL < d_minTTL) {
42 d_ttlTooShorts++;
43 return;
44 }
45 }
46
47 {
48 TryReadLock r(&d_lock);
49 if (!r.gotIt()) {
50 d_deferredInserts++;
51 return;
52 }
53 if (d_map.size() >= d_maxEntries) {
54 return;
55 }
56 }
57
58 const time_t now = time(NULL);
59 std::unordered_map<uint32_t,CacheValue>::iterator it;
60 bool result;
61 time_t newValidity = now + minTTL;
62 CacheValue newValue;
63 newValue.qname = qname;
64 newValue.qtype = qtype;
65 newValue.qclass = qclass;
66 newValue.len = responseLen;
67 newValue.validity = newValidity;
68 newValue.added = now;
69 newValue.tcp = tcp;
70 newValue.value = std::string(response, responseLen);
71
72 {
73 TryWriteLock w(&d_lock);
74
75 if (!w.gotIt()) {
76 d_deferredInserts++;
77 return;
78 }
79
80 tie(it, result) = d_map.insert({key, newValue});
81
82 if (result) {
83 return;
84 }
85
86 /* in case of collision, don't override the existing entry
87 except if it has expired */
88 CacheValue& value = it->second;
89 bool wasExpired = value.validity <= now;
90
91 if (!wasExpired && !cachedValueMatches(value, qname, qtype, qclass, tcp)) {
92 d_insertCollisions++;
93 return;
94 }
95
96 /* if the existing entry had a longer TTD, keep it */
97 if (newValidity <= value.validity) {
98 return;
99 }
100
101 value = newValue;
102 }
103 }
104
105 bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, uint32_t allowExpired, bool skipAging)
106 {
107 uint32_t key = getKey(*dq.qname, consumed, (const unsigned char*)dq.dh, dq.len, dq.tcp);
108 if (keyOut)
109 *keyOut = key;
110
111 time_t now = time(NULL);
112 time_t age;
113 bool stale = false;
114 {
115 TryReadLock r(&d_lock);
116 if (!r.gotIt()) {
117 d_deferredLookups++;
118 return false;
119 }
120
121 std::unordered_map<uint32_t,CacheValue>::const_iterator it = d_map.find(key);
122 if (it == d_map.end()) {
123 d_misses++;
124 return false;
125 }
126
127 const CacheValue& value = it->second;
128 if (value.validity < now) {
129 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
130 d_misses++;
131 return false;
132 }
133 else {
134 stale = true;
135 }
136 }
137
138 if (*responseLen < value.len) {
139 return false;
140 }
141
142 /* check for collision */
143 if (!cachedValueMatches(value, *dq.qname, dq.qtype, dq.qclass, dq.tcp)) {
144 d_lookupCollisions++;
145 return false;
146 }
147
148 string dnsQName(dq.qname->toDNSString());
149 const size_t dnsQNameLen = dnsQName.length();
150 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
151 return false;
152 }
153
154 memcpy(response, &queryId, sizeof(queryId));
155 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
156 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
157 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
158 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
159 }
160 *responseLen = value.len;
161 if (!stale) {
162 age = now - value.added;
163 }
164 else {
165 age = (value.validity - value.added) - d_staleTTL;
166 }
167 }
168
169 if (!skipAging) {
170 ageDNSPacket(response, *responseLen, age);
171 }
172
173 d_hits++;
174 return true;
175 }
176
177 /* Remove expired entries, until the cache has at most
178 upTo entries in it.
179 */
180 void DNSDistPacketCache::purgeExpired(size_t upTo)
181 {
182 time_t now = time(NULL);
183 WriteLock w(&d_lock);
184 if (upTo >= d_map.size()) {
185 return;
186 }
187
188 size_t toRemove = d_map.size() - upTo;
189 for(auto it = d_map.begin(); toRemove > 0 && it != d_map.end(); ) {
190 const CacheValue& value = it->second;
191
192 if (value.validity < now) {
193 it = d_map.erase(it);
194 --toRemove;
195 } else {
196 ++it;
197 }
198 }
199 }
200
201 /* Remove all entries, keeping only upTo
202 entries in the cache */
203 void DNSDistPacketCache::expunge(size_t upTo)
204 {
205 WriteLock w(&d_lock);
206
207 if (upTo >= d_map.size()) {
208 return;
209 }
210
211 size_t toRemove = d_map.size() - upTo;
212 auto beginIt = d_map.begin();
213 auto endIt = beginIt;
214 std::advance(endIt, toRemove);
215 d_map.erase(beginIt, endIt);
216 }
217
218 void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype)
219 {
220 WriteLock w(&d_lock);
221
222 for(auto it = d_map.begin(); it != d_map.end(); ) {
223 const CacheValue& value = it->second;
224 uint16_t cqtype = 0;
225 uint16_t cqclass = 0;
226 DNSName cqname(value.value.c_str(), value.len, sizeof(dnsheader), false, &cqtype, &cqclass, nullptr);
227
228 if (cqname == name && (qtype == QType::ANY || qtype == cqtype)) {
229 it = d_map.erase(it);
230 } else {
231 ++it;
232 }
233 }
234 }
235
236 bool DNSDistPacketCache::isFull()
237 {
238 ReadLock r(&d_lock);
239 return (d_map.size() >= d_maxEntries);
240 }
241
242 uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length)
243 {
244 return getDNSPacketMinTTL(packet, length);
245 }
246
247 uint32_t DNSDistPacketCache::getKey(const DNSName& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
248 {
249 uint32_t result = 0;
250 /* skip the query ID */
251 if (packetLen < sizeof(dnsheader))
252 throw std::range_error("Computing packet cache key for an invalid packet size");
253 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
254 string lc(qname.toDNSStringLC());
255 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
256 if (packetLen < sizeof(dnsheader) + consumed) {
257 throw std::range_error("Computing packet cache key for an invalid packet");
258 }
259 if (packetLen > ((sizeof(dnsheader) + consumed))) {
260 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
261 }
262 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
263 return result;
264 }
265
266 string DNSDistPacketCache::toString()
267 {
268 ReadLock r(&d_lock);
269 return std::to_string(d_map.size()) + "/" + std::to_string(d_maxEntries);
270 }
271
272 uint64_t DNSDistPacketCache::getEntriesCount()
273 {
274 ReadLock r(&d_lock);
275 return d_map.size();
276 }