]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua FFI bindings to generate SVC responses
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 4 Jun 2024 15:36:23 +0000 (17:36 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 6 Jun 2024 15:04:09 +0000 (17:04 +0200)
pdns/dnsdistdist/dnsdist-dnsparser.cc
pdns/dnsdistdist/dnsdist-dnsparser.hh
pdns/dnsdistdist/dnsdist-lua-actions.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-lua.hh
pdns/dnsdistdist/dnsdist-svc.cc
pdns/dnsdistdist/dnsdist-svc.hh
pdns/dnsdistdist/test-dnsdist-lua-ffi.cc
regression-tests.dnsdist/test_SVCB.py

index a15f2d5e9f54e99a59aedb5aace7f5d473dfe8ee..bfe0be3e07ed28bb51defbb6397b3c1d3bf0c6ec 100644 (file)
@@ -214,4 +214,23 @@ namespace PacketMangling
     return true;
   }
 }
+
+void setResponseHeadersFromConfig(dnsheader& dnsheader, const ResponseConfig& config)
+{
+  if (config.setAA) {
+    dnsheader.aa = *config.setAA;
+  }
+  if (config.setAD) {
+    dnsheader.ad = *config.setAD;
+  }
+  else {
+    dnsheader.ad = false;
+  }
+  if (config.setRA) {
+    dnsheader.ra = *config.setRA;
+  }
+  else {
+    dnsheader.ra = dnsheader.rd; // for good measure
+  }
+}
 }
index 4f7cdaded409702940563ef87f5e313ae8eeec7e..67d74a344205cacac42f67f1663fff9d99efc3a0 100644 (file)
@@ -60,4 +60,13 @@ namespace PacketMangling
   bool editDNSHeaderFromPacket(PacketBuffer& packet, const std::function<bool(dnsheader& header)>& editFunction);
   bool editDNSHeaderFromRawPacket(void* packet, const std::function<bool(dnsheader& header)>& editFunction);
 }
+
+struct ResponseConfig
+{
+  boost::optional<bool> setAA{boost::none};
+  boost::optional<bool> setAD{boost::none};
+  boost::optional<bool> setRA{boost::none};
+  uint32_t ttl{60};
+};
+void setResponseHeadersFromConfig(dnsheader& dnsheader, const ResponseConfig& config);
 }
index 51348375f016fec54ffa17801f30f98f0f186d74..ebb82507a4b5658e0967e523d1ff10644e1c1698 100644 (file)
@@ -383,13 +383,13 @@ public:
   {
     return "set rcode " + std::to_string(d_rcode);
   }
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
   uint8_t d_rcode;
 };
 
@@ -413,13 +413,13 @@ public:
   {
     return "set ercode " + ERCode::to_s(d_rcode);
   }
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
   uint8_t d_rcode;
 };
 
@@ -436,7 +436,6 @@ public:
         throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters");
       }
 
-      d_totalPayloadsSize += payload.size();
       d_payloads.push_back(std::move(payload));
 
       for (const auto& hint : param.second.ipv4hints) {
@@ -451,72 +450,27 @@ public:
 
   DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
   {
-    /* it will likely be a bit bigger than that because of additionals */
-    auto numberOfRecords = d_payloads.size();
-    const auto qnameWireLength = dnsquestion->ids.qname.wirelength();
-    if (dnsquestion->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + d_totalPayloadsSize)) {
+    if (!dnsdist::svc::generateSVCResponse(*dnsquestion, d_payloads, d_additionals4, d_additionals6, d_responseConfig)) {
       return Action::None;
     }
-
-    PacketBuffer newPacket;
-    newPacket.reserve(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + d_totalPayloadsSize);
-    GenericDNSPacketWriter<PacketBuffer> packetWriter(newPacket, dnsquestion->ids.qname, dnsquestion->ids.qtype);
-    for (const auto& payload : d_payloads) {
-      packetWriter.startRecord(dnsquestion->ids.qname, dnsquestion->ids.qtype, d_responseConfig.ttl);
-      packetWriter.xfrBlob(payload);
-      packetWriter.commit();
-    }
-
-    if (newPacket.size() < dnsquestion->getMaximumSize()) {
-      for (const auto& additional : d_additionals4) {
-        packetWriter.startRecord(additional.first.isRoot() ? dnsquestion->ids.qname : additional.first, QType::A, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL);
-        packetWriter.xfrCAWithoutPort(4, additional.second);
-        packetWriter.commit();
-      }
-    }
-
-    if (newPacket.size() < dnsquestion->getMaximumSize()) {
-      for (const auto& additional : d_additionals6) {
-        packetWriter.startRecord(additional.first.isRoot() ? dnsquestion->ids.qname : additional.first, QType::AAAA, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL);
-        packetWriter.xfrCAWithoutPort(6, additional.second);
-        packetWriter.commit();
-      }
-    }
-
-    if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dnsquestion)) {
-      bool dnssecOK = ((getEDNSZ(*dnsquestion) & EDNS_HEADER_FLAG_DO) != 0);
-      packetWriter.addOpt(g_PayloadSizeSelfGenAnswers, 0, dnssecOK ? EDNS_HEADER_FLAG_DO : 0);
-      packetWriter.commit();
-    }
-
-    if (newPacket.size() >= dnsquestion->getMaximumSize()) {
-      /* sorry! */
-      return Action::None;
-    }
-
-    packetWriter.getHeader()->id = dnsquestion->getHeader()->id;
-    packetWriter.getHeader()->qr = true; // for good measure
-    setResponseHeadersFromConfig(*packetWriter.getHeader(), d_responseConfig);
-    dnsquestion->getMutableData() = std::move(newPacket);
-
     return Action::HeaderModify;
   }
