pr.xfrBlob(blob);
pw.xfrBlob(blob);
} else {
+
pr.skip(ah.d_clen);
}
}
return 0;
}
+static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+ for (auto it = options.begin(); it != options.end(); ) {
+ if (it->first == EDNSOptionCode::ECS) {
+ ecsAdded = false;
+
+ if (!overrideExisting) {
+ return false;
+ }
+
+ it = options.erase(it);
+ }
+ else {
+ ++it;
+ }
+ }
+
+ options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
+ return true;
+}
+
+static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+ assert(initialPacket.size() >= sizeof(dnsheader));
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+ ecsAdded = false;
+ ednsAdded = true;
+
+ if (ntohs(dh->qdcount) == 0) {
+ return false;
+ }
+
+ if (ntohs(dh->arcount) == 0) {
+ throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
+ }
+
+ PacketReader pr(initialPacket);
+
+ size_t idx = 0;
+ DNSName rrname;
+ uint16_t qdcount = ntohs(dh->qdcount);
+ uint16_t ancount = ntohs(dh->ancount);
+ uint16_t nscount = ntohs(dh->nscount);
+ uint16_t arcount = ntohs(dh->arcount);
+ uint16_t rrtype;
+ uint16_t rrclass;
+ string blob;
+ struct dnsrecordheader ah;
+
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+
+ DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+ pw.getHeader()->id=dh->id;
+ pw.getHeader()->qr=dh->qr;
+ pw.getHeader()->aa=dh->aa;
+ pw.getHeader()->tc=dh->tc;
+ pw.getHeader()->rd=dh->rd;
+ pw.getHeader()->ra=dh->ra;
+ pw.getHeader()->ad=dh->ad;
+ pw.getHeader()->cd=dh->cd;
+ pw.getHeader()->rcode=dh->rcode;
+
+ /* consume remaining qd if any */
+ if (qdcount > 1) {
+ for(idx = 1; idx < qdcount; idx++) {
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+ (void) rrtype;
+ (void) rrclass;
+ }
+ }
+
+ /* copy AN and NS */
+ for (idx = 0; idx < ancount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ }
+
+ for (idx = 0; idx < nscount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ }
+
+ /* consume AR, looking for OPT */
+ for (idx = 0; idx < arcount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ if (ah.d_type != QType::OPT) {
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ } else {
+
+ ednsAdded = false;
+ pr.xfrBlob(blob);
+
+ std::vector<std::pair<uint16_t, std::string>> options;
+ getEDNSOptionsFromContent(blob, options);
+
+ EDNS0Record edns0;
+ static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+ memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
+
+ /* addOrReplaceECSOption will set it to false if there is already an existing option */
+ ecsAdded = true;
+ addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
+ pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
+ }
+ }
+
+ if (ednsAdded) {
+ pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
+ ecsAdded = true;
+ }
+
+ pw.commit();
+
+ return true;
+}
+
+static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr<std::map<uint16_t, EDNSOptionView> >& options)
+{
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+ if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0)
+ {
+ return false;
+ }
+
+ if (ntohs(dh->arcount) == 0) {
+ throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
+ }
+
+ try {
+ uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+ DNSPacketMangler dpm(const_cast<char*>(packet), len);
+ uint64_t n;
+ for(n=0; n < ntohs(dh->qdcount) ; ++n) {
+ dpm.skipDomainName();
+ /* type and class */
+ dpm.skipBytes(4);
+ }
+
+ for(n=0; n < numrecords; ++n) {
+ dpm.skipDomainName();
+
+ uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
+ uint16_t dnstype = dpm.get16BitInt();
+ dpm.get16BitInt();
+ dpm.skipBytes(4); /* TTL */
+
+ if(section == 3 && dnstype == QType::OPT) {
+ uint32_t offset = dpm.getOffset();
+ if (offset >= len) {
+ return false;
+ }
+ /* if we survive this call, we can parse it safely */
+ dpm.skipRData();
+ return getEDNSOptions(packet + offset, len - offset, *options) == 0;
+ }
+ else {
+ dpm.skipRData();
+ }
+ }
+ }
+ catch(...)
+ {
+ return false;
+ }
+
+ return true;
+}
+
int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
{
assert(optStart != NULL);
dh.d_class = htons(udpPayloadSize);
static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
memcpy(&dh.d_ttl, &edns0, sizeof edns0);
- dh.d_clen = htons((uint16_t) optRData.length());
+ dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
- res.assign((const char *) &name, sizeof name);
- res.append((const char *) &dh, sizeof dh);
+ res.assign(reinterpret_cast<const char *>(&name), sizeof name);
+ res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
res.append(optRData.c_str(), optRData.length());
}
}
dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
+
+ if (ntohs(dq.dh->ancount) != 0 || ntohs(dq.dh->nscount) != 0 || (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1)) {
+ return slowParseEDNSOptions(reinterpret_cast<const char*>(dq.dh), dq.len, dq.ednsOptions);
+ }
+
const char* packet = reinterpret_cast<const char*>(dq.dh);
size_t remaining = 0;
return false;
}
-static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded)
+static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded)
{
/* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
/* getEDNSOptionsStart has already checked that there is exactly one AR,
memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
*len += newECSOptionSize;
- *ecsAdded = true;
+ ecsAdded = true;
return true;
}
-static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData)
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
{
/* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
string EDNSRR;
uint16_t arcount = ntohs(dh->arcount);
arcount++;
dh->arcount = htons(arcount);
- *ednsAdded = true;
+ ednsAdded = true;
+ ecsAdded = true;
memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
return true;
}
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
{
assert(packet != nullptr);
assert(len != nullptr);
assert(consumed <= (size_t) *len);
- assert(ednsAdded != nullptr);
- assert(ecsAdded != nullptr);
+
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+ if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
+ vector<uint8_t> newContent;
+ newContent.reserve(packetSize);
+
+ if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+ ednsAdded = false;
+ ecsAdded = false;
+ return false;
+ }
+
+ if (newContent.size() > packetSize) {
+ ednsAdded = false;
+ ecsAdded = false;
+ return false;
+ }
+
+ memcpy(packet, &newContent.at(0), newContent.size());
+ *len = newContent.size();
+ return true;
+ }
+
uint16_t optRDPosition = 0;
size_t remaining = 0;
int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
if (res != 0) {
- return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData);
+ return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData);
}
unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
return true;
}
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData)
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
{
assert(dq.remote != nullptr);
string newECSOption;
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
{
- /* we need at least:
- root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
- if (*optLen < 11) {
+ if (*optLen < optRecordMinimumSize) {
return EINVAL;
}
const unsigned char* end = (const unsigned char*) optStart + *optLen;
bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
{
- /* we need at least:
- root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
- if (optLen < 11) {
+ if (optLen < optRecordMinimumSize) {
return false;
}
size_t p = optStart + 9;
- uint16_t rdLen = (0x100*packet.at(p) + packet.at(p+1));
+ uint16_t rdLen = (0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1)));
p += sizeof(rdLen);
- if (rdLen > (optLen - 11)) {
+ if (rdLen > (optLen - optRecordMinimumSize)) {
return false;
}
size_t rdEnd = p + rdLen;
while ((p + 4) <= rdEnd) {
- const uint16_t optionCode = 0x100*packet.at(p) + packet.at(p+1);
+ const uint16_t optionCode = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
p += sizeof(optionCode);
- const uint16_t optionLen = 0x100*packet.at(p) + packet.at(p+1);
+ const uint16_t optionLen = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
p += sizeof(optionLen);
if ((p + optionLen) > rdEnd) {
bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
{
- if (dh->arcount != 0) {
- return false;
- }
-
std::string optRecord;
generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK);
char * optPtr = reinterpret_cast<char*>(dh) + len;
memcpy(optPtr, optRecord.data(), optRecord.size());
len += optRecord.size();
- dh->arcount = htons(1);
+ dh->arcount = htons(ntohs(dh->arcount) + 1);
+
+ return true;
+}
+
+/*
+ This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
+ generating a NXD or NODATA answer with a SOA record in the additional section.
+*/
+bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum)
+{
+ if (ntohs(dq.dh->qdcount) != 1) {
+ return false;
+ }
+
+ assert(dq.consumed == dq.qname->wirelength());
+ size_t queryPartSize = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+ if (dq.len < queryPartSize) {
+ /* something is already wrong, don't build on flawed foundations */
+ return false;
+ }
+
+ size_t available = dq.size - queryPartSize;
+ uint16_t qtype = htons(QType::SOA);
+ uint16_t qclass = htons(QClass::IN);
+ uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
+ size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
+
+ if (soaSize > available) {
+ /* not enough space left to add the SOA, sorry! */
+ return false;
+ }
+
+ bool hadEDNS = false;
+ bool dnssecOK = false;
+
+ if (g_addEDNSToSelfGeneratedResponses) {
+ uint16_t payloadSize = 0;
+ uint16_t z = 0;
+ hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &payloadSize, &z);
+ if (hadEDNS) {
+ dnssecOK = z & EDNS_HEADER_FLAG_DO;
+ }
+ }
+
+ /* chop off everything after the question */
+ dq.len = queryPartSize;
+ if (nxd) {
+ dq.dh->rcode = RCode::NXDomain;
+ }
+ else {
+ dq.dh->rcode = RCode::NoError;
+ }
+ dq.dh->qr = true;
+ dq.dh->ancount = 0;
+ dq.dh->nscount = 0;
+ dq.dh->arcount = 0;
+
+ rdLength = htons(rdLength);
+ ttl = htonl(ttl);
+ serial = htonl(serial);
+ refresh = htonl(refresh);
+ retry = htonl(retry);
+ expire = htonl(expire);
+ minimum = htonl(minimum);
+
+ std::string soa;
+ soa.reserve(soaSize);
+ soa.append(zone.toDNSString());
+ soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
+ soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
+ soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
+ soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
+ soa.append(mname.toDNSString());
+ soa.append(rname.toDNSString());
+ soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
+ soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
+ soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
+ soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
+ soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
+
+ if (soa.size() != soaSize) {
+ throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
+ }
+
+ memcpy(reinterpret_cast<char*>(dq.dh) + queryPartSize, soa.c_str(), soa.size());
+
+ dq.len += soa.size();
+
+ dq.dh->arcount = htons(1);
+
+ if (g_addEDNSToSelfGeneratedResponses) {
+ /* now we need to add a new OPT record */
+ return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
+ }
return true;
}
return false;
}
+
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
+{
+ uint16_t optStart;
+ size_t optLen = 0;
+ bool last = false;
+ const char * packet = reinterpret_cast<const char*>(dq.dh);
+ std::string packetStr(packet, dq.len);
+ int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+ if (res != 0) {
+ // no EDNS OPT RR
+ return false;
+ }
+
+ if (optLen < optRecordMinimumSize) {
+ return false;
+ }
+
+ if (optStart < dq.len && packetStr.at(optStart) != 0) {
+ // OPT RR Name != '.'
+ return false;
+ }
+
+ static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+ // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
+ memcpy(&edns0, packet + optStart + 5, sizeof edns0);
+ return true;
+}