]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/iputils.hh
5e74d0ddc9b94da612af55890e5345b450bc99f4
[thirdparty/pdns.git] / pdns / iputils.hh
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #pragma once
23 #include <string>
24 #include <sys/socket.h>
25 #include <netinet/in.h>
26 #include <arpa/inet.h>
27 #include <iostream>
28 #include <cstdio>
29 #include <functional>
30 #include "pdnsexception.hh"
31 #include "misc.hh"
32 #include <netdb.h>
33 #include <sstream>
34 #include <sys/un.h>
35
36 #include "namespaces.hh"
37
38 #ifdef __APPLE__
39 #include <libkern/OSByteOrder.h>
40
41 #define htobe16(x) OSSwapHostToBigInt16(x)
42 #define htole16(x) OSSwapHostToLittleInt16(x)
43 #define be16toh(x) OSSwapBigToHostInt16(x)
44 #define le16toh(x) OSSwapLittleToHostInt16(x)
45
46 #define htobe32(x) OSSwapHostToBigInt32(x)
47 #define htole32(x) OSSwapHostToLittleInt32(x)
48 #define be32toh(x) OSSwapBigToHostInt32(x)
49 #define le32toh(x) OSSwapLittleToHostInt32(x)
50
51 #define htobe64(x) OSSwapHostToBigInt64(x)
52 #define htole64(x) OSSwapHostToLittleInt64(x)
53 #define be64toh(x) OSSwapBigToHostInt64(x)
54 #define le64toh(x) OSSwapLittleToHostInt64(x)
55
56 #if defined(CONNECT_DATA_IDEMPOTENT) && defined(CONNECT_RESUME_ON_READ_WRITE)
57 #define CONNECTX_FASTOPEN 1
58 #endif
59
60 #endif
61
62 #ifdef __sun
63
64 #define htobe16(x) BE_16(x)
65 #define htole16(x) LE_16(x)
66 #define be16toh(x) BE_IN16(&(x))
67 #define le16toh(x) LE_IN16(&(x))
68
69 #define htobe32(x) BE_32(x)
70 #define htole32(x) LE_32(x)
71 #define be32toh(x) BE_IN32(&(x))
72 #define le32toh(x) LE_IN32(&(x))
73
74 #define htobe64(x) BE_64(x)
75 #define htole64(x) LE_64(x)
76 #define be64toh(x) BE_IN64(&(x))
77 #define le64toh(x) LE_IN64(&(x))
78
79 #endif
80
81 #ifdef __FreeBSD__
82 #include <sys/endian.h>
83 #endif
84
85 #if defined(__NetBSD__) && defined(IP_PKTINFO) && !defined(IP_SENDSRCADDR)
86 // The IP_PKTINFO option in NetBSD was incompatible with Linux until a
87 // change that also introduced IP_SENDSRCADDR for FreeBSD compatibility.
88 #undef IP_PKTINFO
89 #endif
90
91 union ComboAddress
92 {
93 sockaddr_in sin4{};
94 sockaddr_in6 sin6;
95
96 bool operator==(const ComboAddress& rhs) const
97 {
98 if (std::tie(sin4.sin_family, sin4.sin_port) != std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port)) {
99 return false;
100 }
101 if (sin4.sin_family == AF_INET) {
102 return sin4.sin_addr.s_addr == rhs.sin4.sin_addr.s_addr;
103 }
104 return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr)) == 0;
105 }
106
107 bool operator!=(const ComboAddress& rhs) const
108 {
109 return (!operator==(rhs));
110 }
111
112 bool operator<(const ComboAddress& rhs) const
113 {
114 if (sin4.sin_family == 0) {
115 return false;
116 }
117 if (std::tie(sin4.sin_family, sin4.sin_port) < std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port)) {
118 return true;
119 }
120 if (std::tie(sin4.sin_family, sin4.sin_port) > std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port)) {
121 return false;
122 }
123 if (sin4.sin_family == AF_INET) {
124 return sin4.sin_addr.s_addr < rhs.sin4.sin_addr.s_addr;
125 }
126 return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr)) < 0;
127 }
128
129 bool operator>(const ComboAddress& rhs) const
130 {
131 return rhs.operator<(*this);
132 }
133
134 struct addressPortOnlyHash
135 {
136 uint32_t operator()(const ComboAddress& address) const
137 {
138 // NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
139 if (address.sin4.sin_family == AF_INET) {
140 const auto* start = reinterpret_cast<const unsigned char*>(&address.sin4.sin_addr.s_addr);
141 auto tmp = burtle(start, 4, 0);
142 return burtle(reinterpret_cast<const uint8_t*>(&address.sin4.sin_port), 2, tmp);
143 }
144 const auto* start = reinterpret_cast<const unsigned char*>(&address.sin6.sin6_addr.s6_addr);
145 auto tmp = burtle(start, 16, 0);
146 return burtle(reinterpret_cast<const unsigned char*>(&address.sin6.sin6_port), 2, tmp);
147 // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
148 }
149 };
150
151 struct addressOnlyHash
152 {
153 uint32_t operator()(const ComboAddress& address) const
154 {
155 const unsigned char* start = nullptr;
156 uint32_t len = 0;
157 // NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
158 if (address.sin4.sin_family == AF_INET) {
159 start = reinterpret_cast<const unsigned char*>(&address.sin4.sin_addr.s_addr);
160 len = 4;
161 }
162 else {
163 start = reinterpret_cast<const unsigned char*>(&address.sin6.sin6_addr.s6_addr);
164 len = 16;
165 }
166 // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
167 return burtle(start, len, 0);
168 }
169 };
170
171 struct addressOnlyLessThan
172 {
173 bool operator()(const ComboAddress& lhs, const ComboAddress& rhs) const
174 {
175 if (lhs.sin4.sin_family < rhs.sin4.sin_family) {
176 return true;
177 }
178 if (lhs.sin4.sin_family > rhs.sin4.sin_family) {
179 return false;
180 }
181 if (lhs.sin4.sin_family == AF_INET) {
182 return lhs.sin4.sin_addr.s_addr < rhs.sin4.sin_addr.s_addr;
183 }
184 return memcmp(&lhs.sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(lhs.sin6.sin6_addr.s6_addr)) < 0;
185 }
186 };
187
188 struct addressOnlyEqual
189 {
190 bool operator()(const ComboAddress& lhs, const ComboAddress& rhs) const
191 {
192 if (lhs.sin4.sin_family != rhs.sin4.sin_family) {
193 return false;
194 }
195 if (lhs.sin4.sin_family == AF_INET) {
196 return lhs.sin4.sin_addr.s_addr == rhs.sin4.sin_addr.s_addr;
197 }
198 return memcmp(&lhs.sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(lhs.sin6.sin6_addr.s6_addr)) == 0;
199 }
200 };
201
202 [[nodiscard]] socklen_t getSocklen() const
203 {
204 if (sin4.sin_family == AF_INET) {
205 return sizeof(sin4);
206 }
207 return sizeof(sin6);
208 }
209
210 ComboAddress()
211 {
212 sin4.sin_family = AF_INET;
213 sin4.sin_addr.s_addr = 0;
214 sin4.sin_port = 0;
215 sin6.sin6_scope_id = 0;
216 sin6.sin6_flowinfo = 0;
217 }
218
219 ComboAddress(const struct sockaddr* socketAddress, socklen_t salen)
220 {
221 setSockaddr(socketAddress, salen);
222 };
223
224 ComboAddress(const struct sockaddr_in6* socketAddress)
225 {
226 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
227 setSockaddr(reinterpret_cast<const struct sockaddr*>(socketAddress), sizeof(struct sockaddr_in6));
228 };
229
230 ComboAddress(const struct sockaddr_in* socketAddress)
231 {
232 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
233 setSockaddr(reinterpret_cast<const struct sockaddr*>(socketAddress), sizeof(struct sockaddr_in));
234 };
235
236 void setSockaddr(const struct sockaddr* socketAddress, socklen_t salen)
237 {
238 if (salen > sizeof(struct sockaddr_in6)) {
239 throw PDNSException("ComboAddress can't handle other than sockaddr_in or sockaddr_in6");
240 }
241 memcpy(this, socketAddress, salen);
242 }
243
244 // 'port' sets a default value in case 'str' does not set a port
245 explicit ComboAddress(const string& str, uint16_t port = 0)
246 {
247 memset(&sin6, 0, sizeof(sin6));
248 sin4.sin_family = AF_INET;
249 sin4.sin_port = 0;
250 if (makeIPv4sockaddr(str, &sin4) != 0) {
251 sin6.sin6_family = AF_INET6;
252 if (makeIPv6sockaddr(str, &sin6) < 0) {
253 throw PDNSException("Unable to convert presentation address '" + str + "'");
254 }
255 }
256 if (sin4.sin_port == 0) { // 'str' overrides port!
257 sin4.sin_port = htons(port);
258 }
259 }
260
261 [[nodiscard]] bool isIPv6() const
262 {
263 return sin4.sin_family == AF_INET6;
264 }
265 [[nodiscard]] bool isIPv4() const
266 {
267 return sin4.sin_family == AF_INET;
268 }
269
270 [[nodiscard]] bool isMappedIPv4() const
271 {
272 if (sin4.sin_family != AF_INET6) {
273 return false;
274 }
275
276 int iter = 0;
277 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
278 const auto* ptr = reinterpret_cast<const unsigned char*>(&sin6.sin6_addr.s6_addr);
279 for (iter = 0; iter < 10; ++iter) {
280 if (ptr[iter] != 0) { // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
281 return false;
282 }
283 }
284 for (; iter < 12; ++iter) {
285 if (ptr[iter] != 0xff) { // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
286 return false;
287 }
288 }
289 return true;
290 }
291
292 [[nodiscard]] bool isUnspecified() const
293 {
294 static const ComboAddress unspecifiedV4("0.0.0.0:0");
295 static const ComboAddress unspecifiedV6("[::]:0");
296 const auto compare = ComboAddress::addressOnlyEqual();
297 return compare(*this, unspecifiedV4) || compare(*this, unspecifiedV6);
298 }
299
300 [[nodiscard]] ComboAddress mapToIPv4() const
301 {
302 if (!isMappedIPv4()) {
303 throw PDNSException("ComboAddress can't map non-mapped IPv6 address back to IPv4");
304 }
305 ComboAddress ret;
306 ret.sin4.sin_family = AF_INET;
307 ret.sin4.sin_port = sin4.sin_port;
308
309 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
310 const auto* ptr = reinterpret_cast<const unsigned char*>(&sin6.sin6_addr.s6_addr);
311 ptr += (sizeof(sin6.sin6_addr.s6_addr) - sizeof(ret.sin4.sin_addr.s_addr)); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
312 memcpy(&ret.sin4.sin_addr.s_addr, ptr, sizeof(ret.sin4.sin_addr.s_addr));
313 return ret;
314 }
315
316 [[nodiscard]] string toString() const
317 {
318 std::array<char, 1024> host{};
319 if (sin4.sin_family != 0) {
320 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
321 int retval = getnameinfo(reinterpret_cast<const struct sockaddr*>(this), getSocklen(), host.data(), host.size(), nullptr, 0, NI_NUMERICHOST);
322 if (retval == 0) {
323 return host.data();
324 }
325 return "invalid " + string(gai_strerror(retval));
326 }
327 return "invalid";
328 }
329
330 //! Ignores any interface specifiers possibly available in the sockaddr data.
331 [[nodiscard]] string toStringNoInterface() const
332 {
333 std::array<char, 1024> host{};
334 if (sin4.sin_family == AF_INET) {
335 const auto* ret = inet_ntop(sin4.sin_family, &sin4.sin_addr, host.data(), host.size());
336 if (ret != nullptr) {
337 return host.data();
338 }
339 }
340 else if (sin4.sin_family == AF_INET6) {
341 const auto* ret = inet_ntop(sin4.sin_family, &sin6.sin6_addr, host.data(), host.size());
342 if (ret != nullptr) {
343 return host.data();
344 }
345 }
346 else {
347 return "invalid";
348 }
349 return "invalid " + stringerror();
350 }
351
352 [[nodiscard]] string toStringReversed() const
353 {
354 if (isIPv4()) {
355 const auto address = ntohl(sin4.sin_addr.s_addr);
356 auto aaa = (address >> 0) & 0xFF;
357 auto bbb = (address >> 8) & 0xFF;
358 auto ccc = (address >> 16) & 0xFF;
359 auto ddd = (address >> 24) & 0xFF;
360 return std::to_string(aaa) + "." + std::to_string(bbb) + "." + std::to_string(ccc) + "." + std::to_string(ddd);
361 }
362 const auto* addr = &sin6.sin6_addr;
363 std::stringstream res{};
364 res << std::hex;
365 for (int i = 15; i >= 0; i--) {
366 auto byte = addr->s6_addr[i]; // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
367 res << ((byte >> 0) & 0xF) << ".";
368 res << ((byte >> 4) & 0xF);
369 if (i != 0) {
370 res << ".";
371 }
372 }
373 return res.str();
374 }
375
376 [[nodiscard]] string toStringWithPort() const
377 {
378 if (sin4.sin_family == AF_INET) {
379 return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
380 }
381 return "[" + toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
382 }
383
384 [[nodiscard]] string toStringWithPortExcept(int port) const
385 {
386 if (ntohs(sin4.sin_port) == port) {
387 return toString();
388 }
389 if (sin4.sin_family == AF_INET) {
390 return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
391 }
392 return "[" + toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
393 }
394
395 [[nodiscard]] string toLogString() const
396 {
397 return toStringWithPortExcept(53);
398 }
399
400 [[nodiscard]] string toStructuredLogString() const
401 {
402 return toStringWithPort();
403 }
404
405 [[nodiscard]] string toByteString() const
406 {
407 // NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
408 if (isIPv4()) {
409 return {reinterpret_cast<const char*>(&sin4.sin_addr.s_addr), sizeof(sin4.sin_addr.s_addr)};
410 }
411 return {reinterpret_cast<const char*>(&sin6.sin6_addr.s6_addr), sizeof(sin6.sin6_addr.s6_addr)};
412 // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
413 }
414
415 void truncate(unsigned int bits) noexcept;
416
417 [[nodiscard]] uint16_t getNetworkOrderPort() const noexcept
418 {
419 return sin4.sin_port;
420 }
421 [[nodiscard]] uint16_t getPort() const noexcept
422 {
423 return ntohs(getNetworkOrderPort());
424 }
425 void setPort(uint16_t port)
426 {
427 sin4.sin_port = htons(port);
428 }
429
430 void reset()
431 {
432 memset(&sin6, 0, sizeof(sin6));
433 }
434
435 //! Get the total number of address bits (either 32 or 128 depending on IP version)
436 [[nodiscard]] uint8_t getBits() const
437 {
438 if (isIPv4()) {
439 return 32;
440 }
441 if (isIPv6()) {
442 return 128;
443 }
444 return 0;
445 }
446 /** Get the value of the bit at the provided bit index. When the index >= 0,
447 the index is relative to the LSB starting at index zero. When the index < 0,
448 the index is relative to the MSB starting at index -1 and counting down.
449 */
450 [[nodiscard]] bool getBit(int index) const
451 {
452 if (isIPv4()) {
453 if (index >= 32) {
454 return false;
455 }
456 if (index < 0) {
457 if (index < -32) {
458 return false;
459 }
460 index = 32 + index;
461 }
462
463 uint32_t ls_addr = ntohl(sin4.sin_addr.s_addr);
464
465 return ((ls_addr & (1U << index)) != 0x00000000);
466 }
467 if (isIPv6()) {
468 if (index >= 128) {
469 return false;
470 }
471 if (index < 0) {
472 if (index < -128) {
473 return false;
474 }
475 index = 128 + index;
476 }
477
478 const auto* ls_addr = reinterpret_cast<const uint8_t*>(sin6.sin6_addr.s6_addr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
479 uint8_t byte_idx = index / 8;
480 uint8_t bit_idx = index % 8;
481
482 return ((ls_addr[15 - byte_idx] & (1U << bit_idx)) != 0x00); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
483 }
484 return false;
485 }
486
487 /*! Returns a comma-separated string of IP addresses
488 *
489 * \param c An stl container with ComboAddresses
490 * \param withPort Also print the port (default true)
491 * \param portExcept Print the port, except when this is the port (default 53)
492 */
493 template <template <class...> class Container, class... Args>
494 static string caContainerToString(const Container<ComboAddress, Args...>& container, const bool withPort = true, const uint16_t portExcept = 53)
495 {
496 vector<string> strs;
497 for (const auto& address : container) {
498 if (withPort) {
499 strs.push_back(address.toStringWithPortExcept(portExcept));
500 continue;
501 }
502 strs.push_back(address.toString());
503 }
504 return boost::join(strs, ",");
505 };
506 };
507
508 union SockaddrWrapper
509 {
510 sockaddr_in sin4{};
511 sockaddr_in6 sin6;
512 sockaddr_un sinun;
513
514 [[nodiscard]] socklen_t getSocklen() const
515 {
516 if (sin4.sin_family == AF_INET) {
517 return sizeof(sin4);
518 }
519 if (sin6.sin6_family == AF_INET6) {
520 return sizeof(sin6);
521 }
522 if (sinun.sun_family == AF_UNIX) {
523 return sizeof(sinun);
524 }
525 return 0;
526 }
527
528 SockaddrWrapper()
529 {
530 sin4.sin_family = AF_INET;
531 sin4.sin_addr.s_addr = 0;
532 sin4.sin_port = 0;
533 }
534
535 SockaddrWrapper(const struct sockaddr* socketAddress, socklen_t salen)
536 {
537 setSockaddr(socketAddress, salen);
538 };
539
540 SockaddrWrapper(const struct sockaddr_in6* socketAddress)
541 {
542 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
543 setSockaddr(reinterpret_cast<const struct sockaddr*>(socketAddress), sizeof(struct sockaddr_in6));
544 };
545
546 SockaddrWrapper(const struct sockaddr_in* socketAddress)
547 {
548 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
549 setSockaddr(reinterpret_cast<const struct sockaddr*>(socketAddress), sizeof(struct sockaddr_in));
550 };
551
552 SockaddrWrapper(const struct sockaddr_un* socketAddress)
553 {
554 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
555 setSockaddr(reinterpret_cast<const struct sockaddr*>(socketAddress), sizeof(struct sockaddr_un));
556 };
557
558 void setSockaddr(const struct sockaddr* socketAddress, socklen_t salen)
559 {
560 if (salen > sizeof(struct sockaddr_un)) {
561 throw PDNSException("ComboAddress can't handle other than sockaddr_in, sockaddr_in6 or sockaddr_un");
562 }
563 memcpy(this, socketAddress, salen);
564 }
565
566 explicit SockaddrWrapper(const string& str, uint16_t port = 0)
567 {
568 memset(&sinun, 0, sizeof(sinun));
569 sin4.sin_family = AF_INET;
570 sin4.sin_port = 0;
571 if (str == "\"\"" || str == "''") {
572 throw PDNSException("Stray quotation marks in address.");
573 }
574 if (makeIPv4sockaddr(str, &sin4) != 0) {
575 sin6.sin6_family = AF_INET6;
576 if (makeIPv6sockaddr(str, &sin6) < 0) {
577 sinun.sun_family = AF_UNIX;
578 // only attempt Unix socket address if address candidate does not contain a port
579 if (str.find(':') != string::npos || makeUNsockaddr(str, &sinun) < 0) {
580 throw PDNSException("Unable to convert presentation address '" + str + "'");
581 }
582 }
583 }
584 if (sinun.sun_family != AF_UNIX && sin4.sin_port == 0) { // 'str' overrides port!
585 sin4.sin_port = htons(port);
586 }
587 }
588
589 [[nodiscard]] bool isIPv6() const
590 {
591 return sin4.sin_family == AF_INET6;
592 }
593 [[nodiscard]] bool isIPv4() const
594 {
595 return sin4.sin_family == AF_INET;
596 }
597 [[nodiscard]] bool isUnixSocket() const
598 {
599 return sin4.sin_family == AF_UNIX;
600 }
601
602 [[nodiscard]] string toString() const
603 {
604 if (sinun.sun_family == AF_UNIX) {
605 return sinun.sun_path;
606 }
607 std::array<char, 1024> host{};
608 if (sin4.sin_family != 0) {
609 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
610 int retval = getnameinfo(reinterpret_cast<const struct sockaddr*>(this), getSocklen(), host.data(), host.size(), nullptr, 0, NI_NUMERICHOST);
611 if (retval == 0) {
612 return host.data();
613 }
614 return "invalid " + string(gai_strerror(retval));
615 }
616 return "invalid";
617 }
618
619 [[nodiscard]] string toStringWithPort() const
620 {
621 if (sinun.sun_family == AF_UNIX) {
622 return toString();
623 }
624 if (sin4.sin_family == AF_INET) {
625 return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
626 }
627 return "[" + toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
628 }
629
630 void reset()
631 {
632 memset(&sinun, 0, sizeof(sinun));
633 }
634 };
635
636 /** This exception is thrown by the Netmask class and by extension by the NetmaskGroup class */
637 class NetmaskException : public PDNSException
638 {
639 public:
640 NetmaskException(const string& arg) :
641 PDNSException(arg) {}
642 };
643
644 inline ComboAddress makeComboAddress(const string& str)
645 {
646 ComboAddress address;
647 address.sin4.sin_family = AF_INET;
648 if (inet_pton(AF_INET, str.c_str(), &address.sin4.sin_addr) <= 0) {
649 address.sin4.sin_family = AF_INET6;
650 if (makeIPv6sockaddr(str, &address.sin6) < 0) {
651 throw NetmaskException("Unable to convert '" + str + "' to a netmask");
652 }
653 }
654 return address;
655 }
656
657 inline ComboAddress makeComboAddressFromRaw(uint8_t version, const char* raw, size_t len)
658 {
659 ComboAddress address;
660
661 if (version == 4) {
662 address.sin4.sin_family = AF_INET;
663 if (len != sizeof(address.sin4.sin_addr)) {
664 throw NetmaskException("invalid raw address length");
665 }
666 memcpy(&address.sin4.sin_addr, raw, sizeof(address.sin4.sin_addr));
667 }
668 else if (version == 6) {
669 address.sin6.sin6_family = AF_INET6;
670 if (len != sizeof(address.sin6.sin6_addr)) {
671 throw NetmaskException("invalid raw address length");
672 }
673 memcpy(&address.sin6.sin6_addr, raw, sizeof(address.sin6.sin6_addr));
674 }
675 else {
676 throw NetmaskException("invalid address family");
677 }
678
679 return address;
680 }
681
682 inline ComboAddress makeComboAddressFromRaw(uint8_t version, const string& str)
683 {
684 return makeComboAddressFromRaw(version, str.c_str(), str.size());
685 }
686
687 /** This class represents a netmask and can be queried to see if a certain
688 IP address is matched by this mask */
689 class Netmask
690 {
691 public:
692 Netmask()
693 {
694 d_network.sin4.sin_family = 0; // disable this doing anything useful
695 d_network.sin4.sin_port = 0; // this guarantees d_network compares identical
696 }
697
698 Netmask(const ComboAddress& network, uint8_t bits = 0xff) :
699 d_network(network)
700 {
701 d_network.sin4.sin_port = 0;
702 setBits(bits);
703 }
704
705 Netmask(const sockaddr_in* network, uint8_t bits = 0xff) :
706 d_network(network)
707 {
708 d_network.sin4.sin_port = 0;
709 setBits(bits);
710 }
711 Netmask(const sockaddr_in6* network, uint8_t bits = 0xff) :
712 d_network(network)
713 {
714 d_network.sin4.sin_port = 0;
715 setBits(bits);
716 }
717 void setBits(uint8_t value)
718 {
719 d_bits = d_network.isIPv4() ? std::min(value, static_cast<uint8_t>(32U)) : std::min(value, static_cast<uint8_t>(128U));
720
721 if (d_bits < 32) {
722 d_mask = ~(0xFFFFFFFF >> d_bits);
723 }
724 else {
725 // note that d_mask is unused for IPv6
726 d_mask = 0xFFFFFFFF;
727 }
728
729 if (isIPv4()) {
730 d_network.sin4.sin_addr.s_addr = htonl(ntohl(d_network.sin4.sin_addr.s_addr) & d_mask);
731 }
732 else if (isIPv6()) {
733 uint8_t bytes = d_bits / 8;
734 auto* address = reinterpret_cast<uint8_t*>(&d_network.sin6.sin6_addr.s6_addr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
735 uint8_t bits = d_bits % 8;
736 auto mask = static_cast<uint8_t>(~(0xFF >> bits));
737
738 if (bytes < sizeof(d_network.sin6.sin6_addr.s6_addr)) {
739 address[bytes] &= mask; // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
740 }
741
742 for (size_t idx = bytes + 1; idx < sizeof(d_network.sin6.sin6_addr.s6_addr); ++idx) {
743 address[idx] = 0; // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
744 }
745 }
746 }
747
748 enum stringType
749 {
750 humanString,
751 byteString,
752 };
753 //! Constructor supplies the mask, which cannot be changed
754 Netmask(const string& mask, stringType type = humanString)
755 {
756 if (type == byteString) {
757 uint8_t afi = mask.at(0);
758 size_t len = afi == 4 ? 4 : 16;
759 uint8_t bits = mask.at(len + 1);
760
761 d_network = makeComboAddressFromRaw(afi, mask.substr(1, len));
762
763 setBits(bits);
764 }
765 else {
766 pair<string, string> split = splitField(mask, '/');
767 d_network = makeComboAddress(split.first);
768
769 if (!split.second.empty()) {
770 setBits(pdns::checked_stoi<uint8_t>(split.second));
771 }
772 else if (d_network.sin4.sin_family == AF_INET) {
773 setBits(32);
774 }
775 else {
776 setBits(128);
777 }
778 }
779 }
780
781 [[nodiscard]] bool match(const ComboAddress& address) const
782 {
783 return match(&address);
784 }
785
786 //! If this IP address in socket address matches
787 bool match(const ComboAddress* address) const
788 {
789 if (d_network.sin4.sin_family != address->sin4.sin_family) {
790 return false;
791 }
792 if (d_network.sin4.sin_family == AF_INET) {
793 return match4(htonl((unsigned int)address->sin4.sin_addr.s_addr));
794 }
795 if (d_network.sin6.sin6_family == AF_INET6) {
796 uint8_t bytes = d_bits / 8;
797 uint8_t index = 0;
798 // NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
799 const auto* lhs = reinterpret_cast<const uint8_t*>(&d_network.sin6.sin6_addr.s6_addr);
800 const auto* rhs = reinterpret_cast<const uint8_t*>(&address->sin6.sin6_addr.s6_addr);
801 // NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
802
803 // NOLINTBEGIN(cppcoreguidelines-pro-bounds-pointer-arithmetic)
804 for (index = 0; index < bytes; ++index) {
805 if (lhs[index] != rhs[index]) {
806 return false;
807 }
808 }
809 // still here, now match remaining bits
810 uint8_t bits = d_bits % 8;
811 auto mask = static_cast<uint8_t>(~(0xFF >> bits));
812
813 return ((lhs[index]) == (rhs[index] & mask));
814 // NOLINTEND(cppcoreguidelines-pro-bounds-pointer-arithmetic)
815 }
816 return false;
817 }
818
819 //! If this ASCII IP address matches
820 [[nodiscard]] bool match(const string& arg) const
821 {
822 ComboAddress address = makeComboAddress(arg);
823 return match(&address);
824 }
825
826 //! If this IP address in native format matches
827 [[nodiscard]] bool match4(uint32_t arg) const
828 {
829 return (arg & d_mask) == (ntohl(d_network.sin4.sin_addr.s_addr));
830 }
831
832 [[nodiscard]] string toString() const
833 {
834 return d_network.toStringNoInterface() + "/" + std::to_string((unsigned int)d_bits);
835 }
836
837 [[nodiscard]] string toStringNoMask() const
838 {
839 return d_network.toStringNoInterface();
840 }
841
842 [[nodiscard]] string toByteString() const
843 {
844 ostringstream tmp;
845
846 tmp << (d_network.isIPv4() ? "\x04" : "\x06")
847 << d_network.toByteString()
848 << getBits();
849
850 return tmp.str();
851 }
852
853 [[nodiscard]] const ComboAddress& getNetwork() const
854 {
855 return d_network;
856 }
857
858 [[nodiscard]] const ComboAddress& getMaskedNetwork() const
859 {
860 return getNetwork();
861 }
862
863 [[nodiscard]] uint8_t getBits() const
864 {
865 return d_bits;
866 }
867
868 [[nodiscard]] bool isIPv6() const
869 {
870 return d_network.sin6.sin6_family == AF_INET6;
871 }
872
873 [[nodiscard]] bool isIPv4() const
874 {
875 return d_network.sin4.sin_family == AF_INET;
876 }
877
878 bool operator<(const Netmask& rhs) const
879 {
880 if (empty() && !rhs.empty()) {
881 return false;
882 }
883 if (!empty() && rhs.empty()) {
884 return true;
885 }
886 if (d_bits > rhs.d_bits) {
887 return true;
888 }
889 if (d_bits < rhs.d_bits) {
890 return false;
891 }
892
893 return d_network < rhs.d_network;
894 }
895
896 bool operator>(const Netmask& rhs) const
897 {
898 return rhs.operator<(*this);
899 }
900
901 bool operator==(const Netmask& rhs) const
902 {
903 return std::tie(d_network, d_bits) == std::tie(rhs.d_network, rhs.d_bits);
904 }
905
906 bool operator!=(const Netmask& rhs) const
907 {
908 return !operator==(rhs);
909 }
910
911 [[nodiscard]] bool empty() const
912 {
913 return d_network.sin4.sin_family == 0;
914 }
915
916 //! Get normalized version of the netmask. This means that all address bits below the network bits are zero.
917 [[nodiscard]] Netmask getNormalized() const
918 {
919 return {getMaskedNetwork(), d_bits};
920 }
921 //! Get Netmask for super network of this one (i.e. with fewer network bits)
922 [[nodiscard]] Netmask getSuper(uint8_t bits) const
923 {
924 return {d_network, std::min(d_bits, bits)};
925 }
926
927 //! Get the total number of address bits for this netmask (either 32 or 128 depending on IP version)
928 [[nodiscard]] uint8_t getFullBits() const
929 {
930 return d_network.getBits();
931 }
932
933 /** Get the value of the bit at the provided bit index. When the index >= 0,
934 the index is relative to the LSB starting at index zero. When the index < 0,
935 the index is relative to the MSB starting at index -1 and counting down.
936 When the index points outside the network bits, it always yields zero.
937 */
938 [[nodiscard]] bool getBit(int bit) const
939 {
940 if (bit < -d_bits) {
941 return false;
942 }
943 if (bit >= 0) {
944 if (isIPv4()) {
945 if (bit >= 32 || bit < (32 - d_bits)) {
946 return false;
947 }
948 }
949 if (isIPv6()) {
950 if (bit >= 128 || bit < (128 - d_bits)) {
951 return false;
952 }
953 }
954 }
955 return d_network.getBit(bit);
956 }
957
958 struct Hash
959 {
960 size_t operator()(const Netmask& netmask) const
961 {
962 return burtle(&netmask.d_bits, 1, ComboAddress::addressOnlyHash()(netmask.d_network));
963 }
964 };
965
966 private:
967 ComboAddress d_network;
968 uint32_t d_mask{0};
969 uint8_t d_bits{0};
970 };
971
972 namespace std
973 {
974 template <>
975 struct hash<Netmask>
976 {
977 auto operator()(const Netmask& netmask) const
978 {
979 return Netmask::Hash{}(netmask);
980 }
981 };
982 }
983
984 /** Binary tree map implementation with <Netmask,T> pair.
985 *
986 * This is an binary tree implementation for storing attributes for IPv4 and IPv6 prefixes.
987 * The most simple use case is simple NetmaskTree<bool> used by NetmaskGroup, which only
988 * wants to know if given IP address is matched in the prefixes stored.
989 *
990 * This element is useful for anything that needs to *STORE* prefixes, and *MATCH* IP addresses
991 * to a *LIST* of *PREFIXES*. Not the other way round.
992 *
993 * You can store IPv4 and IPv6 addresses to same tree, separate payload storage is kept per AFI.
994 * Network prefixes (Netmasks) are always recorded in normalized fashion, meaning that only
995 * the network bits are set. This is what is returned in the insert() and lookup() return
996 * values.
997 *
998 * Use swap if you need to move the tree to another NetmaskTree instance, it is WAY faster
999 * than using copy ctor or assignment operator, since it moves the nodes and tree root to
1000 * new home instead of actually recreating the tree.
1001 *
1002 * Please see NetmaskGroup for example of simple use case. Other usecases can be found
1003 * from GeoIPBackend and Sortlist, and from dnsdist.
1004 */
1005 template <typename T, class K = Netmask>
1006 class NetmaskTree
1007 {
1008 public:
1009 class Iterator;
1010
1011 using key_type = K;
1012 using value_type = T;
1013 using node_type = std::pair<const key_type, value_type>;
1014 using size_type = size_t;
1015 using iterator = class Iterator;
1016
1017 private:
1018 /** Single node in tree, internal use only.
1019 */
1020 class TreeNode : boost::noncopyable
1021 {
1022 public:
1023 explicit TreeNode() noexcept :
1024 parent(nullptr), node(), assigned(false), d_bits(0)
1025 {
1026 }
1027 explicit TreeNode(const key_type& key) :
1028 parent(nullptr), node({key.getNormalized(), value_type()}), assigned(false), d_bits(key.getFullBits())
1029 {
1030 }
1031
1032 //<! Makes a left leaf node with specified key.
1033 TreeNode* make_left(const key_type& key)
1034 {
1035 d_bits = node.first.getBits();
1036 left = make_unique<TreeNode>(key);
1037 left->parent = this;
1038 return left.get();
1039 }
1040
1041 //<! Makes a right leaf node with specified key.
1042 TreeNode* make_right(const key_type& key)
1043 {
1044 d_bits = node.first.getBits();
1045 right = make_unique<TreeNode>(key);
1046 right->parent = this;
1047 return right.get();
1048 }
1049
1050 //<! Splits branch at indicated bit position by inserting key
1051 TreeNode* split(const key_type& key, int bits)
1052 {
1053 if (parent == nullptr) {
1054 // not to be called on the root node
1055 throw std::logic_error(
1056 "NetmaskTree::TreeNode::split(): must not be called on root node");
1057 }
1058
1059 // determine reference from parent
1060 unique_ptr<TreeNode>& parent_ref = (parent->left.get() == this ? parent->left : parent->right);
1061 if (parent_ref.get() != this) {
1062 throw std::logic_error(
1063 "NetmaskTree::TreeNode::split(): parent node reference is invalid");
1064 }
1065
1066 // create new tree node for the new key and
1067 // attach the new node under our former parent
1068 auto new_intermediate_node = make_unique<TreeNode>(key);
1069 new_intermediate_node->d_bits = bits;
1070 new_intermediate_node->parent = parent;
1071 auto* new_intermediate_node_raw = new_intermediate_node.get();
1072
1073 // hereafter new_intermediate points to "this"
1074 // ie the child of the new intermediate node
1075 std::swap(parent_ref, new_intermediate_node);
1076 // and we now assign this to current_node so
1077 // it's clear it no longer refers to the new
1078 // intermediate node
1079 std::unique_ptr<TreeNode> current_node = std::move(new_intermediate_node);
1080
1081 // attach "this" node below the new node
1082 // (left or right depending on bit)
1083 // technically the raw pointer escapes the duration of the
1084 // unique pointer, but just below we store the unique pointer
1085 // in the parent, so it lives as long as necessary
1086 // coverity[escape]
1087 current_node->parent = new_intermediate_node_raw;
1088 if (current_node->node.first.getBit(-1 - bits)) {
1089 new_intermediate_node_raw->right = std::move(current_node);
1090 }
1091 else {
1092 new_intermediate_node_raw->left = std::move(current_node);
1093 }
1094
1095 return new_intermediate_node_raw;
1096 }
1097
1098 //<! Forks branch for new key at indicated bit position
1099 TreeNode* fork(const key_type& key, int bits)
1100 {
1101 if (parent == nullptr) {
1102 // not to be called on the root node
1103 throw std::logic_error(
1104 "NetmaskTree::TreeNode::fork(): must not be called on root node");
1105 }
1106
1107 // determine reference from parent
1108 unique_ptr<TreeNode>& parent_ref = (parent->left.get() == this ? parent->left : parent->right);
1109 if (parent_ref.get() != this) {
1110 throw std::logic_error(
1111 "NetmaskTree::TreeNode::fork(): parent node reference is invalid");
1112 }
1113
1114 // create new tree node for the branch point
1115
1116 // the current node will now be a child of the new branch node
1117 // (hereafter new_child1 points to "this")
1118 unique_ptr<TreeNode> new_child1 = std::move(parent_ref);
1119 // attach the branch node under our former parent
1120 parent_ref = make_unique<TreeNode>(node.first.getSuper(bits));
1121 auto* branch_node = parent_ref.get();
1122 branch_node->d_bits = bits;
1123 branch_node->parent = parent;
1124
1125 // create second new leaf node for the new key
1126 unique_ptr<TreeNode> new_child2 = make_unique<TreeNode>(key);
1127 TreeNode* new_node = new_child2.get();
1128
1129 // attach the new child nodes below the branch node
1130 // (left or right depending on bit)
1131 new_child1->parent = branch_node;
1132 new_child2->parent = branch_node;
1133 if (new_child1->node.first.getBit(-1 - bits)) {
1134 branch_node->right = std::move(new_child1);
1135 branch_node->left = std::move(new_child2);
1136 }
1137 else {
1138 branch_node->right = std::move(new_child2);
1139 branch_node->left = std::move(new_child1);
1140 }
1141 // now we have attached the new unique pointers to the tree:
1142 // - branch_node is below its parent
1143 // - new_child1 (ourselves) is below branch_node
1144 // - new_child2, the new leaf node, is below branch_node as well
1145
1146 return new_node;
1147 }
1148
1149 //<! Traverse left branch depth-first
1150 TreeNode* traverse_l()
1151 {
1152 TreeNode* tnode = this;
1153
1154 while (tnode->left) {
1155 tnode = tnode->left.get();
1156 }
1157 return tnode;
1158 }
1159
1160 //<! Traverse tree depth-first and in-order (L-N-R)
1161 TreeNode* traverse_lnr()
1162 {
1163 TreeNode* tnode = this;
1164
1165 // precondition: descended left as deep as possible
1166 if (tnode->right) {
1167 // descend right
1168 tnode = tnode->right.get();
1169 // descend left as deep as possible and return next node
1170 return tnode->traverse_l();
1171 }
1172
1173 // ascend to parent
1174 while (tnode->parent != nullptr) {
1175 TreeNode* prev_child = tnode;
1176 tnode = tnode->parent;
1177
1178 // return this node, but only when we come from the left child branch
1179 if (tnode->left && tnode->left.get() == prev_child) {
1180 return tnode;
1181 }
1182 }
1183 return nullptr;
1184 }
1185
1186 //<! Traverse only assigned nodes
1187 TreeNode* traverse_lnr_assigned()
1188 {
1189 TreeNode* tnode = traverse_lnr();
1190
1191 while (tnode != nullptr && !tnode->assigned) {
1192 tnode = tnode->traverse_lnr();
1193 }
1194 return tnode;
1195 }
1196
1197 unique_ptr<TreeNode> left;
1198 unique_ptr<TreeNode> right;
1199 TreeNode* parent;
1200
1201 node_type node;
1202 bool assigned; //<! Whether this node is assigned-to by the application
1203
1204 int d_bits; //<! How many bits have been used so far
1205 };
1206
1207 void cleanup_tree(TreeNode* node)
1208 {
1209 // only cleanup this node if it has no children and node not assigned
1210 if (!(node->left || node->right || node->assigned)) {
1211 // get parent node ptr
1212 TreeNode* pparent = node->parent;
1213 // delete this node
1214 if (pparent) {
1215 if (pparent->left.get() == node) {
1216 pparent->left.reset();
1217 }
1218 else {
1219 pparent->right.reset();
1220 }
1221 // now recurse up to the parent
1222 cleanup_tree(pparent);
1223 }
1224 }
1225 }
1226
1227 void copyTree(const NetmaskTree& rhs)
1228 {
1229 try {
1230 TreeNode* node = rhs.d_root.get();
1231 if (node != nullptr) {
1232 node = node->traverse_l();
1233 }
1234 while (node != nullptr) {
1235 if (node->assigned) {
1236 insert(node->node.first).second = node->node.second;
1237 }
1238 node = node->traverse_lnr();
1239 }
1240 }
1241 catch (const NetmaskException&) {
1242 abort();
1243 }
1244 catch (const std::logic_error&) {
1245 abort();
1246 }
1247 }
1248
1249 public:
1250 class Iterator
1251 {
1252 public:
1253 using value_type = node_type;
1254 using reference = node_type&;
1255 using pointer = node_type*;
1256 using iterator_category = std::forward_iterator_tag;
1257 using difference_type = size_type;
1258
1259 private:
1260 friend class NetmaskTree;
1261
1262 const NetmaskTree* d_tree;
1263 TreeNode* d_node;
1264
1265 Iterator(const NetmaskTree* tree, TreeNode* node) :
1266 d_tree(tree), d_node(node)
1267 {
1268 }
1269
1270 public:
1271 Iterator() :
1272 d_tree(nullptr), d_node(nullptr) {}
1273
1274 Iterator& operator++() // prefix
1275 {
1276 if (d_node == nullptr) {
1277 throw std::logic_error(
1278 "NetmaskTree::Iterator::operator++: iterator is invalid");
1279 }
1280 d_node = d_node->traverse_lnr_assigned();
1281 return *this;
1282 }
1283 Iterator operator++(int) // postfix
1284 {
1285 Iterator tmp(*this);
1286 operator++();
1287 return tmp;
1288 }
1289
1290 reference operator*()
1291 {
1292 if (d_node == nullptr) {
1293 throw std::logic_error(
1294 "NetmaskTree::Iterator::operator*: iterator is invalid");
1295 }
1296 return d_node->node;
1297 }
1298
1299 pointer operator->()
1300 {
1301 if (d_node == nullptr) {
1302 throw std::logic_error(
1303 "NetmaskTree::Iterator::operator->: iterator is invalid");
1304 }
1305 return &d_node->node;
1306 }
1307
1308 bool operator==(const Iterator& rhs)
1309 {
1310 return (d_tree == rhs.d_tree && d_node == rhs.d_node);
1311 }
1312 bool operator!=(const Iterator& rhs)
1313 {
1314 return !(*this == rhs);
1315 }
1316 };
1317
1318 NetmaskTree() noexcept :
1319 d_root(new TreeNode()), d_left(nullptr)
1320 {
1321 }
1322
1323 NetmaskTree(const NetmaskTree& rhs) :
1324 d_root(new TreeNode()), d_left(nullptr)
1325 {
1326 copyTree(rhs);
1327 }
1328
1329 ~NetmaskTree() = default;
1330
1331 NetmaskTree& operator=(const NetmaskTree& rhs)
1332 {
1333 if (this != &rhs) {
1334 clear();
1335 copyTree(rhs);
1336 }
1337 return *this;
1338 }
1339
1340 NetmaskTree(NetmaskTree&&) noexcept = default;
1341 NetmaskTree& operator=(NetmaskTree&&) noexcept = default;
1342
1343 [[nodiscard]] iterator begin() const
1344 {
1345 return Iterator(this, d_left);
1346 }
1347 [[nodiscard]] iterator end() const
1348 {
1349 return Iterator(this, nullptr);
1350 }
1351 iterator begin()
1352 {
1353 return Iterator(this, d_left);
1354 }
1355 iterator end()
1356 {
1357 return Iterator(this, nullptr);
1358 }
1359
1360 node_type& insert(const string& mask)
1361 {
1362 return insert(key_type(mask));
1363 }
1364
1365 //<! Creates new value-pair in tree and returns it.
1366 node_type& insert(const key_type& key)
1367 {
1368 TreeNode* node{};
1369 bool is_left = true;
1370
1371 // we turn left on IPv4 and right on IPv6
1372 if (key.isIPv4()) {
1373 node = d_root->left.get();
1374 if (node == nullptr) {
1375
1376 d_root->left = make_unique<TreeNode>(key);
1377 node = d_root->left.get();
1378 node->assigned = true;
1379 node->parent = d_root.get();
1380 d_size++;
1381 d_left = node;
1382 return node->node;
1383 }
1384 }
1385 else if (key.isIPv6()) {
1386 node = d_root->right.get();
1387 if (node == nullptr) {
1388
1389 d_root->right = make_unique<TreeNode>(key);
1390 node = d_root->right.get();
1391 node->assigned = true;
1392 node->parent = d_root.get();
1393 d_size++;
1394 if (!d_root->left) {
1395 d_left = node;
1396 }
1397 return node->node;
1398 }
1399 if (d_root->left) {
1400 is_left = false;
1401 }
1402 }
1403 else {
1404 throw NetmaskException("invalid address family");
1405 }
1406
1407 // we turn left on 0 and right on 1
1408 int bits = 0;
1409 for (; bits < key.getBits(); bits++) {
1410 bool vall = key.getBit(-1 - bits);
1411
1412 if (bits >= node->d_bits) {
1413 // the end of the current node is reached; continue with the next
1414 if (vall) {
1415 if (node->left || node->assigned) {
1416 is_left = false;
1417 }
1418 if (!node->right) {
1419 // the right branch doesn't exist yet; attach our key here
1420 node = node->make_right(key);
1421 break;
1422 }
1423 node = node->right.get();
1424 }
1425 else {
1426 if (!node->left) {
1427 // the left branch doesn't exist yet; attach our key here
1428 node = node->make_left(key);
1429 break;
1430 }
1431 node = node->left.get();
1432 }
1433 continue;
1434 }
1435 if (bits >= node->node.first.getBits()) {
1436 // the matching branch ends here, yet the key netmask has more bits; add a
1437 // child node below the existing branch leaf.
1438 if (vall) {
1439 if (node->assigned) {
1440 is_left = false;
1441 }
1442 node = node->make_right(key);
1443 }
1444 else {
1445 node = node->make_left(key);
1446 }
1447 break;
1448 }
1449 bool valr = node->node.first.getBit(-1 - bits);
1450 if (vall != valr) {
1451 if (vall) {
1452 is_left = false;
1453 }
1454 // the branch matches just upto this point, yet continues in a different
1455 // direction; fork the branch.
1456 node = node->fork(key, bits);
1457 break;
1458 }
1459 }
1460
1461 if (node->node.first.getBits() > key.getBits()) {
1462 // key is a super-network of the matching node; split the branch and
1463 // insert a node for the key above the matching node.
1464 node = node->split(key, key.getBits());
1465 }
1466
1467 if (node->left) {
1468 is_left = false;
1469 }
1470
1471 node_type& value = node->node;
1472
1473 if (!node->assigned) {
1474 // only increment size if not assigned before
1475 d_size++;
1476 // update the pointer to the left-most tree node
1477 if (is_left) {
1478 d_left = node;
1479 }
1480 node->assigned = true;
1481 }
1482 else {
1483 // tree node exists for this value
1484 if (is_left && d_left != node) {
1485 throw std::logic_error(
1486 "NetmaskTree::insert(): lost track of left-most node in tree");
1487 }
1488 }
1489
1490 return value;
1491 }
1492
1493 //<! Creates or updates value
1494 void insert_or_assign(const key_type& mask, const value_type& value)
1495 {
1496 insert(mask).second = value;
1497 }
1498
1499 void insert_or_assign(const string& mask, const value_type& value)
1500 {
1501 insert(key_type(mask)).second = value;
1502 }
1503
1504 //<! check if given key is present in TreeMap
1505 [[nodiscard]] bool has_key(const key_type& key) const
1506 {
1507 const node_type* ptr = lookup(key);
1508 return ptr && ptr->first == key;
1509 }
1510
1511 //<! Returns "best match" for key_type, which might not be value
1512 [[nodiscard]] node_type* lookup(const key_type& value) const
1513 {
1514 uint8_t max_bits = value.getBits();
1515 return lookupImpl(value, max_bits);
1516 }
1517
1518 //<! Perform best match lookup for value, using at most max_bits
1519 [[nodiscard]] node_type* lookup(const ComboAddress& value, int max_bits = 128) const
1520 {
1521 uint8_t addr_bits = value.getBits();
1522 if (max_bits < 0 || max_bits > addr_bits) {
1523 max_bits = addr_bits;
1524 }
1525
1526 return lookupImpl(key_type(value, max_bits), max_bits);
1527 }
1528
1529 //<! Removes key from TreeMap.
1530 void erase(const key_type& key)
1531 {
1532 TreeNode* node = nullptr;
1533
1534 if (key.isIPv4()) {
1535 node = d_root->left.get();
1536 }
1537 else if (key.isIPv6()) {
1538 node = d_root->right.get();
1539 }
1540 else {
1541 throw NetmaskException("invalid address family");
1542 }
1543 // no tree, no value
1544 if (node == nullptr) {
1545 return;
1546 }
1547 int bits = 0;
1548 for (; node && bits < key.getBits(); bits++) {
1549 bool vall = key.getBit(-1 - bits);
1550 if (bits >= node->d_bits) {
1551 // the end of the current node is reached; continue with the next
1552 if (vall) {
1553 node = node->right.get();
1554 }
1555 else {
1556 node = node->left.get();
1557 }
1558 continue;
1559 }
1560 if (bits >= node->node.first.getBits()) {
1561 // the matching branch ends here
1562 if (key.getBits() != node->node.first.getBits()) {
1563 node = nullptr;
1564 }
1565 break;
1566 }
1567 bool valr = node->node.first.getBit(-1 - bits);
1568 if (vall != valr) {
1569 // the branch matches just upto this point, yet continues in a different
1570 // direction
1571 node = nullptr;
1572 break;
1573 }
1574 }
1575 if (node) {
1576 if (d_size == 0) {
1577 throw std::logic_error(
1578 "NetmaskTree::erase(): size of tree is zero before erase");
1579 }
1580 d_size--;
1581 node->assigned = false;
1582 node->node.second = value_type();
1583
1584 if (node == d_left) {
1585 d_left = d_left->traverse_lnr_assigned();
1586 }
1587 cleanup_tree(node);
1588 }
1589 }
1590
1591 void erase(const string& key)
1592 {
1593 erase(key_type(key));
1594 }
1595
1596 //<! checks whether the container is empty.
1597 [[nodiscard]] bool empty() const
1598 {
1599 return (d_size == 0);
1600 }
1601
1602 //<! returns the number of elements
1603 [[nodiscard]] size_type size() const
1604 {
1605 return d_size;
1606 }
1607
1608 //<! See if given ComboAddress matches any prefix
1609 [[nodiscard]] bool match(const ComboAddress& value) const
1610 {
1611 return (lookup(value) != nullptr);
1612 }
1613
1614 [[nodiscard]] bool match(const std::string& value) const
1615 {
1616 return match(ComboAddress(value));
1617 }
1618
1619 //<! Clean out the tree
1620 void clear()
1621 {
1622 d_root = make_unique<TreeNode>();
1623 d_left = nullptr;
1624 d_size = 0;
1625 }
1626
1627 //<! swaps the contents with another NetmaskTree
1628 void swap(NetmaskTree& rhs) noexcept
1629 {
1630 std::swap(d_root, rhs.d_root);
1631 std::swap(d_left, rhs.d_left);
1632 std::swap(d_size, rhs.d_size);
1633 }
1634
1635 private:
1636 [[nodiscard]] node_type* lookupImpl(const key_type& value, uint8_t max_bits) const
1637 {
1638 TreeNode* node = nullptr;
1639
1640 if (value.isIPv4()) {
1641 node = d_root->left.get();
1642 }
1643 else if (value.isIPv6()) {
1644 node = d_root->right.get();
1645 }
1646 else {
1647 throw NetmaskException("invalid address family");
1648 }
1649 if (node == nullptr) {
1650 return nullptr;
1651 }
1652
1653 node_type* ret = nullptr;
1654
1655 int bits = 0;
1656 for (; bits < max_bits; bits++) {
1657 bool vall = value.getBit(-1 - bits);
1658 if (bits >= node->d_bits) {
1659 // the end of the current node is reached; continue with the next
1660 // (we keep track of last assigned node)
1661 if (node->assigned && bits == node->node.first.getBits()) {
1662 ret = &node->node;
1663 }
1664 if (vall) {
1665 if (!node->right) {
1666 break;
1667 }
1668 node = node->right.get();
1669 }
1670 else {
1671 if (!node->left) {
1672 break;
1673 }
1674 node = node->left.get();
1675 }
1676 continue;
1677 }
1678 if (bits >= node->node.first.getBits()) {
1679 // the matching branch ends here
1680 break;
1681 }
1682 bool valr = node->node.first.getBit(-1 - bits);
1683 if (vall != valr) {
1684 // the branch matches just upto this point, yet continues in a different
1685 // direction
1686 break;
1687 }
1688 }
1689 // needed if we did not find one in loop
1690 if (node->assigned && bits == node->node.first.getBits()) {
1691 ret = &node->node;
1692 }
1693 // this can be nullptr.
1694 return ret;
1695 }
1696
1697 unique_ptr<TreeNode> d_root; //<! Root of our tree
1698 TreeNode* d_left;
1699 size_type d_size{0};
1700 };
1701
1702 /** This class represents a group of supplemental Netmask classes. An IP address matches
1703 if it is matched by one or more of the Netmask objects within.
1704 */
1705 class NetmaskGroup
1706 {
1707 public:
1708 NetmaskGroup() noexcept = default;
1709
1710 //! If this IP address is matched by any of the classes within
1711
1712 bool match(const ComboAddress* address) const
1713 {
1714 const auto& ret = tree.lookup(*address);
1715 if (ret != nullptr) {
1716 return ret->second;
1717 }
1718 return false;
1719 }
1720
1721 [[nodiscard]] bool match(const ComboAddress& address) const
1722 {
1723 return match(&address);
1724 }
1725
1726 bool lookup(const ComboAddress* address, Netmask* nmp) const
1727 {
1728 const auto& ret = tree.lookup(*address);
1729 if (ret != nullptr) {
1730 if (nmp != nullptr) {
1731 *nmp = ret->first;
1732 }
1733 return ret->second;
1734 }
1735 return false;
1736 }
1737
1738 bool lookup(const ComboAddress& address, Netmask* nmp) const
1739 {
1740 return lookup(&address, nmp);
1741 }
1742
1743 //! Add this string to the list of possible matches
1744 void addMask(const string& address, bool positive = true)
1745 {
1746 if (!address.empty() && address[0] == '!') {
1747 addMask(Netmask(address.substr(1)), false);
1748 }
1749 else {
1750 addMask(Netmask(address), positive);
1751 }
1752 }
1753
1754 //! Add this Netmask to the list of possible matches
1755 void addMask(const Netmask& netmask, bool positive = true)
1756 {
1757 tree.insert(netmask).second = positive;
1758 }
1759
1760 void addMasks(const NetmaskGroup& group, boost::optional<bool> positive)
1761 {
1762 for (const auto& entry : group.tree) {
1763 addMask(entry.first, positive ? *positive : entry.second);
1764 }
1765 }
1766
1767 //! Delete this Netmask from the list of possible matches
1768 void deleteMask(const Netmask& netmask)
1769 {
1770 tree.erase(netmask);
1771 }
1772
1773 void deleteMasks(const NetmaskGroup& group)
1774 {
1775 for (const auto& entry : group.tree) {
1776 deleteMask(entry.first);
1777 }
1778 }
1779
1780 void deleteMask(const std::string& address)
1781 {
1782 if (!address.empty()) {
1783 deleteMask(Netmask(address));
1784 }
1785 }
1786
1787 void clear()
1788 {
1789 tree.clear();
1790 }
1791
1792 [[nodiscard]] bool empty() const
1793 {
1794 return tree.empty();
1795 }
1796
1797 [[nodiscard]] size_t size() const
1798 {
1799 return tree.size();
1800 }
1801
1802 [[nodiscard]] string toString() const
1803 {
1804 ostringstream str;
1805 for (auto iter = tree.begin(); iter != tree.end(); ++iter) {
1806 if (iter != tree.begin()) {
1807 str << ", ";
1808 }
1809 if (!(iter->second)) {
1810 str << "!";
1811 }
1812 str << iter->first.toString();
1813 }
1814 return str.str();
1815 }
1816
1817 [[nodiscard]] std::vector<std::string> toStringVector() const
1818 {
1819 std::vector<std::string> out;
1820 out.reserve(tree.size());
1821 for (const auto& entry : tree) {
1822 out.push_back((entry.second ? "" : "!") + entry.first.toString());
1823 }
1824 return out;
1825 }
1826
1827 void toMasks(const string& ips)
1828 {
1829 vector<string> parts;
1830 stringtok(parts, ips, ", \t");
1831
1832 for (const auto& part : parts) {
1833 addMask(part);
1834 }
1835 }
1836
1837 private:
1838 NetmaskTree<bool> tree;
1839 };
1840
1841 struct SComboAddress
1842 {
1843 SComboAddress(const ComboAddress& orig) :
1844 ca(orig) {}
1845 ComboAddress ca;
1846 bool operator<(const SComboAddress& rhs) const
1847 {
1848 return ComboAddress::addressOnlyLessThan()(ca, rhs.ca);
1849 }
1850 operator const ComboAddress&() const
1851 {
1852 return ca;
1853 }
1854 };
1855
1856 class NetworkError : public runtime_error
1857 {
1858 public:
1859 NetworkError(const string& why = "Network Error") :
1860 runtime_error(why.c_str())
1861 {}
1862 NetworkError(const char* why = "Network Error") :
1863 runtime_error(why)
1864 {}
1865 };
1866
1867 class AddressAndPortRange
1868 {
1869 public:
1870 AddressAndPortRange() :
1871 d_addrMask(0), d_portMask(0)
1872 {
1873 d_addr.sin4.sin_family = 0; // disable this doing anything useful
1874 d_addr.sin4.sin_port = 0; // this guarantees d_network compares identical
1875 }
1876
1877 AddressAndPortRange(ComboAddress address, uint8_t addrMask, uint8_t portMask = 0) :
1878 d_addr(address), d_addrMask(addrMask), d_portMask(portMask)
1879 {
1880 if (!d_addr.isIPv4()) {
1881 d_portMask = 0;
1882 }
1883
1884 uint16_t port = d_addr.getPort();
1885 if (d_portMask < 16) {
1886 auto mask = static_cast<uint16_t>(~(0xFFFF >> d_portMask));
1887 port = port & mask;
1888 }
1889
1890 if (d_addrMask < d_addr.getBits()) {
1891 if (d_portMask > 0) {
1892 throw std::runtime_error("Trying to create a AddressAndPortRange with a reduced address mask (" + std::to_string(d_addrMask) + ") and a port range (" + std::to_string(d_portMask) + ")");
1893 }
1894 d_addr = Netmask(d_addr, d_addrMask).getMaskedNetwork();
1895 }
1896 d_addr.setPort(port);
1897 }
1898
1899 [[nodiscard]] uint8_t getFullBits() const
1900 {
1901 return d_addr.getBits() + 16;
1902 }
1903
1904 [[nodiscard]] uint8_t getBits() const
1905 {
1906 if (d_addrMask < d_addr.getBits()) {
1907 return d_addrMask;
1908 }
1909
1910 return d_addr.getBits() + d_portMask;
1911 }
1912
1913 /** Get the value of the bit at the provided bit index. When the index >= 0,
1914 the index is relative to the LSB starting at index zero. When the index < 0,
1915 the index is relative to the MSB starting at index -1 and counting down.
1916 */
1917 [[nodiscard]] bool getBit(int index) const
1918 {
1919 if (index >= getFullBits()) {
1920 return false;
1921 }
1922 if (index < 0) {
1923 index = getFullBits() + index;
1924 }
1925
1926 if (index < 16) {
1927 /* we are into the port bits */
1928 uint16_t port = d_addr.getPort();
1929 return ((port & (1U << index)) != 0x0000);
1930 }
1931
1932 index -= 16;
1933
1934 return d_addr.getBit(index);
1935 }
1936
1937 [[nodiscard]] bool isIPv4() const
1938 {
1939 return d_addr.isIPv4();
1940 }
1941
1942 [[nodiscard]] bool isIPv6() const
1943 {
1944 return d_addr.isIPv6();
1945 }
1946
1947 [[nodiscard]] AddressAndPortRange getNormalized() const
1948 {
1949 return {d_addr, d_addrMask, d_portMask};
1950 }
1951
1952 [[nodiscard]] AddressAndPortRange getSuper(uint8_t bits) const
1953 {
1954 if (bits <= d_addrMask) {
1955 return {d_addr, bits, 0};
1956 }
1957 if (bits <= d_addrMask + d_portMask) {
1958 return {d_addr, d_addrMask, static_cast<uint8_t>(d_portMask - (bits - d_addrMask))};
1959 }
1960
1961 return {d_addr, d_addrMask, d_portMask};
1962 }
1963
1964 [[nodiscard]] const ComboAddress& getNetwork() const
1965 {
1966 return d_addr;
1967 }
1968
1969 [[nodiscard]] string toString() const
1970 {
1971 if (d_addrMask < d_addr.getBits() || d_portMask == 0) {
1972 return d_addr.toStringNoInterface() + "/" + std::to_string(d_addrMask);
1973 }
1974 return d_addr.toStringNoInterface() + ":" + std::to_string(d_addr.getPort()) + "/" + std::to_string(d_portMask);
1975 }
1976
1977 [[nodiscard]] bool empty() const
1978 {
1979 return d_addr.sin4.sin_family == 0;
1980 }
1981
1982 bool operator==(const AddressAndPortRange& rhs) const
1983 {
1984 return std::tie(d_addr, d_addrMask, d_portMask) == std::tie(rhs.d_addr, rhs.d_addrMask, rhs.d_portMask);
1985 }
1986
1987 bool operator<(const AddressAndPortRange& rhs) const
1988 {
1989 if (empty() && !rhs.empty()) {
1990 return false;
1991 }
1992
1993 if (!empty() && rhs.empty()) {
1994 return true;
1995 }
1996
1997 if (d_addrMask > rhs.d_addrMask) {
1998 return true;
1999 }
2000
2001 if (d_addrMask < rhs.d_addrMask) {
2002 return false;
2003 }
2004
2005 if (d_addr < rhs.d_addr) {
2006 return true;
2007 }
2008
2009 if (d_addr > rhs.d_addr) {
2010 return false;
2011 }
2012
2013 if (d_portMask > rhs.d_portMask) {
2014 return true;
2015 }
2016
2017 if (d_portMask < rhs.d_portMask) {
2018 return false;
2019 }
2020
2021 return d_addr.getPort() < rhs.d_addr.getPort();
2022 }
2023
2024 bool operator>(const AddressAndPortRange& rhs) const
2025 {
2026 return rhs.operator<(*this);
2027 }
2028
2029 struct hash
2030 {
2031 uint32_t operator()(const AddressAndPortRange& apr) const
2032 {
2033 ComboAddress::addressOnlyHash hashOp;
2034 uint16_t port = apr.d_addr.getPort();
2035 /* it's fine to hash the whole address and port because the non-relevant parts have
2036 been masked to 0 */
2037 return burtle(reinterpret_cast<const unsigned char*>(&port), sizeof(port), hashOp(apr.d_addr)); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
2038 }
2039 };
2040
2041 private:
2042 ComboAddress d_addr;
2043 uint8_t d_addrMask;
2044 /* only used for v4 addresses */
2045 uint8_t d_portMask;
2046 };
2047
2048 int SSocket(int family, int type, int flags);
2049 int SConnect(int sockfd, bool fastopen, const ComboAddress& remote);
2050 /* tries to connect to remote for a maximum of timeout seconds.
2051 sockfd should be set to non-blocking beforehand.
2052 returns 0 on success (the socket is writable), throw a
2053 runtime_error otherwise */
2054 int SConnectWithTimeout(int sockfd, bool fastopen, const ComboAddress& remote, const struct timeval& timeout);
2055 int SBind(int sockfd, const ComboAddress& local);
2056 int SAccept(int sockfd, ComboAddress& remote);
2057 int SListen(int sockfd, int limit);
2058 int SSetsockopt(int sockfd, int level, int opname, int value);
2059 void setSocketIgnorePMTU(int sockfd, int family);
2060 void setSocketForcePMTU(int sockfd, int family);
2061 bool setReusePort(int sockfd);
2062
2063 #if defined(IP_PKTINFO)
2064 #define GEN_IP_PKTINFO IP_PKTINFO
2065 #elif defined(IP_RECVDSTADDR)
2066 #define GEN_IP_PKTINFO IP_RECVDSTADDR
2067 #endif
2068
2069 bool IsAnyAddress(const ComboAddress& addr);
2070 bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destination);
2071 bool HarvestTimestamp(struct msghdr* msgh, struct timeval* timeval);
2072 void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, cmsgbuf_aligned* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
2073 int sendOnNBSocket(int fileDesc, const struct msghdr* msgh);
2074 size_t sendMsgWithOptions(int socketDesc, const void* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags);
2075
2076 /* requires a non-blocking, connected TCP socket */
2077 bool isTCPSocketUsable(int sock);
2078
2079 extern template class NetmaskTree<bool>;
2080 ComboAddress parseIPAndPort(const std::string& input, uint16_t port);
2081
2082 std::set<std::string> getListOfNetworkInterfaces();
2083 std::vector<ComboAddress> getListOfAddressesOfNetworkInterface(const std::string& itf);
2084 std::vector<Netmask> getListOfRangesOfNetworkInterface(const std::string& itf);
2085
2086 /* These functions throw if the value was already set to a higher value,
2087 or on error */
2088 void setSocketBuffer(int fileDesc, int optname, uint32_t size);
2089 void setSocketReceiveBuffer(int fileDesc, uint32_t size);
2090 void setSocketSendBuffer(int fileDesc, uint32_t size);
2091 uint32_t raiseSocketReceiveBufferToMax(int socket);
2092 uint32_t raiseSocketSendBufferToMax(int socket);