+
   [[nodiscard]] std::string toString() const override
   {
     return "spoof SVC record ";
   }
 
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
   std::vector<std::vector<uint8_t>> d_payloads{};
   std::set<std::pair<DNSName, ComboAddress>> d_additionals4{};
   std::set<std::pair<DNSName, ComboAddress>> d_additionals6{};
-  size_t d_totalPayloadsSize{0};
 };
 
 class TCAction : public DNSAction
@@ -2094,13 +2048,13 @@ public:
     return "return an HTTP status of " + std::to_string(d_code);
   }
 
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
   PacketBuffer d_body;
   std::string d_contentType;
   int d_code;
@@ -2260,13 +2214,13 @@ public:
   {
     return std::string(d_nxd ? "NXD " : "NODATA") + " with SOA";
   }
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
 
   DNSName d_zone;
   DNSName d_mname;
@@ -2439,7 +2393,7 @@ static void addAction(GlobalStateHolder<vector<T>>* someRuleActions, const luadn
 
 using responseParams_t = std::unordered_map<std::string, boost::variant<bool, uint32_t>>;
 
-static void parseResponseConfig(boost::optional<responseParams_t>& vars, ResponseConfig& config)
+static void parseResponseConfig(boost::optional<responseParams_t>& vars, dnsdist::ResponseConfig& config)
 {
   getOptionalValue<uint32_t>(vars, "ttl", config.ttl);
   getOptionalValue<bool>(vars, "aa", config.setAA);
@@ -2447,25 +2401,6 @@ static void parseResponseConfig(boost::optional<responseParams_t>& vars, Respons
   getOptionalValue<bool>(vars, "ra", config.setRA);
 }
 
-void setResponseHeadersFromConfig(dnsheader& dnsheader, const ResponseConfig& config)
-{
-  if (config.setAA) {
-    dnsheader.aa = *config.setAA;
-  }
-  if (config.setAD) {
-    dnsheader.ad = *config.setAD;
-  }
-  else {
-    dnsheader.ad = false;
-  }
-  if (config.setRA) {
-    dnsheader.ra = *config.setRA;
-  }
-  else {
-    dnsheader.ra = dnsheader.rd; // for good measure
-  }
-}
-
 // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
 void setupLuaActions(LuaContext& luaCtx)
 {
index 03ea1b33151531f409aab17a668ff0d8309d0903..a98cf1d9c663694dfb03c2236909c19601ce9e52 100644 (file)
@@ -286,3 +286,16 @@ const dnsdist_ffi_dynamic_block_entry_t* dnsdist_ffi_dynamic_blocks_list_get(con
 void dnsdist_ffi_dynamic_blocks_list_free(dnsdist_ffi_dynamic_blocks_list_t*) __attribute__ ((visibility ("default")));
 
 uint32_t dnsdist_ffi_hash(uint32_t seed, const unsigned char* data, size_t dataSize, bool caseInsensitive) __attribute__ ((visibility ("default")));
+
+typedef struct dnsdist_ffi_svc_record_parameters dnsdist_ffi_svc_record_parameters;
+bool dnsdist_ffi_svc_record_parameters_new(const char* targetName, uint16_t priority, bool noDefaultALPN, dnsdist_ffi_svc_record_parameters** out) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_set_port(dnsdist_ffi_svc_record_parameters* parameters, uint16_t port) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_set_ech(dnsdist_ffi_svc_record_parameters* parameters, const char* ech, size_t echLen) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_set_additional_param(dnsdist_ffi_svc_record_parameters* parameters, uint16_t key, const char* value, size_t valueLen) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_add_mandatory_param(dnsdist_ffi_svc_record_parameters* parameters, uint16_t key) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_add_alpn(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_add_ipv4_hint(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_add_ipv6_hint(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_svc_record_parameters_free(dnsdist_ffi_svc_record_parameters* parameters) __attribute__ ((visibility ("default")));
+
+bool dnsdist_ffi_dnsquestion_generate_svc_response(dnsdist_ffi_dnsquestion_t* dnsQuestion, const dnsdist_ffi_svc_record_parameters** parametersList, size_t parametersListSize, uint32_t ttl) __attribute__ ((visibility ("default")));
index e399c438c5dd576272a05eb9942755fe83a0b22a..1bbb9c2fdfce25e7f5a8425a24bede967a45002b 100644 (file)
@@ -31,6 +31,7 @@
 #include "dnsdist-lua.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-rings.hh"
+#include "dnsdist-svc.hh"
 #include "dolog.hh"
 
 uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq)
@@ -2046,3 +2047,133 @@ uint32_t dnsdist_ffi_hash(uint32_t seed, const unsigned char* data, size_t dataS
 
   return burtle(data, dataSize, seed);
 }
+
+struct dnsdist_ffi_svc_record_parameters
+{
+  SVCRecordParameters parameters;
+};
+
+bool dnsdist_ffi_svc_record_parameters_new(const char* targetName, uint16_t priority, bool noDefaultALPN, dnsdist_ffi_svc_record_parameters** out)
+{
+  if (targetName == nullptr || out == nullptr) {
+    return false;
+  }
+  try {
+    auto parameters = std::make_unique<dnsdist_ffi_svc_record_parameters>();
+    parameters->parameters.target = DNSName(targetName);
+    parameters->parameters.priority = priority;
+    parameters->parameters.noDefaultAlpn = noDefaultALPN;
+    *out = parameters.release();
+    return true;
+  }
+  catch (const std::exception& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_new: %s", exp.what());
+  }
+  catch (const PDNSException& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_new: %s", exp.reason);
+  }
+  catch (...) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_new");
+  }
+
+  return false;
+}
+
+void dnsdist_ffi_svc_record_parameters_set_port(dnsdist_ffi_svc_record_parameters* parameters, uint16_t port)
+{
+  if (parameters == nullptr) {
+    return;
+  }
+  parameters->parameters.port = port;
+}
+
+void dnsdist_ffi_svc_record_parameters_set_ech(dnsdist_ffi_svc_record_parameters* parameters, const char* ech, size_t echLen)
+{
+  if (parameters == nullptr || ech == nullptr || echLen == 0) {
+    return;
+  }
+  parameters->parameters.ech = std::string(ech, echLen);
+}
+
+void dnsdist_ffi_svc_record_parameters_set_additional_param(dnsdist_ffi_svc_record_parameters* parameters, uint16_t key, const char* value, size_t valueLen)
+{
+  if (parameters == nullptr || (value == nullptr && valueLen != 0)) {
+    return;
+  }
+  parameters->parameters.additionalParams.emplace_back(key, std::string(value, valueLen));
+}
+
+void dnsdist_ffi_svc_record_parameters_add_mandatory_param(dnsdist_ffi_svc_record_parameters* parameters, uint16_t key)
+{
+  if (parameters == nullptr) {
+    return;
+  }
+  parameters->parameters.mandatoryParams.insert(key);
+}
+
+void dnsdist_ffi_svc_record_parameters_add_alpn(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen)
+{
+  if (parameters == nullptr || value == nullptr || valueLen == 0) {
+    return;
+  }
+  parameters->parameters.alpns.emplace_back(value, valueLen);
+}
+
+void dnsdist_ffi_svc_record_parameters_add_ipv4_hint(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen)
+{
+  if (parameters == nullptr || value == nullptr || valueLen == 0) {
+    return;
+  }
+  try {
+    parameters->parameters.ipv4hints.emplace_back(ComboAddress(std::string(value, valueLen)));
+  }
+  catch (const std::exception& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint: %s", exp.what());
+  }
+  catch (const PDNSException& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint: %s", exp.reason);
+  }
+  catch (...) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint");
+  }
+}
+
+void dnsdist_ffi_svc_record_parameters_add_ipv6_hint(dnsdist_ffi_svc_record_parameters* parameters, const char* value, size_t valueLen)
+{
+  if (parameters == nullptr || value == nullptr || valueLen == 0) {
+    return;
+  }
+  try {
+    parameters->parameters.ipv6hints.emplace_back(ComboAddress(std::string(value, valueLen)));
+  }
+  catch (const std::exception& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint: %s", exp.what());
+  }
+  catch (const PDNSException& exp) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint: %s", exp.reason);
+  }
+  catch (...) {
+    errlog("Exception in dnsdist_ffi_svc_record_parameters_add_ipv4_hint");
+  }
+}
+
+bool dnsdist_ffi_dnsquestion_generate_svc_response(dnsdist_ffi_dnsquestion_t* dnsQuestion, const dnsdist_ffi_svc_record_parameters** parametersList, size_t parametersListSize, uint32_t ttl)
+{
+  if (dnsQuestion == nullptr || parametersList == nullptr || parametersListSize == 0) {
+    return false;
+  }
+  std::vector<SVCRecordParameters> parameters;
+  parameters.reserve(parametersListSize);
+  for (size_t idx = 0; idx < parametersListSize; idx++) {
+    if (parametersList[idx] == nullptr) {
+      return false;
+    }
+    parameters.push_back(parametersList[idx]->parameters);
+  }
+  return dnsdist::svc::generateSVCResponse(*dnsQuestion->dq, ttl, parameters);
+}
+
+void dnsdist_ffi_svc_record_parameters_free(dnsdist_ffi_svc_record_parameters* parameters)
+{
+  delete parameters;
+}
index 8a6363300a5c13689aff273e203d7715788b5d2b..0e076880a1d9e32f9b7afa9e44759224e32072fb 100644 (file)
 
 #include "dolog.hh"
 #include "dnsdist.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsparser.hh"
 #include <random>
 
-struct ResponseConfig
-{
-  boost::optional<bool> setAA{boost::none};
-  boost::optional<bool> setAD{boost::none};
-  boost::optional<bool> setRA{boost::none};
-  uint32_t ttl{60};
-};
-void setResponseHeadersFromConfig(dnsheader& dnsheader, const ResponseConfig& config);
-
 class SpoofAction : public DNSAction
 {
 public:
@@ -88,13 +80,13 @@ public:
     return ret;
   }
 
-  [[nodiscard]] ResponseConfig& getResponseConfig()
+  [[nodiscard]] dnsdist::ResponseConfig& getResponseConfig()
   {
     return d_responseConfig;
   }
 
 private:
-  ResponseConfig d_responseConfig;
+  dnsdist::ResponseConfig d_responseConfig;
   static thread_local std::default_random_engine t_randomEngine;
   std::vector<ComboAddress> d_addrs;
   std::unordered_set<uint16_t> d_types;
index ffd42fd649433594226f768fcf51e432874dc909..6f73000ea8c42b358392ccbbf1a87f7d548c083d 100644 (file)
@@ -20,6 +20,9 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist-svc.hh"
+#include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsdist-lua.hh"
 #include "dnswriter.hh"
 #include "svc-records.hh"
 
@@ -131,3 +134,93 @@ struct SVCRecordParameters parseSVCParameters(const svcParamsLua_t& params)
   }
   return parameters;
 }
+
+namespace dnsdist::svc
+{
+bool generateSVCResponse(DNSQuestion& dnsQuestion, const std::vector<std::vector<uint8_t>>& svcRecordPayloads, const std::set<std::pair<DNSName, ComboAddress>>& additionals4, const std::set<std::pair<DNSName, ComboAddress>>& additionals6, const ResponseConfig& responseConfig)
+{
+  /* it will likely be a bit bigger than that because of additionals */
+  auto totalPayloadsSize = 0;
+  for (const auto& payload : svcRecordPayloads) {
+    totalPayloadsSize += payload.size();
+  }
+  const auto numberOfRecords = svcRecordPayloads.size();
+  const auto qnameWireLength = dnsQuestion.ids.qname.wirelength();
+  if (dnsQuestion.getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totalPayloadsSize)) {
+    return false;
+  }
+
+  PacketBuffer newPacket;
+  newPacket.reserve(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totalPayloadsSize);
+  GenericDNSPacketWriter<PacketBuffer> packetWriter(newPacket, dnsQuestion.ids.qname, dnsQuestion.ids.qtype);
+  for (const auto& payload : svcRecordPayloads) {
+    packetWriter.startRecord(dnsQuestion.ids.qname, dnsQuestion.ids.qtype, responseConfig.ttl);
+    packetWriter.xfrBlob(payload);
+    packetWriter.commit();
+  }
+
+  if (newPacket.size() < dnsQuestion.getMaximumSize()) {
+    for (const auto& additional : additionals4) {
+      packetWriter.startRecord(additional.first.isRoot() ? dnsQuestion.ids.qname : additional.first, QType::A, responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL);
+      packetWriter.xfrCAWithoutPort(4, additional.second);
+      packetWriter.commit();
+    }
+  }
+
+  if (newPacket.size() < dnsQuestion.getMaximumSize()) {
+    for (const auto& additional : additionals6) {
+      packetWriter.startRecord(additional.first.isRoot() ? dnsQuestion.ids.qname : additional.first, QType::AAAA, responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL);
+      packetWriter.xfrCAWithoutPort(6, additional.second);
+      packetWriter.commit();
+    }
+  }
+
+  if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
+    bool dnssecOK = ((getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
+    packetWriter.addOpt(g_PayloadSizeSelfGenAnswers, 0, dnssecOK ? EDNS_HEADER_FLAG_DO : 0);
+    packetWriter.commit();
+  }
+
+  if (newPacket.size() >= dnsQuestion.getMaximumSize()) {
+    /* sorry! */
+    return false;
+  }
+
+  packetWriter.getHeader()->id = dnsQuestion.getHeader()->id;
+  packetWriter.getHeader()->qr = true; // for good measure
+  setResponseHeadersFromConfig(*packetWriter.getHeader(), responseConfig);
+  dnsQuestion.getMutableData() = std::move(newPacket);
+
+  return true;
+}
+
+bool generateSVCResponse(DNSQuestion& dnsQuestion, uint32_t ttl, const std::vector<SVCRecordParameters>& parameters)
+{
+  std::vector<std::vector<uint8_t>> payloads;
+  std::set<std::pair<DNSName, ComboAddress>> additionals4;
+  std::set<std::pair<DNSName, ComboAddress>> additionals6;
+  ResponseConfig responseConfig;
+  responseConfig.setAA = true;
+  responseConfig.ttl = ttl;
+
+  payloads.reserve(parameters.size());
+  for (const auto& parameter : parameters) {
+    std::vector<uint8_t> payload;
+    if (!generateSVCPayload(payload, parameter)) {
+      throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters");
+    }
+
+    payloads.push_back(std::move(payload));
+
+    for (const auto& hint : parameter.ipv4hints) {
+      additionals4.insert({parameter.target, ComboAddress(hint)});
+    }
+
+    for (const auto& hint : parameter.ipv6hints) {
+      additionals6.insert({parameter.target, ComboAddress(hint)});
+    }
+  }
+
+  return generateSVCResponse(dnsQuestion, payloads, additionals4, additionals6, responseConfig);
+}
+}
index d0a1a8c961b3da0660234518c7545a22d0ef274d..742683c3901a571ba48d66c01084e034721351fe 100644 (file)
@@ -49,18 +49,29 @@ struct SVCRecordParameters
   bool noDefaultAlpn{false};
 };
 
