From 7bd50190ccb39019a2c733b722bdeb958d749065 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Mon, 17 May 2021 15:39:46 +0200 Subject: [PATCH] dnsdist: Properly handle ECS for queries with ancount or nscount > 0 --- pdns/dnsdist-ecs.cc | 9 ++--- pdns/test-dnsdist_cc.cc | 78 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index fc664da941..e7b6fdd0cc 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -149,7 +149,7 @@ static bool addOrReplaceECSOption(std::vector>& return true; } -static bool slowRewriteQueryWithExistingEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption) +static bool slowRewriteQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption) { assert(initialPacket.size() >= sizeof(dnsheader)); const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); @@ -161,8 +161,8 @@ static bool slowRewriteQueryWithExistingEDNS(const PacketBuffer& initialPacket, return false; } - if (ntohs(dh->arcount) == 0) { - throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS"); + if (ntohs(dh->ancount) == 0 && ntohs(dh->nscount) == 0 && ntohs(dh->arcount) == 0) { + throw std::runtime_error("slowRewriteQueryWithRecords() should not be called for queries that have no records"); } PacketReader pr(pdns_string_view(reinterpret_cast(initialPacket.data()), initialPacket.size())); @@ -588,7 +588,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons PacketBuffer newContent; newContent.reserve(packet.size()); - if (!slowRewriteQueryWithExistingEDNS(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) { + if (!slowRewriteQueryWithRecords(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) { ednsAdded = false; ecsAdded = false; return false; @@ -611,6 +611,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons if (res != 0) { /* no EDNS but there might be another record in additional (TSIG?) */ + /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */ size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t); if (packet.size() > minimumPacketSize) { if (ntohs(dh->arcount) == 0) { diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index d17e9a0c60..9986213e08 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -238,6 +238,84 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) validateQuery(packet); } +BOOST_AUTO_TEST_CASE(addECSWithoutEDNSButWithAnswer) +{ + /* this might happen for NOTIFY queries where, according to rfc1996: + "If ANCOUNT>0, then the answer section represents an + unsecure hint at the new RRset for this ". + */ + bool ednsAdded = false; + bool ecsAdded = false; + ComboAddress remote("192.0.2.1"); + DNSName name("www.powerdns.com."); + string newECSOption; + generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + + PacketBuffer query; + GenericDNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + pw.getHeader()->rd = 1; + pw.startRecord(name, QType::A, 60, QClass::IN, DNSResourceRecord::ANSWER, false); + pw.xfrIP(remote.sin4.sin_addr.s_addr); + pw.commit(); + uint16_t len = query.size(); + + /* large enough packet */ + PacketBuffer packet = query; + + unsigned int consumed = 0; + uint16_t qtype; + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); + BOOST_CHECK_EQUAL(ednsAdded, true); + BOOST_CHECK_EQUAL(ecsAdded, true); + validateQuery(packet, true, false, 0, 1); + validateECS(packet, remote); + PacketBuffer queryWithEDNS = packet; + + /* not large enough packet */ + packet = query; + + ednsAdded = false; + ecsAdded = false; + consumed = 0; + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + packet.resize(query.size()); + validateQuery(packet, false, false, 0, 1); + + /* packet with trailing data (overriding it) */ + packet = query; + ednsAdded = false; + ecsAdded = false; + consumed = 0; + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + /* add trailing data */ + const size_t trailingDataSize = 10; + /* Making sure we have enough room to allow for fake trailing data */ + packet.resize(packet.size() + trailingDataSize); + for (size_t idx = 0; idx < trailingDataSize; idx++) { + packet[len + idx] = 'A'; + } + + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_REQUIRE_EQUAL(packet.size(), queryWithEDNS.size()); + BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet.data(), queryWithEDNS.size()), 0); + BOOST_CHECK_EQUAL(ednsAdded, true); + BOOST_CHECK_EQUAL(ecsAdded, true); + validateQuery(packet, true, false, 0, 1); +} + BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) { bool ednsAdded = false; -- 2.47.2