return true;
}
-static bool slowRewriteQueryWithExistingEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+static bool slowRewriteQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
{
assert(initialPacket.size() >= sizeof(dnsheader));
const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
return false;
}
- if (ntohs(dh->arcount) == 0) {
- throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
+ if (ntohs(dh->ancount) == 0 && ntohs(dh->nscount) == 0 && ntohs(dh->arcount) == 0) {
+ throw std::runtime_error("slowRewriteQueryWithRecords() should not be called for queries that have no records");
}
PacketReader pr(pdns_string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
PacketBuffer newContent;
newContent.reserve(packet.size());
- if (!slowRewriteQueryWithExistingEDNS(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+ if (!slowRewriteQueryWithRecords(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
ednsAdded = false;
ecsAdded = false;
return false;
if (res != 0) {
/* no EDNS but there might be another record in additional (TSIG?) */
+ /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
if (packet.size() > minimumPacketSize) {
if (ntohs(dh->arcount) == 0) {
validateQuery(packet);
}
+BOOST_AUTO_TEST_CASE(addECSWithoutEDNSButWithAnswer)
+{
+ /* this might happen for NOTIFY queries where, according to rfc1996:
+ "If ANCOUNT>0, then the answer section represents an
+ unsecure hint at the new RRset for this <QNAME,QCLASS,QTYPE>".
+ */
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+ ComboAddress remote("192.0.2.1");
+ DNSName name("www.powerdns.com.");
+ string newECSOption;
+ generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+ PacketBuffer query;
+ GenericDNSPacketWriter<PacketBuffer> pw(query, name, QType::A, QClass::IN, 0);
+ pw.getHeader()->rd = 1;
+ pw.startRecord(name, QType::A, 60, QClass::IN, DNSResourceRecord::ANSWER, false);
+ pw.xfrIP(remote.sin4.sin_addr.s_addr);
+ pw.commit();
+ uint16_t len = query.size();
+
+ /* large enough packet */
+ PacketBuffer packet = query;
+
+ unsigned int consumed = 0;
+ uint16_t qtype;
+ DNSName qname(reinterpret_cast<char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption));
+ BOOST_CHECK(packet.size() > query.size());
+ BOOST_CHECK_EQUAL(ednsAdded, true);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
+ validateQuery(packet, true, false, 0, 1);
+ validateECS(packet, remote);
+ PacketBuffer queryWithEDNS = packet;
+
+ /* not large enough packet */
+ packet = query;
+
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ qname = DNSName(reinterpret_cast<char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+
+ BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, false, newECSOption));
+ BOOST_CHECK_EQUAL(ednsAdded, false);
+ BOOST_CHECK_EQUAL(ecsAdded, false);
+ packet.resize(query.size());
+ validateQuery(packet, false, false, 0, 1);
+
+ /* packet with trailing data (overriding it) */
+ packet = query;
+ ednsAdded = false;
+ ecsAdded = false;
+ consumed = 0;
+ qname = DNSName(reinterpret_cast<char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
+ BOOST_CHECK_EQUAL(qname, name);
+ BOOST_CHECK(qtype == QType::A);
+ /* add trailing data */
+ const size_t trailingDataSize = 10;
+ /* Making sure we have enough room to allow for fake trailing data */
+ packet.resize(packet.size() + trailingDataSize);
+ for (size_t idx = 0; idx < trailingDataSize; idx++) {
+ packet[len + idx] = 'A';
+ }
+
+ BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption));
+ BOOST_REQUIRE_EQUAL(packet.size(), queryWithEDNS.size());
+ BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet.data(), queryWithEDNS.size()), 0);
+ BOOST_CHECK_EQUAL(ednsAdded, true);
+ BOOST_CHECK_EQUAL(ecsAdded, true);
+ validateQuery(packet, true, false, 0, 1);
+}
+
BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
{
bool ednsAdded = false;