-typedef std::unordered_map<
+using svcParamsLua_t = std::unordered_map<
   std::string,
   boost::variant<
     uint16_t,
     bool,
     std::string,
     std::vector<std::pair<int, std::string>>,
-    std::vector<std::pair<int, ComboAddress>>>>
-  svcParamsLua_t;
+    std::vector<std::pair<int, ComboAddress>>>>;
 
 struct SVCRecordParameters parseSVCParameters(const svcParamsLua_t& params);
 
 bool generateSVCPayload(std::vector<uint8_t>& payload, uint16_t priority, const DNSName& target, const std::set<uint16_t>& mandatoryParams, const std::vector<std::string>& alpns, bool noDefaultAlpn, std::optional<uint16_t> port, const std::string& ech, const std::vector<ComboAddress>& ipv4hints, const std::vector<ComboAddress>& ipv6hints, const std::vector<std::pair<uint16_t, std::string>>& additionalParams);
 
 bool generateSVCPayload(std::vector<uint8_t>& payload, const SVCRecordParameters& parameters);
+
+struct DNSQuestion;
+namespace dnsdist
+{
+struct ResponseConfig;
+}
+
+namespace dnsdist::svc
+{
+bool generateSVCResponse(DNSQuestion& dnsQuestion, const std::vector<std::vector<uint8_t>>& svcRecordPayloads, const std::set<std::pair<DNSName, ComboAddress>>& additionals4, const std::set<std::pair<DNSName, ComboAddress>>& additionals6, const dnsdist::ResponseConfig& d_responseConfig);
+bool generateSVCResponse(DNSQuestion& dnsQuestion, uint32_t ttl, const std::vector<SVCRecordParameters>& parameters);
+}
index d0cd68ea5563bdf71c2a29f25f1e99dda38153f9..0429e5c560908a0ce71a00cdb8a134450e551286 100644 (file)
@@ -886,4 +886,64 @@ BOOST_AUTO_TEST_CASE(test_hash)
   }
 }
 
