]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua FFI accessors for EDNS version and extended rcode
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 21 Aug 2024 12:48:24 +0000 (14:48 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 21 Aug 2024 12:48:24 +0000 (14:48 +0200)
12 files changed:
pdns/dnsdistdist/dnsdist-cache.cc
pdns/dnsdistdist/dnsdist-ecs.cc
pdns/dnsdistdist/dnsdist-ecs.hh
pdns/dnsdistdist/dnsdist-lua-actions.cc
pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/dnsdist-svc.cc
pdns/dnsdistdist/dnsdist.cc
pdns/dnsdistdist/test-dnsdist_cc.cc
regression-tests.dnsdist/test_LuaFFI.py

index d6323d8b3224844bc78f0eeec935acaad077344e..5832585a896d69c42802d928b4c3dd063c868c82 100644 (file)
@@ -51,7 +51,7 @@ bool DNSDistPacketCache::getClientSubnet(const PacketBuffer& packet, size_t qnam
   uint16_t optRDPosition = 0;
   size_t remaining = 0;
 
-  int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
+  int res = dnsdist::getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
 
   if (res == 0) {
     size_t ecsOptionStartPosition = 0;
index 10ddf4a98850a58dd0fda4b9979b4534cff1277e..2caeb67be156c93a198104d9ea5e7472ca691f3d 100644 (file)
@@ -380,8 +380,9 @@ int locateEDNSOptRR(const PacketBuffer& packet, uint16_t* optStart, size_t* optL
   return ENOENT;
 }
 
+namespace dnsdist {
 /* extract the start of the OPT RR in a QUERY packet if any */
-int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_t* optRDPosition, size_t* remaining)
+int getEDNSOptionsStart(const PacketBuffer& packet, const size_t qnameWireLength, uint16_t* optRDPosition, size_t* remaining)
 {
   if (optRDPosition == nullptr || remaining == nullptr) {
     throw std::runtime_error("Invalid values passed to getEDNSOptionsStart");
@@ -389,7 +390,7 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_
 
   const dnsheader_aligned dnsHeader(packet.data());
 
-  if (offset >= packet.size()) {
+  if (qnameWireLength >= packet.size()) {
     return ENOENT;
   }
 
@@ -397,7 +398,7 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_
     return ENOENT;
   }
 
-  size_t pos = sizeof(dnsheader) + offset;
+  size_t pos = sizeof(dnsheader) + qnameWireLength;
   pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
 
   if (pos >= packet.size()) {
@@ -428,6 +429,7 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_
 
   return 0;
 }
+}
 
 void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
 {
@@ -531,7 +533,7 @@ bool parseEDNSOptions(const DNSQuestion& dnsQuestion)
 
   size_t remaining = 0;
   uint16_t optRDPosition{};
-  int res = getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
+  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
 
   if (res == 0) {
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
@@ -616,7 +618,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons
   uint16_t optRDPosition = 0;
   size_t remaining = 0;
 
-  int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
+  int res = dnsdist::getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
 
   if (res != 0) {
     /* no EDNS but there might be another record in additional (TSIG?) */
@@ -996,7 +998,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion)
   size_t remaining = 0;
 
   auto& packet = dnsQuestion.getMutableData();
-  int res = getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
+  int res = dnsdist::getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
 
   if (res != 0) {
     /* if the initial query did not have EDNS0, we are done */
@@ -1031,54 +1033,105 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion)
   return true;
 }
 
+namespace dnsdist {
+static std::optional<size_t> getEDNSRecordPosition(const DNSQuestion& dnsQuestion)
+{
+  try {
+    const auto& packet = dnsQuestion.getData();
+    if (packet.size() <= sizeof(dnsheader)) {
+      return std::nullopt;
+    }
+
+    uint16_t optRDPosition = 0;
+    size_t remaining = 0;
+    auto res = getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
+    if (res != 0) {
+      return std::nullopt;
+    }
+
+    if (optRDPosition < DNS_TTL_SIZE) {
+      return std::nullopt;
+    }
+
+    return optRDPosition - DNS_TTL_SIZE;
+  }
+  catch (...) {
+    return std::nullopt;
+  }
+}
+
 // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
 int getEDNSZ(const DNSQuestion& dnsQuestion)
 {
   try {
-    const auto& dnsHeader = dnsQuestion.getHeader();
-    if (ntohs(dnsHeader->qdcount) != 1 || dnsHeader->ancount != 0 || ntohs(dnsHeader->arcount) != 1 || dnsHeader->nscount != 0) {
+    auto position = getEDNSRecordPosition(dnsQuestion);
+
+    if (!position) {
       return 0;
     }
 
-    if (dnsQuestion.getData().size() <= sizeof(dnsheader)) {
+    const auto& packet = dnsQuestion.getData();
+    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
       return 0;
     }
 
-    size_t pos = sizeof(dnsheader) + dnsQuestion.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+    return 0x100 * packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) + packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1);
+  }
+  catch (...) {
+    return 0;
+  }
+}
+
+std::optional<uint8_t> getEDNSVersion(const DNSQuestion& dnsQuestion)
+{
+  try {
+    auto position = getEDNSRecordPosition(dnsQuestion);
 
-    if (dnsQuestion.getData().size() <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) {
-      return 0;
+    if (!position) {
+      return std::nullopt;
     }
 
     const auto& packet = dnsQuestion.getData();
-    if (packet.at(pos) != 0) {
-      /* not root, so not a valid OPT record */
-      return 0;
+    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) >= packet.size()) {
+      return std::nullopt;
     }
 
-    pos++;
+    return packet.at(*position + EDNS_EXTENDED_RCODE_SIZE);
+  }
+  catch (...) {
+    return std::nullopt;
+  }
+}
+
+std::optional<uint8_t> getEDNSExtendedRCode(const DNSQuestion& dnsQuestion)
+{
+  try {
+    auto position = getEDNSRecordPosition(dnsQuestion);
 
-    uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
-    pos += DNS_TYPE_SIZE;
-    pos += DNS_CLASS_SIZE;
+    if (!position) {
+      return std::nullopt;
+    }
 
-    if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
-      return 0;
+    const auto& packet = dnsQuestion.getData();
+    if ((*position + EDNS_EXTENDED_RCODE_SIZE) >= packet.size()) {
+      return std::nullopt;
     }
 
-    return 0x100 * packet.at(pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) + packet.at(pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1);
+    return packet.at(*position);
   }
   catch (...) {
-    return 0;
+    return std::nullopt;
   }
 }
 
+}
+
 bool queryHasEDNS(const DNSQuestion& dnsQuestion)
 {
   uint16_t optRDPosition = 0;
   size_t ecsRemaining = 0;
 
-  int res = getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
+  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
   return res == 0;
 }
 
index ee901bee97b6ff9019224dd5f1c98c18da2e13e4..0c6a4780ed29e1427099c9316b66d6ea19bb08ca 100644 (file)
@@ -38,7 +38,6 @@ bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximu
 void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength);
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
 int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent);
-int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_t* optRDPosition, size_t* remaining);
 bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr);
 bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion);
