]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/lua-recursor4.cc
spelling: target
[thirdparty/pdns.git] / pdns / lua-recursor4.cc
index 7c96133f5369095864caedaec7e220c41f5bf081..445092448ed0985ddaa58e3c17e0a1703c4dc357 100644 (file)
 
 RecursorLua4::RecursorLua4() { prepareContext(); }
 
-static int followCNAMERecords(vector<DNSRecord>& ret, const QType& qtype)
-{
-  vector<DNSRecord> resolved;
-  DNSName target;
-  for(const DNSRecord& rr :  ret) {
-    if(rr.d_type == QType::CNAME) {
-      auto rec = getRR<CNAMERecordContent>(rr);
-      if(rec) {
-        target=rec->getTarget();
-        break;
-      }
-    }
-  }
-  if(target.empty())
-    return 0;
-  
-  int rcode=directResolve(target, qtype, 1, resolved); // 1 == class
-  
-  for(const DNSRecord& rr :  resolved) {
-    ret.push_back(rr);
-  }
-  return rcode;
-}
-
-static int getFakeAAAARecords(const DNSName& qname, const std::string& prefix, vector<DNSRecord>& ret)
-{
-  int rcode=directResolve(qname, QType(QType::A), 1, ret);
-
-  ComboAddress prefixAddress(prefix);
-
-  // Remove double CNAME records
-  std::set<DNSName> seenCNAMEs;
-  ret.erase(std::remove_if(
-        ret.begin(),
-        ret.end(),
-        [&seenCNAMEs](DNSRecord& rr) {
-          if (rr.d_type == QType::CNAME) {
-            auto target = getRR<CNAMERecordContent>(rr);
-            if (target == nullptr) {
-              return false;
-            }
-            if (seenCNAMEs.count(target->getTarget()) > 0) {
-              // We've had this CNAME before, remove it
-              return true;
-            }
-            seenCNAMEs.insert(target->getTarget());
-          }
-          return false;
-        }),
-      ret.end());
-
-  bool seenA = false;
-  for(DNSRecord& rr :  ret)
-  {
-    if(rr.d_type == QType::A && rr.d_place==DNSResourceRecord::ANSWER) {
-      if(auto rec = getRR<ARecordContent>(rr)) {
-        ComboAddress ipv4(rec->getCA());
-        uint32_t tmp;
-        memcpy((void*)&tmp, &ipv4.sin4.sin_addr.s_addr, 4);
-        // tmp=htonl(tmp);
-        memcpy(((char*)&prefixAddress.sin6.sin6_addr.s6_addr)+12, &tmp, 4);
-        rr.d_content = std::make_shared<AAAARecordContent>(prefixAddress);
-        rr.d_type = QType::AAAA;
-      }
-      seenA = true;
-    }
-  }
-
-  if (seenA) {
-    // We've seen an A in the ANSWER section, so there is no need to keep any
-    // SOA in the AUTHORITY section as this is not a NODATA response.
-    ret.erase(std::remove_if(
-          ret.begin(),
-          ret.end(),
-          [](DNSRecord& rr) {
-            return (rr.d_type == QType::SOA && rr.d_place==DNSResourceRecord::AUTHORITY);
-          }),
-        ret.end());
-  }
-  return rcode;
-}
-
-static int getFakePTRRecords(const DNSName& qname, const std::string& prefix, vector<DNSRecord>& ret)
-{
-  /* qname has a reverse ordered IPv6 address, need to extract the underlying IPv4 address from it
-     and turn it into an IPv4 in-addr.arpa query */
-  ret.clear();
-  vector<string> parts = qname.getRawLabels();
-
-  if(parts.size() < 8)
-    return -1;
-
-  string newquery;
-  for(int n = 0; n < 4; ++n) {
-    newquery +=
-      std::to_string(stoll(parts[n*2], 0, 16) + 16*stoll(parts[n*2+1], 0, 16));
-    newquery.append(1,'.');
-  }
-  newquery += "in-addr.arpa.";
-
-
-  int rcode = directResolve(DNSName(newquery), QType(QType::PTR), 1, ret);
-  for(DNSRecord& rr :  ret)
-  {
-    if(rr.d_type == QType::PTR && rr.d_place==DNSResourceRecord::ANSWER) {
-      rr.d_name = qname;
-    }
-  }
-  return rcode;
-
-}
-
 boost::optional<dnsheader> RecursorLua4::DNSQuestion::getDH() const
 {
   if (dh)
@@ -207,6 +94,20 @@ boost::optional<Netmask>  RecursorLua4::DNSQuestion::getEDNSSubnet() const
   return boost::optional<Netmask>();
 }
 
