editDNSPacketTTL(reinterpret_cast<char*>(packet.data()), packet.size(), visitor);
}
+ void restoreFlags(struct dnsheader* dnsHeader, uint16_t origFlags)
+ {
+ static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
+ static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
+ static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
+ uint16_t* flags = getFlagsFromDNSHeader(dnsHeader);
+ /* clear the flags we are about to restore */
+ *flags &= restoreFlagsMask;
+ /* only keep the flags we want to restore */
+ origFlags &= ~restoreFlagsMask;
+ /* set the saved flags as they were */
+ *flags |= origFlags;
+ }
}
namespace RecordParsers
bool editDNSHeaderFromPacket(PacketBuffer& packet, const std::function<bool(dnsheader& header)>& editFunction);
bool editDNSHeaderFromRawPacket(void* packet, const std::function<bool(dnsheader& header)>& editFunction);
void restrictDNSPacketTTLs(PacketBuffer& packet, uint32_t minimumValue, uint32_t maximumValue = std::numeric_limits<uint32_t>::max(), const std::unordered_set<QType>& types = {});
+ void restoreFlags(struct dnsheader* dnsHeader, uint16_t origFlags);
}
namespace RecordParsers
}
}
-static void restoreFlags(struct dnsheader* dnsHeader, uint16_t origFlags)
-{
- static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
- static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
- static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
- uint16_t* flags = getFlagsFromDNSHeader(dnsHeader);
- /* clear the flags we are about to restore */
- *flags &= restoreFlagsMask;
- /* only keep the flags we want to restore */
- origFlags &= ~restoreFlagsMask;
- /* set the saved flags as they were */
- *flags |= origFlags;
-}
-
static bool fixUpQueryTurnedResponse(DNSQuestion& dnsQuestion, const uint16_t origFlags)
{
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [origFlags](dnsheader& header) {
- restoreFlags(&header, origFlags);
+ dnsdist::PacketMangling::restoreFlags(&header, origFlags);
return true;
});
}
dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [origFlags](dnsheader& header) {
- restoreFlags(&header, origFlags);
+ dnsdist::PacketMangling::restoreFlags(&header, origFlags);
return true;
});
if (serverPool.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, dnsQuestion.ids.protocol == dnsdist::Protocol::DoH ? &dnsQuestion.ids.cacheKeyTCP : &dnsQuestion.ids.cacheKey, dnsQuestion.ids.subnet, *dnsQuestion.ids.dnssecOK, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH && willBeForwardedOverUDP, allowExpired, false, true, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH || !willBeForwardedOverUDP)) {
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [flags = dnsQuestion.ids.origFlags](dnsheader& header) {
- restoreFlags(&header, flags);
+ dnsdist::PacketMangling::restoreFlags(&header, flags);
return true;
});
dnsdist::PacketMangling::editDNSHeaderFromPacket(payload, [&ids](dnsheader& header) {
memset(&header, 0, sizeof(header));
header.id = ids.origID;
- restoreFlags(&header, ids.origFlags);
+ dnsdist::PacketMangling::restoreFlags(&header, ids.origFlags);
// set QR=1 since this is a response rule
header.qr = 1;
// do not set the qdcount, otherwise the protobuf code will choke on it