pr.xfrBlob(blob);
pw.xfrBlob(blob);
} else {
+
pr.skip(ah.d_clen);
}
}
return 0;
}
+static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+ for (auto it = options.begin(); it != options.end(); ) {
+ if (it->first == EDNSOptionCode::ECS) {
+ ecsAdded = false;
+
+ if (!overrideExisting) {
+ return false;
+ }
+
+ it = options.erase(it);
+ }
+ else {
+ ++it;
+ }
+ }
+
+ options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
+ return true;
+}
+
+static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+ assert(initialPacket.size() >= sizeof(dnsheader));
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+ ecsAdded = false;
+ ednsAdded = true;
+
+ if (ntohs(dh->qdcount) == 0) {
+ return false;
+ }
+
+ if (ntohs(dh->arcount) == 0) {
+ throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
+ }
+
+ PacketReader pr(initialPacket);
+
+ size_t idx = 0;
+ DNSName rrname;
+ uint16_t qdcount = ntohs(dh->qdcount);
+ uint16_t ancount = ntohs(dh->ancount);
+ uint16_t nscount = ntohs(dh->nscount);
+ uint16_t arcount = ntohs(dh->arcount);
+ uint16_t rrtype;
+ uint16_t rrclass;
+ string blob;
+ struct dnsrecordheader ah;
+
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+
+ DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+ pw.getHeader()->id=dh->id;
+ pw.getHeader()->qr=dh->qr;
+ pw.getHeader()->aa=dh->aa;
+ pw.getHeader()->tc=dh->tc;
+ pw.getHeader()->rd=dh->rd;
+ pw.getHeader()->ra=dh->ra;
+ pw.getHeader()->ad=dh->ad;
+ pw.getHeader()->cd=dh->cd;
+ pw.getHeader()->rcode=dh->rcode;
+
+ /* consume remaining qd if any */
+ if (qdcount > 1) {
+ for(idx = 1; idx < qdcount; idx++) {
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+ (void) rrtype;
+ (void) rrclass;
+ }
+ }
+
+ /* copy AN and NS */
+ for (idx = 0; idx < ancount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ }
+
+ for (idx = 0; idx < nscount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ }
+
+ /* consume AR, looking for OPT */
+ for (idx = 0; idx < arcount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+
+ if (ah.d_type != QType::OPT) {
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+ pr.xfrBlob(blob);
+ pw.xfrBlob(blob);
+ } else {
+
+ ednsAdded = false;
+ pr.xfrBlob(blob);
+
+ std::vector<std::pair<uint16_t, std::string>> options;
+ getEDNSOptionsFromContent(blob, options);
+
+ EDNS0Record edns0;
+ static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+ memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
+
+ /* addOrReplaceECSOption will set it to false if there is already an existing option */
+ ecsAdded = true;
+ addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
+ pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
+ }
+ }
+
+ if (ednsAdded) {
+ pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
+ ecsAdded = true;
+ }
+
+ pw.commit();
+
+ return true;
+}
+
+static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr<std::map<uint16_t, EDNSOptionView> >& options)
+{
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+ if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0)
+ {
+ return false;
+ }
+
+ if (ntohs(dh->arcount) == 0) {
+ throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
+ }
+
+ try {
+ uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+ DNSPacketMangler dpm(const_cast<char*>(packet), len);
+ uint64_t n;
+ for(n=0; n < ntohs(dh->qdcount) ; ++n) {
+ dpm.skipDomainName();
+ /* type and class */
+ dpm.skipBytes(4);
+ }
+
+ for(n=0; n < numrecords; ++n) {
+ dpm.skipDomainName();
+
+ uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
+ uint16_t dnstype = dpm.get16BitInt();
+ dpm.get16BitInt();
+ dpm.skipBytes(4); /* TTL */
+
+ if(section == 3 && dnstype == QType::OPT) {
+ uint32_t offset = dpm.getOffset();
+ if (offset >= len) {
+ return false;
+ }
+ /* if we survive this call, we can parse it safely */
+ dpm.skipRData();
+ return getEDNSOptions(packet + offset, len - offset, *options) == 0;
+ }
+ else {
+ dpm.skipRData();
+ }
+ }
+ }
+ catch(...)
+ {
+ return false;
+ }
+
+ return true;
+}
+
int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
{
assert(optStart != NULL);
}
dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
+
+ if (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1) {
+ return slowParseEDNSOptions(reinterpret_cast<const char*>(dq.dh), dq.len, dq.ednsOptions);
+ }
+
const char* packet = reinterpret_cast<const char*>(dq.dh);
size_t remaining = 0;
return false;
}
-static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded)
+static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded)
{
/* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
/* getEDNSOptionsStart has already checked that there is exactly one AR,
memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
*len += newECSOptionSize;
- *ecsAdded = true;
+ ecsAdded = true;
return true;
}
-static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData)
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
{
/* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
string EDNSRR;
uint16_t arcount = ntohs(dh->arcount);
arcount++;
dh->arcount = htons(arcount);
- *ednsAdded = true;
+ ednsAdded = true;
+ ecsAdded = true;
memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
return true;
}
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
{
assert(packet != nullptr);
assert(len != nullptr);
assert(consumed <= (size_t) *len);
- assert(ednsAdded != nullptr);
- assert(ecsAdded != nullptr);
+
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+ if (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1) {
+ vector<uint8_t> newContent;
+ newContent.reserve(packetSize);
+
+ if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+ ednsAdded = false;
+ ecsAdded = false;
+ return false;
+ }
+
+ if (newContent.size() > packetSize) {
+ ednsAdded = false;
+ ecsAdded = false;
+ return false;
+ }
+
+ memcpy(packet, &newContent.at(0), newContent.size());
+ *len = newContent.size();
+ return true;
+ }
+
uint16_t optRDPosition = 0;
size_t remaining = 0;
int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
if (res != 0) {
- return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData);
+ return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData);
}
unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
return true;
}
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData)
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
{
assert(dq.remote != nullptr);
string newECSOption;
static const uint16_t ECSSourcePrefixV4 = 24;
static const uint16_t ECSSourcePrefixV6 = 56;
-static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false)
+static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false, uint16_t additionals=0)
{
MOADNSParser mdp(true, packet, packetSize);
BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U);
BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U);
BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0U);
- uint16_t expectedARCount = 0 + (hasEdns ? 1 : 0) + (hasXPF ? 1 : 0);
+ uint16_t expectedARCount = additionals + (hasEdns ? 1U : 0U) + (hasXPF ? 1U : 0U);
BOOST_CHECK_EQUAL(mdp.d_header.arcount, expectedARCount);
}
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
BOOST_CHECK(static_cast<size_t>(len) > query.size());
BOOST_CHECK_EQUAL(ednsAdded, true);
- BOOST_CHECK_EQUAL(ecsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
validateQuery(packet, len);
validateECS(packet, len, remote);
vector<uint8_t> queryWithEDNS;
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
BOOST_CHECK_EQUAL(static_cast<size_t>(len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
packet[len + idx] = 'A';
}
len += trailingDataSize;
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size());
BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
BOOST_CHECK_EQUAL(ednsAdded, true);
- BOOST_CHECK_EQUAL(ecsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
validateQuery(packet, len);
/* packet with trailing data (preserving trailing data) */
packet[len + idx] = 'A';
}
len += trailingDataSize;
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, true));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, true));
BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size() + trailingDataSize);
BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
for (size_t idx = 0; idx < trailingDataSize; idx++) {
BOOST_CHECK_EQUAL(packet[queryWithEDNS.size() + idx], 'A');
}
BOOST_CHECK_EQUAL(ednsAdded, true);
- BOOST_CHECK_EQUAL(ecsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
validateQuery(packet, len);
}
BOOST_CHECK(!parseEDNSOptions(dq));
/* And now we add our own ECS */
- BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+ BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, true);
- BOOST_CHECK_EQUAL(ecsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
validateQuery(packet, dq.len);
validateECS(packet, dq.len, remote);
BOOST_CHECK(qclass == QClass::IN);
DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
- BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
+ BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false));
BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
BOOST_CHECK((size_t) len > query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, true);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
BOOST_CHECK_EQUAL((size_t) len, query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK(parseEDNSOptions(dq));
/* And now we add our own ECS */
- BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+ BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, true);
BOOST_CHECK(qclass == QClass::IN);
DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
- BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
+ BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false));
BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
BOOST_CHECK_EQUAL((size_t) len, query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK(parseEDNSOptions(dq));
/* And now we add our own ECS */
- BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+ BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
BOOST_CHECK((size_t) len < query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
BOOST_CHECK((size_t) len > query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
BOOST_CHECK_EQUAL(qname, name);
BOOST_CHECK(qtype == QType::A);
- BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
BOOST_CHECK_EQUAL((size_t) len, query.size());
BOOST_CHECK_EQUAL(ednsAdded, false);
BOOST_CHECK_EQUAL(ecsAdded, false);
validateQuery(reinterpret_cast<char*>(query.data()), len);
}
+BOOST_AUTO_TEST_CASE(replaceECSFollowedByTSIG) {
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ ComboAddress remote("192.168.1.25");
+ DNSName name("www.powerdns.com.");
+ ComboAddress origRemote("127.0.0.1");
+ string newECSOption;
+ generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+ vector<uint8_t> query;
+ DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+ pw.getHeader()->rd = 1;
+ EDNSSubnetOpts ecsOpts;
+ ecsOpts.source = Netmask(origRemote, 8);
+ string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+ DNSPacketWriter::optvect_t opts;
+ opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
+ pw.addOpt(512, 0, 0, opts);
+ pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+ pw.commit();
+ uint16_t len = query.size();
+
+ /* large enough packet */
+ char packet[1500];
+ memcpy(packet, query.data(), query.size());
+
+ unsigned int consumed = 0;
+ uint16_t qtype;
+ DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK((size_t) len > query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(packet, len, true, false, 1);
+ validateECS(packet, len, remote);
+
+ /* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ len = query.size();
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK_EQUAL((size_t) len, query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 1);
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSBetweenTwoRecords) {
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ ComboAddress remote("192.168.1.25");
+ DNSName name("www.powerdns.com.");
+ ComboAddress origRemote("127.0.0.1");
+ string newECSOption;
+ generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+ vector<uint8_t> query;
+ DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+ pw.getHeader()->rd = 1;
+ EDNSSubnetOpts ecsOpts;
+ ecsOpts.source = Netmask(origRemote, 8);
+ string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+ DNSPacketWriter::optvect_t opts;
+ opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
+ pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false);
+ pw.xfr32BitInt(0x01020304);
+ pw.addOpt(512, 0, 0, opts);
+ pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+ pw.commit();
+ uint16_t len = query.size();
+
+ /* large enough packet */
+ char packet[1500];
+ memcpy(packet, query.data(), query.size());
+
+ unsigned int consumed = 0;
+ uint16_t qtype;
+ DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK((size_t) len > query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(packet, len, true, false, 2);
+ validateECS(packet, len, remote);
+
+ /* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ len = query.size();
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK_EQUAL((size_t) len, query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 2);
+}
+
+BOOST_AUTO_TEST_CASE(insertECSInEDNSBetweenTwoRecords) {
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ ComboAddress remote("192.168.1.25");
+ DNSName name("www.powerdns.com.");
+ ComboAddress origRemote("127.0.0.1");
+ string newECSOption;
+ generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+ vector<uint8_t> query;
+ DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+ pw.getHeader()->rd = 1;
+ pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false);
+ pw.xfr32BitInt(0x01020304);
+ pw.addOpt(512, 0, 0);
+ pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+ pw.commit();
+ uint16_t len = query.size();
+
+ /* large enough packet */
+ char packet[1500];
+ memcpy(packet, query.data(), query.size());
+
+ unsigned int consumed = 0;
+ uint16_t qtype;
+ DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK((size_t) len > query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
+ validateQuery(packet, len, true, false, 2);
+ validateECS(packet, len, remote);
+
+ /* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ len = query.size();
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK_EQUAL((size_t) len, query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 2);
+}
+
+BOOST_AUTO_TEST_CASE(insertECSAfterTSIG) {
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ ComboAddress remote("192.168.1.25");
+ DNSName name("www.powerdns.com.");
+ ComboAddress origRemote("127.0.0.1");
+ string newECSOption;
+ generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+ vector<uint8_t> query;
+ DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+ pw.getHeader()->rd = 1;
+ pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+ pw.commit();
+ uint16_t len = query.size();
+
+ /* large enough packet */
+ char packet[1500];
+ memcpy(packet, query.data(), query.size());
+
+ unsigned int consumed = 0;
+ uint16_t qtype;
+ DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK((size_t) len > query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, true);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
+ /* the MOADNSParser does not allow anything except XPF after a TSIG */
+ BOOST_CHECK_THROW(validateQuery(packet, len, true, false, 1), MOADNSException);
+ validateECS(packet, len, remote);
+
+ /* not large enough packet */
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ len = query.size();
+ qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+ BOOST_CHECK_EQUAL((size_t) len, query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ validateQuery(reinterpret_cast<char*>(query.data()), len, true, false);
+}
+
BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) {
DNSName name("www.powerdns.com.");
ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse])
expectedResponse = dns.message.make_response(query)
+ expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
response.use_edns(edns=True, payload=4096, options=[ecsoResponse, ecoResponse])
expectedResponse = dns.message.make_response(query, our_payload=4096)
+ expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse, ecoResponse])
expectedResponse = dns.message.make_response(query, our_payload=4096)
+ expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse, ecoResponse])
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
self.checkResponseEDNSWithECS(response, receivedResponse)
+ def testWithECSFollowedByAnother(self):
+ """
+ ECS: Existing EDNS with ECS, followed by another record
+
+ Send a query with EDNS and an existing ECS value.
+ The OPT record is not the last one in the query
+ and is followed by another one.
+ Check that the query received by the responder
+ has a valid ECS value and that the response
+ received from dnsdist contains an EDNS pseudo-RR.
+ """
+ name = 'withecs-followedbyanother.ecs.tests.powerdns.com.'
+ ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
+ eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+ rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,ecso,eco])
+ # I would have loved to use a TSIG here but I can't find how to make dnspython ignore
+ # it while parsing the message in the receiver :-/
+ query.additional.append(rrset)
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,eco,rewrittenEcso])
+ expectedQuery.additional.append(rrset)
+
+ response = dns.message.make_response(expectedQuery)
+ response.use_edns(edns=True, payload=4096, options=[eco, ecso, eco])
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.use_edns(edns=True, payload=4096, options=[eco, ecso, eco])
+ response.answer.append(rrset)
+ response.additional.append(rrset)
+ expectedResponse.answer.append(rrset)
+ expectedResponse.additional.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = expectedQuery.id
+ self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 2)
+ self.checkResponseEDNSWithECS(expectedResponse, receivedResponse, 2)
+
+ def testWithEDNSNoECSFollowedByAnother(self):
+ """
+ ECS: Existing EDNS without ECS, followed by another record
+
+ Send a query with EDNS but no ECS value.
+ The OPT record is not the last one in the query
+ and is followed by another one.
+ Check that the query received by the responder
+ has a valid ECS value and that the response
+ received from dnsdist contains an EDNS pseudo-RR.
+ """
+ name = 'withedns-no-ecs-followedbyanother.ecs.tests.powerdns.com.'
+ eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+ rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+ # I would have loved to use a TSIG here but I can't find how to make dnspython ignore
+ # it while parsing the message in the receiver :-/
+ query.additional.append(rrset)
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,rewrittenEcso])
+ expectedQuery.additional.append(rrset)
+
+ response = dns.message.make_response(expectedQuery)
+ response.use_edns(edns=True, payload=4096, options=[eco, rewrittenEcso, eco])
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.use_edns(edns=True, payload=4096, options=[eco, eco])
+ response.answer.append(rrset)
+ response.additional.append(rrset)
+ expectedResponse.answer.append(rrset)
+ expectedResponse.additional.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = expectedQuery.id
+ self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+ self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, 2)
+
class TestECSDisabledByRuleOrLua(DNSDistTest):
"""
dnsdist is configured to add the EDNS0 Client Subnet