@@ -49,7 +48,6 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, size_t maximumSize, size_t qna
 
 bool parseEDNSOptions(const DNSQuestion& dnsQuestion);
 
-int getEDNSZ(const DNSQuestion& dnsQuestion);
 bool queryHasEDNS(const DNSQuestion& dnsQuestion);
 bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0);
 
@@ -59,4 +57,12 @@ struct InternalQueryState;
 namespace dnsdist
 {
 bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers);
+/* this method only works for queries (qdcount == 1, ancount == nscount == 0, arcount == 1) */
+int getEDNSOptionsStart(const PacketBuffer& packet, const size_t qnameWireLength, uint16_t* optRDPosition, size_t* remaining);
+/* this method only works for queries (qdcount == 1, ancount == nscount == 0, arcount == 1) */
+int getEDNSZ(const DNSQuestion& dnsQuestion);
+/* this method only works for queries (qdcount == 1, ancount == nscount == 0, arcount == 1) */
+std::optional<uint8_t> getEDNSVersion(const DNSQuestion& dnsQuestion);
+/* this method only works for queries (qdcount == 1, ancount == nscount == 0, arcount == 1) */
+std::optional<uint8_t> getEDNSExtendedRCode(const DNSQuestion& dnsQuestion);
 }
index 9133aa77e2c054aca37681905593138570ca9ebb..76a87bc4ab5476f067ec9fc83f6f4cee752aee09 100644 (file)
@@ -909,7 +909,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
   bool hadEDNS = false;
   if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dnsquestion)) {
     hadEDNS = true;
-    dnssecOK = ((getEDNSZ(*dnsquestion) & EDNS_HEADER_FLAG_DO) != 0);
+    dnssecOK = ((dnsdist::getEDNSZ(*dnsquestion) & EDNS_HEADER_FLAG_DO) != 0);
   }
 
   auto& data = dnsquestion->getMutableData();