+std::vector<std::pair<int, ProxyProtocolValue>> RecursorLua4::DNSQuestion::getProxyProtocolValues() const
+{
+  std::vector<std::pair<int, ProxyProtocolValue>> result;
+  if (proxyProtocolValues) {
+    result.reserve(proxyProtocolValues->size());
+
+    int idx = 1;
+    for (const auto& value: *proxyProtocolValues) {
+      result.push_back({ idx++, value });
+    }
+  }
+
+  return result;
+}
 
 vector<pair<int, DNSRecord> > RecursorLua4::DNSQuestion::getRecords() const
 {
@@ -279,14 +180,13 @@ void RecursorLua4::postPrepareContext()
   d_lw->registerMember("appliedPolicy", &DNSQuestion::appliedPolicy);
   d_lw->registerMember<DNSFilterEngine::Policy, std::string>("policyName",
     [](const DNSFilterEngine::Policy& pol) -> std::string {
-      if(pol.d_name)
-        return *pol.d_name;
-      return std::string();
+      return pol.getName();
     },
     [](DNSFilterEngine::Policy& pol, const std::string& name) {
-      pol.d_name = std::make_shared<std::string>(name);
+      pol.setName(name);
     });
   d_lw->registerMember("policyKind", &DNSFilterEngine::Policy::d_kind);
+  d_lw->registerMember("policyType", &DNSFilterEngine::Policy::d_type);
   d_lw->registerMember("policyTTL", &DNSFilterEngine::Policy::d_ttl);
   d_lw->registerMember<DNSFilterEngine::Policy, std::string>("policyCustom",
     [](const DNSFilterEngine::Policy& pol) -> std::string {
@@ -314,6 +214,7 @@ void RecursorLua4::postPrepareContext()
   d_lw->registerFunction("getEDNSOptions", &DNSQuestion::getEDNSOptions);
   d_lw->registerFunction("getEDNSOption", &DNSQuestion::getEDNSOption);
   d_lw->registerFunction("getEDNSSubnet", &DNSQuestion::getEDNSSubnet);
+  d_lw->registerFunction("getProxyProtocolValues", &DNSQuestion::getProxyProtocolValues);
   d_lw->registerFunction("getEDNSFlags", &DNSQuestion::getEDNSFlags);
   d_lw->registerFunction("getEDNSFlag", &DNSQuestion::getEDNSFlag);
   d_lw->registerMember("name", &DNSRecord::d_name);
@@ -359,6 +260,8 @@ void RecursorLua4::postPrepareContext()
       return ret;
     });
 
+  d_lw->registerFunction<const ProxyProtocolValue, std::string()>("getContent", [](const ProxyProtocolValue& value) { return value.content; });
+  d_lw->registerFunction<const ProxyProtocolValue, uint8_t()>("getType", [](const ProxyProtocolValue& value) { return value.type; });
 
   d_lw->registerFunction<void(DNSRecord::*)(const std::string&)>("changeContent", [](DNSRecord& dr, const std::string& newContent) { dr.d_content = DNSRecordContent::mastermake(dr.d_type, 1, newContent); });
   d_lw->registerFunction("addAnswer", &DNSQuestion::addAnswer);
@@ -366,12 +269,13 @@ void RecursorLua4::postPrepareContext()
   d_lw->registerFunction("getRecords", &DNSQuestion::getRecords);
   d_lw->registerFunction("setRecords", &DNSQuestion::setRecords);
 
-  d_lw->registerFunction<void(DNSQuestion::*)(const std::string&)>("addPolicyTag", [](DNSQuestion& dq, const std::string& tag) { if (dq.policyTags) { dq.policyTags->push_back(tag); } });
+  d_lw->registerFunction<void(DNSQuestion::*)(const std::string&)>("addPolicyTag", [](DNSQuestion& dq, const std::string& tag) { if (dq.policyTags) { dq.policyTags->insert(tag); } });
   d_lw->registerFunction<void(DNSQuestion::*)(const std::vector<std::pair<int, std::string> >&)>("setPolicyTags", [](DNSQuestion& dq, const std::vector<std::pair<int, std::string> >& tags) {
       if (dq.policyTags) {
         dq.policyTags->clear();
+        dq.policyTags->reserve(tags.size());
         for (const auto& tag : tags) {
-          dq.policyTags->push_back(tag.second);
+          dq.policyTags->insert(tag.second);
         }
       }
     });
@@ -379,6 +283,7 @@ void RecursorLua4::postPrepareContext()
       std::vector<std::pair<int, std::string> > ret;
       if (dq.policyTags) {
         int count = 1;
+        ret.reserve(dq.policyTags->size());
         for (const auto& tag : *dq.policyTags) {
           ret.push_back({count++, tag});
         }
@@ -485,6 +390,14 @@ void RecursorLua4::postLoad() {
   d_gettag_ffi = d_lw->readVariable<boost::optional<gettag_ffi_t>>("gettag_ffi").get_value_or(0);
 }
 
+void RecursorLua4::getFeatures(Features & features) {
+  // Add key-values pairs below.
+  // Make sure you add string values explicitly converted to string.
+  // e.g. features.push_back(make_pair("somekey", string("stringvalue"));
+  // Both int and double end up as a lua number type.
+   features.push_back(make_pair("PR8001_devicename", true));
+}
+
 void RecursorLua4::maintenance() const
 {
   if (d_maintenance) {
@@ -535,16 +448,24 @@ bool RecursorLua4::ipfilter(const ComboAddress& remote, const ComboAddress& loca
   return false; // don't block
 }
 
-unsigned int RecursorLua4::gettag(const ComboAddress& remote, const Netmask& ednssubnet, const ComboAddress& local, const DNSName& qname, uint16_t qtype, std::vector<std::string>* policyTags, LuaContext::LuaObject& data, const EDNSOptionViewMap& ednsOptions, bool tcp, std::string& requestorId, std::string& deviceId, std::string& deviceName) const
+unsigned int RecursorLua4::gettag(const ComboAddress& remote, const Netmask& ednssubnet, const ComboAddress& local, const DNSName& qname, uint16_t qtype, std::unordered_set<std::string>* policyTags, LuaContext::LuaObject& data, const EDNSOptionViewMap& ednsOptions, bool tcp, std::string& requestorId, std::string& deviceId, std::string& deviceName, std::string& routingTag, const std::vector<ProxyProtocolValue>& proxyProtocolValues) const
 {
   if(d_gettag) {
-    auto ret = d_gettag(remote, ednssubnet, local, qname, qtype, ednsOptions, tcp);
+    std::vector<std::pair<int, const ProxyProtocolValue*>> proxyProtocolValuesMap;
+    proxyProtocolValuesMap.reserve(proxyProtocolValues.size());
+    int num = 1;
+    for (const auto& value : proxyProtocolValues) {
+      proxyProtocolValuesMap.emplace_back(num++, &value);
+    }
+
+    auto ret = d_gettag(remote, ednssubnet, local, qname, qtype, ednsOptions, tcp, proxyProtocolValuesMap);
 
     if (policyTags) {
       const auto& tags = std::get<1>(ret);
       if (tags) {
+        policyTags->reserve(policyTags->size() + tags->size());
         for (const auto& tag : *tags) {
-          policyTags->push_back(tag.second);
+          policyTags->insert(tag.second);
         }
       }
     }
@@ -565,6 +486,12 @@ unsigned int RecursorLua4::gettag(const ComboAddress& remote, const Netmask& edn
     if (deviceNameret) {
       deviceName = *deviceNameret;
     }
+
+    const auto routingTarget = std::get<6>(ret);
+    if (routingTarget) {
+      routingTag = *routingTarget;
+    }
+
     return std::get<0>(ret);
   }
   return 0;
@@ -573,7 +500,7 @@ unsigned int RecursorLua4::gettag(const ComboAddress& remote, const Netmask& edn
 struct pdns_ffi_param
 {
 public:
-  pdns_ffi_param(const DNSName& qname_, uint16_t qtype_, const ComboAddress& local_, const ComboAddress& remote_, const Netmask& ednssubnet_, std::vector<std::string>& policyTags_, const EDNSOptionViewMap& ednsOptions_, std::string& requestorId_, std::string& deviceId_, std::string& deviceName_, uint32_t& ttlCap_, bool& variable_, bool tcp_, bool& logQuery_): qname(qname_), local(local_), remote(remote_), ednssubnet(ednssubnet_), policyTags(policyTags_), ednsOptions(ednsOptions_), requestorId(requestorId_), deviceId(deviceId_), deviceName(deviceName_), ttlCap(ttlCap_), variable(variable_), logQuery(logQuery_), qtype(qtype_), tcp(tcp_)
+  pdns_ffi_param(const DNSName& qname_, uint16_t qtype_, const ComboAddress& local_, const ComboAddress& remote_, const Netmask& ednssubnet_, std::unordered_set<std::string>& policyTags_, std::vector<DNSRecord>& records_, const EDNSOptionViewMap& ednsOptions_, const std::vector<ProxyProtocolValue>& proxyProtocolValues_, std::string& requestorId_, std::string& deviceId_, std::string& deviceName_, std::string& routingTag_, boost::optional<int>& rcode_, uint32_t& ttlCap_, bool& variable_, bool tcp_, bool& logQuery_, bool& logResponse_, bool& followCNAMERecords_): qname(qname_), local(local_), remote(remote_), ednssubnet(ednssubnet_), policyTags(policyTags_), records(records_), ednsOptions(ednsOptions_), proxyProtocolValues(proxyProtocolValues_), requestorId(requestorId_), deviceId(deviceId_), deviceName(deviceName_), routingTag(routingTag_), rcode(rcode_), ttlCap(ttlCap_), variable(variable_), logQuery(logQuery_), logResponse(logResponse_), followCNAMERecords(followCNAMERecords_), qtype(qtype_), tcp(tcp_)
   {
   }
 
@@ -582,29 +509,36 @@ public:
   std::unique_ptr<std::string> remoteStr{nullptr};
   std::unique_ptr<std::string> ednssubnetStr{nullptr};
   std::vector<pdns_ednsoption_t> ednsOptionsVect;
+  std::vector<pdns_proxyprotocol_value_t> proxyProtocolValuesVect;
 
   const DNSName& qname;
   const ComboAddress& local;
   const ComboAddress& remote;
   const Netmask& ednssubnet;
-  std::vector<std::string>& policyTags;
+  std::unordered_set<std::string>& policyTags;
+  std::vector<DNSRecord>& records;
   const EDNSOptionViewMap& ednsOptions;
+  const std::vector<ProxyProtocolValue>& proxyProtocolValues;
   std::string& requestorId;
   std::string& deviceId;
   std::string& deviceName;
+  std::string& routingTag;
+  boost::optional<int>& rcode;
   uint32_t& ttlCap;
   bool& variable;
   bool& logQuery;
+  bool& logResponse;
+  bool& followCNAMERecords;
 
   unsigned int tag{0};
   uint16_t qtype;
   bool tcp;
 };
 
-unsigned int RecursorLua4::gettag_ffi(const ComboAddress& remote, const Netmask& ednssubnet, const ComboAddress& local, const DNSName& qname, uint16_t qtype, std::vector<std::string>* policyTags, LuaContext::LuaObject& data, const EDNSOptionViewMap& ednsOptions, bool tcp, std::string& requestorId, std::string& deviceId, std::string& deviceName, uint32_t& ttlCap, bool& variable, bool& logQuery) const
+unsigned int RecursorLua4::gettag_ffi(const ComboAddress& remote, const Netmask& ednssubnet, const ComboAddress& local, const DNSName& qname, uint16_t qtype, std::unordered_set<std::string>* policyTags, std::vector<DNSRecord>& records, LuaContext::LuaObject& data, const EDNSOptionViewMap& ednsOptions, bool tcp, const std::vector<ProxyProtocolValue>& proxyProtocolValues, std::string& requestorId, std::string& deviceId, std::string& deviceName, std::string& routingTag, boost::optional<int>& rcode, uint32_t& ttlCap, bool& variable, bool& logQuery, bool& logResponse, bool& followCNAMERecords) const
 {
   if (d_gettag_ffi) {
-    pdns_ffi_param_t param(qname, qtype, local, remote, ednssubnet, *policyTags, ednsOptions, requestorId, deviceId, deviceName, ttlCap, variable, tcp, logQuery);
+    pdns_ffi_param_t param(qname, qtype, local, remote, ednssubnet, *policyTags, records, ednsOptions, proxyProtocolValues, requestorId, deviceId, deviceName, routingTag, rcode, ttlCap, variable, tcp, logQuery, logResponse, followCNAMERecords);
 
     auto ret = d_gettag_ffi(&param);
     if (ret) {
@@ -646,10 +580,10 @@ loop:;
         ret = followCNAMERecords(dq.records, QType(dq.qtype));
       }
       else if(dq.followupFunction=="getFakeAAAARecords") {
-        ret=getFakeAAAARecords(dq.followupName, dq.followupPrefix, dq.records);
+        ret=getFakeAAAARecords(dq.followupName, ComboAddress(dq.followupPrefix), dq.records);
       }
       else if(dq.followupFunction=="getFakePTRRecords") {
-        ret=getFakePTRRecords(dq.followupName, dq.followupPrefix, dq.records);
+        ret=getFakePTRRecords(dq.followupName, dq.records);
       }
       else if(dq.followupFunction=="udpQueryResponse") {
         dq.udpAnswer = GenUDPQueryResponse(dq.udpQueryDest, dq.udpQuery);
@@ -834,6 +768,30 @@ size_t pdns_ffi_param_get_edns_options_by_code(pdns_ffi_param_t* ref, uint16_t o
   return pos;
 }
 
+size_t pdns_ffi_param_get_proxy_protocol_values(pdns_ffi_param_t* ref, const pdns_proxyprotocol_value_t** out)
+{
+  if (ref->proxyProtocolValues.empty()) {
+    return 0;
+  }
+
+  ref->proxyProtocolValuesVect.resize(ref->proxyProtocolValues.size());
+
+  size_t pos = 0;
+  for (const auto& value : ref->proxyProtocolValues) {
+    auto& dest = ref->proxyProtocolValuesVect.at(pos);
+    dest.type = value.type;
+    dest.len = value.content.size();
+    if (dest.len > 0) {
+      dest.data = value.content.data();
+    }
+    pos++;
+  }
+
+  *out = ref->proxyProtocolValuesVect.data();
+
+  return ref->proxyProtocolValuesVect.size();
+}
+
 void pdns_ffi_param_set_tag(pdns_ffi_param_t* ref, unsigned int tag)
 {
   ref->tag = tag;
@@ -841,7 +799,7 @@ void pdns_ffi_param_set_tag(pdns_ffi_param_t* ref, unsigned int tag)
 
 void pdns_ffi_param_add_policytag(pdns_ffi_param_t *ref, const char* name)
 {
-  ref->policyTags.push_back(std::string(name));
+  ref->policyTags.insert(std::string(name));
 }
 
 void pdns_ffi_param_set_requestorid(pdns_ffi_param_t* ref, const char* name)
@@ -859,6 +817,11 @@ void pdns_ffi_param_set_deviceid(pdns_ffi_param_t* ref, size_t len, const void*
   ref->deviceId = std::string(reinterpret_cast<const char*>(name), len);
 }
 
+void pdns_ffi_param_set_routingtag(pdns_ffi_param_t* ref, const char* rtag)
+{
+  ref->routingTag = std::string(rtag);
+}
+
 void pdns_ffi_param_set_variable(pdns_ffi_param_t* ref, bool variable)
 {
   ref->variable = variable;
@@ -873,3 +836,38 @@ void pdns_ffi_param_set_log_query(pdns_ffi_param_t* ref, bool logQuery)
 {
   ref->logQuery = logQuery;
 }
+
+void pdns_ffi_param_set_log_response(pdns_ffi_param_t* ref, bool logResponse)
+{
+  ref->logResponse = logResponse;
+}
+
+void pdns_ffi_param_set_rcode(pdns_ffi_param_t* ref, int rcode)
+{
+  ref->rcode = rcode;
+}
+
+void pdns_ffi_param_set_follow_cname_records(pdns_ffi_param_t* ref, bool follow)
+{
+  ref->followCNAMERecords = follow;
+}
+
+bool pdns_ffi_param_add_record(pdns_ffi_param_t *ref, const char* name, uint16_t type, uint32_t ttl, const char* content, size_t contentSize, pdns_record_place_t place)
+{
+  try {
+    DNSRecord dr;
+    dr.d_name = name != nullptr ? DNSName(name) : ref->qname;
+    dr.d_ttl = ttl;
+    dr.d_type = type;
+    dr.d_class = QClass::IN;
+    dr.d_place = DNSResourceRecord::Place(place);
+    dr.d_content = DNSRecordContent::mastermake(type, QClass::IN, std::string(content, contentSize));
+    ref->records.push_back(std::move(dr));
+
+    return true;
+  }
+  catch (const std::exception& e) {
+    g_log<<Logger::Error<<"Error attempting to add a record from Lua via pdns_ffi_param_add_record(): "<<e.what()<<endl;
+    return false;
+  }
+}