+BOOST_AUTO_TEST_CASE(test_SVC_Generation)
+{
+  dnsdist_ffi_svc_record_parameters* parameters{nullptr};
+
+  {
+    /* invalid parameters */
+    BOOST_CHECK_EQUAL(dnsdist_ffi_svc_record_parameters_new(nullptr, 0, false, &parameters), false);
+    BOOST_CHECK_EQUAL(dnsdist_ffi_svc_record_parameters_new("powerdns.com.", 0, false, nullptr), false);
+  }
+
+  BOOST_REQUIRE_EQUAL(dnsdist_ffi_svc_record_parameters_new("powerdns.com.", 1, true, &parameters), true);
+  BOOST_REQUIRE(parameters != nullptr);
+
+  {
+    /* invalid parameters */
+    dnsdist_ffi_svc_record_parameters_set_port(nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_set_ech(nullptr, "alpn", 4);
+    dnsdist_ffi_svc_record_parameters_set_additional_param(nullptr, 7, "/dns-query{?dns}", 16);
+    dnsdist_ffi_svc_record_parameters_set_additional_param(parameters, 7, nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_mandatory_param(nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_alpn(nullptr, "h2", 2);
+    dnsdist_ffi_svc_record_parameters_add_alpn(parameters, nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_ipv4_hint(parameters, nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_ipv4_hint(nullptr, nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_ipv6_hint(parameters, nullptr, 0);
+    dnsdist_ffi_svc_record_parameters_add_ipv6_hint(nullptr, nullptr, 0);
+    dnsdist_ffi_dnsquestion_generate_svc_response(nullptr, nullptr, 0, 0);
+  }
+
+  dnsdist_ffi_svc_record_parameters_set_port(parameters, 443);
+  dnsdist_ffi_svc_record_parameters_set_ech(parameters, "binary", 6);
+  dnsdist_ffi_svc_record_parameters_set_additional_param(parameters, 7, "/dns-query{?dns}", 16);
+  dnsdist_ffi_svc_record_parameters_add_mandatory_param(parameters, 7);
+  dnsdist_ffi_svc_record_parameters_add_alpn(parameters, "h2", 2);
+  dnsdist_ffi_svc_record_parameters_add_ipv4_hint(parameters, "9.9.9.9", 8);
+  dnsdist_ffi_svc_record_parameters_add_ipv6_hint(parameters, "2620:fe::fe", 11);
+
+  {
+    InternalQueryState ids;
+    ids.origRemote = ComboAddress("192.0.2.1:4242");
+    ids.origDest = ComboAddress("192.0.2.255:53");
+    ids.qtype = QType::A;
+    ids.qclass = QClass::IN;
+    ids.protocol = dnsdist::Protocol::DoUDP;
+    ids.qname = DNSName("www.powerdns.com.");
+    ids.queryRealTime.start();
+    PacketBuffer query;
+    GenericDNSPacketWriter<PacketBuffer> pwQ(query, ids.qname, QType::A, QClass::IN, 0);
+    pwQ.getHeader()->rd = 1;
+    pwQ.getHeader()->id = htons(42);
+
+    DNSQuestion dnsQuestion(ids, query);
+    dnsdist_ffi_dnsquestion_t lightDQ(&dnsQuestion);
+    std::array<const dnsdist_ffi_svc_record_parameters*, 1> list = {parameters};
+    BOOST_CHECK_EQUAL(dnsdist_ffi_dnsquestion_generate_svc_response(&lightDQ, list.data(), list.size(), 42), true);
+  }
+
+  dnsdist_ffi_svc_record_parameters_free(parameters);
+}
+
 BOOST_AUTO_TEST_SUITE_END();
index 9056924231a6f53865801a359074f2d58941df25..12c92809144e1b7973696c9e00860d749eb6a5e3 100644 (file)
@@ -101,3 +101,166 @@ class TestSVCB(DNSDistTest):
             self.assertEqual(len(receivedResponse.additional), 2)
             self.assertEqual(receivedResponse.additional[0], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.2'))
             self.assertEqual(receivedResponse.additional[1], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.AAAA, '2001:db8::2'))
+
+class TestSVCBViaFFI(DNSDistTest):
+
+    _config_template = """
+    local ffi = require("ffi")
+
+    function setSVC(record, port, mandatoryParam, alpn, v4Hint, v6Hint)
+      ffi.C.dnsdist_ffi_svc_record_parameters_set_port(record, port)
+      ffi.C.dnsdist_ffi_svc_record_parameters_add_mandatory_param(record, mandatoryParam)
+      ffi.C.dnsdist_ffi_svc_record_parameters_add_alpn(record, alpn, #alpn)
+      if v4Hint then
+        ffi.C.dnsdist_ffi_svc_record_parameters_add_ipv4_hint(record, v4Hint, #v4Hint)
+      end
+      if v6Hint then
+        ffi.C.dnsdist_ffi_svc_record_parameters_add_ipv6_hint(record, v6Hint, #v6Hint)
+      end
+    end
+
+    function generateSVC(target, priority, port, alpn, noDefaultALPN, v4Hint, v6Hint)
+      local recordPtr = ffi.new("dnsdist_ffi_svc_record_parameters* [1]")
+      local recordPtrOut = ffi.cast("dnsdist_ffi_svc_record_parameters**", recordPtr)
+      ffi.C.dnsdist_ffi_svc_record_parameters_new(target, priority, noDefaultALPN, recordPtrOut)
+      ffi.gc(recordPtrOut[0], ffi.C.dnsdist_ffi_svc_record_parameters_free)
+      -- 3 is the port parameter
+      setSVC(recordPtrOut[0], port, 3, alpn, v4Hint, v6Hint)
+      return recordPtrOut[0]
+    end
+
+    function basicSVC(dq)
+      local SVCrecords = ffi.new("dnsdist_ffi_svc_record_parameters* [2]")
+      SVCrecords[0] = generateSVC("dot.powerdns.com.", 1, 853, "dot", true, "192.0.2.1", "2001:db8::1")
+      SVCrecords[1] = generateSVC("doh.powerdns.com.", 2, 443, "h2", false, "192.0.2.2", "2001:db8::2")
+      local path = "/dns-query{?dns}"
+      ffi.C.dnsdist_ffi_svc_record_parameters_set_additional_param(SVCrecords[1], 7, path, #path)
+      local SVCrecordsPtr = ffi.cast("const dnsdist_ffi_svc_record_parameters**", SVCrecords)
+      if not ffi.C.dnsdist_ffi_dnsquestion_generate_svc_response(dq, SVCrecordsPtr, 2, 60) then
+        return DNSAction.ServFail
+      end
+      return DNSAction.HeaderModify
+    end
+
+    addAction(AndRule{QTypeRule(64), SuffixMatchNodeRule("basic.svcb.tests.powerdns.com.")}, LuaFFIAction(basicSVC))
+
+    function noHintsSVC(dq)
+      local SVCrecords = ffi.new("dnsdist_ffi_svc_record_parameters* [2]")
+      SVCrecords[0] = generateSVC("dot.powerdns.com.", 1, 853, "dot", true, nil, nil)
+      SVCrecords[1] = generateSVC("doh.powerdns.com.", 2, 443, "h2", false, nil, nil)
+      local path = "/dns-query{?dns}"
+      ffi.C.dnsdist_ffi_svc_record_parameters_set_additional_param(SVCrecords[1], 7, path, #path)
+      local SVCrecordsPtr = ffi.cast("const dnsdist_ffi_svc_record_parameters**", SVCrecords)
+      if not ffi.C.dnsdist_ffi_dnsquestion_generate_svc_response(dq, SVCrecordsPtr, 2, 60) then
+        return DNSAction.ServFail
+      end
+      return DNSAction.HeaderModify
+    end
+
+    addAction(AndRule{QTypeRule(64), SuffixMatchNodeRule("no-hints.svcb.tests.powerdns.com.")}, LuaFFIAction(noHintsSVC))
+
+    function effectiveTargetSVC(dq)
+      local SVCrecords = ffi.new("dnsdist_ffi_svc_record_parameters* [2]")
+      SVCrecords[0] = generateSVC(".", 1, 853, "dot", true, "192.0.2.1", "2001:db8::1")
+      SVCrecords[1] = generateSVC(".", 2, 443, "h2", false, "192.0.2.1", "2001:db8::1")
+      local path = "/dns-query{?dns}"
+      ffi.C.dnsdist_ffi_svc_record_parameters_set_additional_param(SVCrecords[1], 7, path, #path)
+      local SVCrecordsPtr = ffi.cast("const dnsdist_ffi_svc_record_parameters**", SVCrecords)
+      if not ffi.C.dnsdist_ffi_dnsquestion_generate_svc_response(dq, SVCrecordsPtr, 2, 60) then
+        return DNSAction.ServFail
+      end
+      return DNSAction.HeaderModify
+    end
+
+    addAction(AndRule{QTypeRule(64), SuffixMatchNodeRule("effective-target.svcb.tests.powerdns.com.")}, LuaFFIAction(effectiveTargetSVC))
+
+    function httpsSVC(dq)
+      local SVCrecords = ffi.new("dnsdist_ffi_svc_record_parameters* [1]")
+      SVCrecords[0] = generateSVC(".", 1, 8002, "h2", false, "192.0.2.2", "2001:db8::2")
+      local SVCrecordsPtr = ffi.cast("const dnsdist_ffi_svc_record_parameters**", SVCrecords)
+      if not ffi.C.dnsdist_ffi_dnsquestion_generate_svc_response(dq, SVCrecordsPtr, 1, 60) then
+        return DNSAction.ServFail
+      end
+      return DNSAction.HeaderModify
+    end
+
+    addAction(AndRule{QTypeRule(65), SuffixMatchNodeRule("https.svcb.tests.powerdns.com.")}, LuaFFIAction(httpsSVC))
+
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testBasic(self):
+        """
+        SVCB: Basic service binding
+        """
+        name = 'basic.svcb.tests.powerdns.com.'
+        query = dns.message.make_query(name, 64, 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(len(receivedResponse.answer), 1)
+            self.assertEqual(receivedResponse.answer[0].rdtype, 64)
+            self.assertEqual(len(receivedResponse.additional), 4)
+            self.assertEqual(receivedResponse.additional[0], dns.rrset.from_text("doh.powerdns.com.", 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.2'))
+            self.assertEqual(receivedResponse.additional[1], dns.rrset.from_text("dot.powerdns.com.", 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'))
+            self.assertEqual(receivedResponse.additional[2], dns.rrset.from_text("doh.powerdns.com.", 60, dns.rdataclass.IN, dns.rdatatype.AAAA, '2001:db8::2'))
+            self.assertEqual(receivedResponse.additional[3], dns.rrset.from_text("dot.powerdns.com.", 60, dns.rdataclass.IN, dns.rdatatype.AAAA, '2001:db8::1'))
+
+    def testNoHints(self):
+        """
+        SVCB: No hints
+        """
+        name = 'no-hints.svcb.tests.powerdns.com.'
+        query = dns.message.make_query(name, 64, 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(len(receivedResponse.answer), 1)
+            self.assertEqual(receivedResponse.answer[0].rdtype, 64)
+            self.assertEqual(len(receivedResponse.additional), 0)
+
+    def testEffectiveTarget(self):
+        """
+        SVCB: Effective target
+        """
+        name = 'effective-target.svcb.tests.powerdns.com.'
+        query = dns.message.make_query(name, 64, 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(len(receivedResponse.answer), 1)
+            self.assertEqual(receivedResponse.answer[0].rdtype, 64)
+            self.assertEqual(len(receivedResponse.additional), 2)
+            self.assertEqual(receivedResponse.additional[0], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'))
+            self.assertEqual(receivedResponse.additional[1], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.AAAA, '2001:db8::1'))
+
+    def testHTTPS(self):
+        """
+        SVCB: HTTPS
+        """
+        name = 'https.svcb.tests.powerdns.com.'
+        query = dns.message.make_query(name, 65, 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(len(receivedResponse.answer), 1)
+            self.assertEqual(receivedResponse.answer[0].rdtype, 65)
+            self.assertEqual(len(receivedResponse.additional), 2)
+            self.assertEqual(receivedResponse.additional[0], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.2'))
+            self.assertEqual(receivedResponse.additional[1], dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.AAAA, '2001:db8::2'))