index 8889fea674588d76576234daa3cb3692fe1a941b..bf6c47361c9eb52ad2690cf3b9abd89bf04f254d 100644 (file)
@@ -122,7 +122,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     }
     dnsQuestion.ids.d_protoBufData->d_requestorID = newValue; });
   luaCtx.registerFunction<bool (DNSQuestion::*)() const>("getDO", [](const DNSQuestion& dnsQuestion) {
-    return getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO;
+    return dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO;
   });
   luaCtx.registerFunction<std::string (DNSQuestion::*)() const>("getContent", [](const DNSQuestion& dnsQuestion) {
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
@@ -423,7 +423,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     editDNSPacketTTL(reinterpret_cast<char*>(dnsResponse.getMutableData().data()), dnsResponse.getData().size(), editFunc);
   });
   luaCtx.registerFunction<bool (DNSResponse::*)() const>("getDO", [](const DNSResponse& dnsQuestion) {
-    return getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO;
+    return dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO;
   });
   luaCtx.registerFunction<std::string (DNSResponse::*)() const>("getContent", [](const DNSResponse& dnsQuestion) {
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
index a98cf1d9c663694dfb03c2236909c19601ce9e52..9a1b7e4cb2b4c1fe3141d278d4c2a6319072a8fd 100644 (file)
@@ -84,6 +84,8 @@ uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsques
 bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint8_t dnsdist_ffi_dnsquestion_get_edns_version(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint8_t dnsdist_ffi_dnsquestion_get_edns_extended_rcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label) __attribute__ ((visibility ("default")));
 size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* label, char* buffer, size_t bufferSize) __attribute__ ((visibility ("default")));
index c59182750dc5d752c32f7889b9bdefdbd99a9b2b..617fb9ce9534cf928e246aa968111be4b1818057 100644 (file)
@@ -237,7 +237,19 @@ uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquest
 
 bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq)
 {
-  return getEDNSZ(*dq->dq) & EDNS_HEADER_FLAG_DO;
+  return dnsdist::getEDNSZ(*dq->dq) & EDNS_HEADER_FLAG_DO;
+}
+
+uint8_t dnsdist_ffi_dnsquestion_get_edns_version(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  auto version = dnsdist::getEDNSVersion(*dq->dq);
+  return version ? *version : 0U;
+}
+
+uint8_t dnsdist_ffi_dnsquestion_get_edns_extended_rcode(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  auto rcode = dnsdist::getEDNSExtendedRCode(*dq->dq);
+  return rcode ? *rcode : 0U;
 }
 
 void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize)
index 83ea08dbd7063d9648ea1388f6c5bdc8a827a376..129da7c8b96fffbbbb2b6d0181c7759eebc5b15e 100644 (file)
@@ -416,7 +416,7 @@ public:
   }
   bool matches(const DNSQuestion* dq) const override
   {
-    return dq->getHeader()->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
+    return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
   }
 
   string toString() const override
index a92fdb2072f93810010d235f3d9101e70d34979b..574cc9eea1e3f52a4c193329077055df54b0dfd3 100644 (file)
@@ -177,7 +177,7 @@ bool generateSVCResponse(DNSQuestion& dnsQuestion, const std::vector<std::vector
 
   const auto& runtimeConfig = dnsdist::configuration::getCurrentRuntimeConfiguration();
   if (runtimeConfig.d_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
-    bool dnssecOK = ((getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
+    bool dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
     packetWriter.addOpt(runtimeConfig.d_payloadSizeSelfGenAnswers, 0, dnssecOK ? EDNS_HEADER_FLAG_DO : 0);
     packetWriter.commit();
   }
index 4186ac787575227b9535924f5a5f1da7d71d8dbc..0194682d0fb5fd2ed078ea77a619a2c2bc4338ce 100644 (file)
@@ -1422,7 +1422,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
     uint32_t allowExpired = selectedBackend ? 0 : dnsdist::configuration::getCurrentRuntimeConfiguration().d_staleCacheEntriesTTL;
 
     if (dnsQuestion.ids.packetCache && !dnsQuestion.ids.skipCache) {
-      dnsQuestion.ids.dnssecOK = (getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0;
+      dnsQuestion.ids.dnssecOK = (dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0;
     }
 
     if (dnsQuestion.useECS && ((selectedBackend && selectedBackend->d_config.useECS) || (!selectedBackend && serverPool->getECS()))) {
index 0ea258fbc8f1ba438519aa6fa4f40f1bbeb9637e..44a3cf75ee66aa5c1cdf963bd33fa00f0eef9aa6 100644 (file)
@@ -1489,7 +1489,7 @@ static int getZ(const DNSName& qname, const uint16_t qtype, const uint16_t qclas
 
   auto dnsQuestion = DNSQuestion(ids, query);
 
-  return getEDNSZ(dnsQuestion);
+  return dnsdist::getEDNSZ(dnsQuestion);
 }
 
 BOOST_AUTO_TEST_CASE(test_getEDNSZ)
@@ -1593,6 +1593,131 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ)
   }
 }
 
+BOOST_AUTO_TEST_CASE(test_getEDNSVersion)
+{
+  const DNSName qname("www.powerdns.com.");
+  const uint16_t qtype = QType::A;
+  const uint16_t qclass = QClass::IN;
+  const GenericDNSPacketWriter<PacketBuffer>::optvect_t opts;
+
+  auto getVersion = [&qname, qtype, qclass](PacketBuffer& query) {
+    InternalQueryState ids;
+    ids.protocol = dnsdist::Protocol::DoUDP;
+    ids.qname = qname;
+    ids.qtype = qtype;
+    ids.qclass = qclass;
+    ids.origDest = ComboAddress("127.0.0.1");
+    ids.origRemote = ComboAddress("127.0.0.1");
+    ids.queryRealTime.start();
+
+    auto dnsQuestion = DNSQuestion(ids, query);
+
+    return dnsdist::getEDNSVersion(dnsQuestion);
+  };
+
+  {
+    /* no EDNS */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.commit();
+
+    BOOST_CHECK(getVersion(query) == std::nullopt);
+  }
+
+  {
+    /* truncated EDNS */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    packetWriter.commit();
+
+    query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* TTL */ 2));
+    BOOST_CHECK(getVersion(query) == std::nullopt);
+  }
+
+  {
+    /* valid EDNS, no options */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 0, 0);
+    packetWriter.commit();
+
+    BOOST_CHECK_EQUAL(*getVersion(query), 0U);
+  }
+
+  {
+    /* EDNS version 255 */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 0, EDNS_HEADER_FLAG_DO, opts, 255U);
+    packetWriter.commit();
+
+    BOOST_CHECK_EQUAL(*getVersion(query), 255U);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(test_getEDNSExtendedRCode)
+{
+  const DNSName qname("www.powerdns.com.");
+  const uint16_t qtype = QType::A;
+  const uint16_t qclass = QClass::IN;
+
+  auto getExtendedRCode = [&qname, qtype, qclass](PacketBuffer& query) {
+    InternalQueryState ids;
+    ids.protocol = dnsdist::Protocol::DoUDP;
+    ids.qname = qname;
+    ids.qtype = qtype;
+    ids.qclass = qclass;
+    ids.origDest = ComboAddress("127.0.0.1");
+    ids.origRemote = ComboAddress("127.0.0.1");
+    ids.queryRealTime.start();
+
+    auto dnsQuestion = DNSQuestion(ids, query);
+
+    return dnsdist::getEDNSExtendedRCode(dnsQuestion);
+  };
+
+  {
+    /* no EDNS */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.commit();
+
+    BOOST_CHECK(getExtendedRCode(query) == std::nullopt);
+  }
+
+  {
+    /* truncated EDNS */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    packetWriter.commit();
+
+    query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* TTL */ 2));
+    BOOST_CHECK(getExtendedRCode(query) == std::nullopt);
+  }
+
+  {
+    /* valid EDNS, no options */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 0, 0);
+    packetWriter.commit();
+
+    BOOST_CHECK_EQUAL(*getExtendedRCode(query), 0U);
+  }
+
+  {
+    /* EDNS extended RCode 4095 (15 for the normal RCode, 255 for the EDNS part) */
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> packetWriter(query, qname, qtype, qclass, 0);
+    packetWriter.addOpt(512, 4095U, EDNS_HEADER_FLAG_DO);
+    packetWriter.commit();
+
+    BOOST_CHECK_EQUAL(*getExtendedRCode(query), 255U);
+  }
+}
+
 BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
 {
   InternalQueryState ids;
@@ -1622,7 +1747,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
     packetWriter.commit();
 
     auto dnsQuestion = turnIntoResponse(ids, query);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK(dnsdist::getEDNSVersion(dnsQuestion) == std::nullopt);
+    BOOST_CHECK(dnsdist::getEDNSExtendedRCode(dnsQuestion) == std::nullopt);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), false);
     BOOST_CHECK_EQUAL(zValue, 0);
@@ -1638,7 +1765,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
 
     query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1));
     auto dnsQuestion = turnIntoResponse(ids, query, false);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK(dnsdist::getEDNSVersion(dnsQuestion) == std::nullopt);
+    BOOST_CHECK(dnsdist::getEDNSExtendedRCode(dnsQuestion) == std::nullopt);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), false);
     BOOST_CHECK_EQUAL(zValue, 0);
@@ -1653,7 +1782,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
     packetWriter.commit();
 
     auto dnsQuestion = turnIntoResponse(ids, query);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSVersion(dnsQuestion), 0U);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSExtendedRCode(dnsQuestion), 0U);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), true);
     BOOST_CHECK_EQUAL(zValue, 0);
@@ -1668,7 +1799,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
     packetWriter.commit();
 
     auto dnsQuestion = turnIntoResponse(ids, query);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSVersion(dnsQuestion), 0U);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSExtendedRCode(dnsQuestion), 0U);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), true);
     BOOST_CHECK_EQUAL(zValue, EDNS_HEADER_FLAG_DO);
@@ -1683,7 +1816,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
     packetWriter.commit();
 
     auto dnsQuestion = turnIntoResponse(ids, query);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), 0);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSVersion(dnsQuestion), 0U);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSExtendedRCode(dnsQuestion), 0U);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), true);
     BOOST_CHECK_EQUAL(zValue, 0);
@@ -1698,7 +1833,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse)
     packetWriter.commit();
 
     auto dnsQuestion = turnIntoResponse(ids, query);
-    BOOST_CHECK_EQUAL(getEDNSZ(dnsQuestion), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(dnsdist::getEDNSZ(dnsQuestion), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSVersion(dnsQuestion), 0U);
+    BOOST_CHECK_EQUAL(*dnsdist::getEDNSExtendedRCode(dnsQuestion), 0U);
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dnsQuestion.getData().data()), dnsQuestion.getData().size(), &udpPayloadSize, &zValue), true);
     BOOST_CHECK_EQUAL(zValue, EDNS_HEADER_FLAG_DO);
@@ -1730,13 +1867,13 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart)
     packetWriter.getHeader()->rcode = RCode::NXDomain;
     packetWriter.commit();
 
-    int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    int res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
 
     BOOST_CHECK_EQUAL(res, ENOENT);
 
     /* truncated packet (should not matter) */
     query.resize(query.size() - 1);
-    res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
 
     BOOST_CHECK_EQUAL(res, ENOENT);
   }
@@ -1748,7 +1885,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart)
     packetWriter.addOpt(512, 0, 0);
     packetWriter.commit();
 
-    int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    int res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
 
     BOOST_CHECK_EQUAL(res, 0);
     BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset);
@@ -1757,7 +1894,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart)
     /* truncated packet */
     query.resize(query.size() - 1);
 
-    res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
     BOOST_CHECK_EQUAL(res, ENOENT);
   }
 
@@ -1768,7 +1905,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart)
     packetWriter.addOpt(512, 0, 0, opts);
     packetWriter.commit();
 
-    int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    int res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
 
     BOOST_CHECK_EQUAL(res, 0);
     BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset);
@@ -1776,7 +1913,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart)
 
     /* truncated options (should not matter for this test) */
     query.resize(query.size() - 1);
-    res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
+    res = dnsdist::getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining);
     BOOST_CHECK_EQUAL(res, 0);
     BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset);
     BOOST_CHECK_EQUAL(remaining, query.size() - optRDExpectedOffset);
index 517db15c3265ff1f6121b143ee7318b908d5a5ab..f298d2a3cece94a09b5d970cc8f070bdfa1b4826 100644 (file)
@@ -70,6 +70,18 @@ class TestAdvancedLuaFFI(DNSDistTest):
         return false
       end
 
+      local ednsVersion = ffi.C.dnsdist_ffi_dnsquestion_get_edns_version(dq)
+      if ednsVersion ~= 0 then
+        print('invalid EDNS version')
+        return false
+      end
+
+      local ednsExtendedRCode = ffi.C.dnsdist_ffi_dnsquestion_get_edns_extended_rcode(dq)
+      if ednsExtendedRCode ~= 0 then
+        print('invalid EDNS Extended RCode')
+        return false
+      end
+
       local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
       if len ~= 52 then
         print('invalid length')
@@ -226,7 +238,19 @@ class TestAdvancedLuaFFIPerThread(DNSDistTest):
           return false
         end
 
-        local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
+        local ednsVersion = ffi.C.dnsdist_ffi_dnsquestion_get_edns_version(dq)
+        if ednsVersion ~= 0 then
+          print('invalid EDNS version')
+          return false
+        end
+
+        local ednsExtendedRCode = ffi.C.dnsdist_ffi_dnsquestion_get_edns_extended_rcode(dq)
+        if ednsExtendedRCode ~= 0 then
+          print('invalid EDNS Extended RCode')
+          return false
+        end
+
+      local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
         if len ~= 61 then
           print('invalid length')
           print(len)