]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/lua-auth4.cc
dnsdist: Add HTTPStatusAction to return a specific HTTP response
[thirdparty/pdns.git] / pdns / lua-auth4.cc
index 773d71090d8c73995f6ef6ce5ff4513e4130b7f5..dae104b14032b784aa4c6f6782cf235bccbf7312 100644 (file)
@@ -1,3 +1,7 @@
+#include "config.h"
+#if defined(HAVE_LUA)
+#include "ext/luawrapper/include/LuaContext.hpp"
+#endif
 #include "lua-auth4.hh"
 #include "stubresolver.hh"
 #include <fstream>
 #include "namespaces.hh"
 #include "ednssubnet.hh"
 #include <unordered_set>
+#include "sstuff.hh"
+#include <thread>
+#include <mutex>
+
+#include "ueberbackend.hh"
+
+AuthLua4::AuthLua4() { prepareContext(); }
 
 #if !defined(HAVE_LUA)
 
-AuthLua4::AuthLua4(const std::string& fname) { }
 bool AuthLua4::updatePolicy(const DNSName &qname, QType qtype, const DNSName &zonename, DNSPacket *packet) { return false; }
 bool AuthLua4::axfrfilter(const ComboAddress& remote, const DNSName& zone, const DNSResourceRecord& in, vector<DNSResourceRecord>& out) { return false; }
-AuthLua4::~AuthLua4() { }
-
-#else
-
-#undef L
-#include "ext/luawrapper/include/LuaContext.hpp"
-
-AuthLua4::AuthLua4(const std::string& fname) {
-  d_lw = std::unique_ptr<LuaContext>(new LuaContext);
-  stubParseResolveConf();
-  d_lw->registerFunction<int(dnsheader::*)()>("getID", [](dnsheader& dh) { return dh.id; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getCD", [](dnsheader& dh) { return dh.cd; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getTC", [](dnsheader& dh) { return dh.tc; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getRA", [](dnsheader& dh) { return dh.ra; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getAD", [](dnsheader& dh) { return dh.ad; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getAA", [](dnsheader& dh) { return dh.aa; });
-  d_lw->registerFunction<bool(dnsheader::*)()>("getRD", [](dnsheader& dh) { return dh.rd; });
-  d_lw->registerFunction<int(dnsheader::*)()>("getRCODE", [](dnsheader& dh) { return dh.rcode; });
-  d_lw->registerFunction<int(dnsheader::*)()>("getOPCODE", [](dnsheader& dh) { return dh.opcode; });
-  d_lw->registerFunction<int(dnsheader::*)()>("getQDCOUNT", [](dnsheader& dh) { return ntohs(dh.qdcount); });
-  d_lw->registerFunction<int(dnsheader::*)()>("getANCOUNT", [](dnsheader& dh) { return ntohs(dh.ancount); });
-  d_lw->registerFunction<int(dnsheader::*)()>("getNSCOUNT", [](dnsheader& dh) { return ntohs(dh.nscount); });
-  d_lw->registerFunction<int(dnsheader::*)()>("getARCOUNT", [](dnsheader& dh) { return ntohs(dh.arcount); });
-
-  d_lw->writeFunction("newDN", [](const std::string& dom){ return DNSName(dom); });
-  d_lw->registerFunction("isPartOf", &DNSName::isPartOf);
-  d_lw->registerFunction<bool(DNSName::*)(const std::string&)>("equal",
-                                                              [](const DNSName& lhs, const std::string& rhs) { return lhs==DNSName(rhs); });
-  d_lw->registerFunction("__eq", &DNSName::operator==);
-
-  d_lw->registerFunction("__eq", &DNSResourceRecord::operator==);
-  d_lw->registerFunction("__lt", &DNSResourceRecord::operator<);
-
-  d_lw->registerFunction<string(DNSResourceRecord::*)()>("toString", [](const DNSResourceRecord& rec) { return rec.getZoneRepresentation();} );
-
-  d_lw->registerFunction<DNSName(DNSResourceRecord::*)()>("qname", [](DNSResourceRecord& rec) { return rec.qname; });
-  d_lw->registerFunction<DNSName(DNSResourceRecord::*)()>("wildcardname", [](DNSResourceRecord& rec) { return rec.wildcardname; });
-  d_lw->registerFunction<string(DNSResourceRecord::*)()>("content", [](DNSResourceRecord& rec) { return rec.content; });
-  d_lw->registerFunction<time_t(DNSResourceRecord::*)()>("last_modified", [](DNSResourceRecord& rec) { return rec.last_modified; });
-  d_lw->registerFunction<uint32_t(DNSResourceRecord::*)()>("ttl", [](DNSResourceRecord& rec) { return rec.ttl; });
-  d_lw->registerFunction<uint32_t(DNSResourceRecord::*)()>("signttl", [](DNSResourceRecord& rec) { return rec.signttl; });
-  d_lw->registerFunction<int(DNSResourceRecord::*)()>("domain_id", [](DNSResourceRecord& rec) { return rec.domain_id; });
-  d_lw->registerFunction<uint16_t(DNSResourceRecord::*)()>("qtype", [](DNSResourceRecord& rec) { return rec.qtype.getCode(); });
-  d_lw->registerFunction<uint16_t(DNSResourceRecord::*)()>("qclass", [](DNSResourceRecord& rec) { return rec.qclass; });
-  d_lw->registerFunction<uint8_t(DNSResourceRecord::*)()>("scopeMask", [](DNSResourceRecord& rec) { return rec.scopeMask; });
-  d_lw->registerFunction<bool(DNSResourceRecord::*)()>("auth", [](DNSResourceRecord& rec) { return rec.auth; });
-  d_lw->registerFunction<bool(DNSResourceRecord::*)()>("disabled", [](DNSResourceRecord& rec) { return rec.disabled; });
-
-  d_lw->registerFunction<string(ComboAddress::*)()>("toString", [](const ComboAddress& ca) { return ca.toString(); });
-  d_lw->registerFunction<string(ComboAddress::*)()>("toStringWithPort", [](const ComboAddress& ca) { return ca.toStringWithPort(); });
-  d_lw->registerFunction<uint16_t(ComboAddress::*)()>("getPort", [](const ComboAddress& ca) { return ntohs(ca.sin4.sin_port); } );
-  d_lw->registerFunction<string(ComboAddress::*)()>("getRaw", [](const ComboAddress& ca) {
-      if(ca.sin4.sin_family == AF_INET) {
-        auto t=ca.sin4.sin_addr.s_addr; return string((const char*)&t, 4);
-      }
-      else
-        return string((const char*)&ca.sin6.sin6_addr.s6_addr, 16);
-    } );
-
-  d_lw->writeFunction("newCA", [](const std::string& a) { return ComboAddress(a); });
-  typedef std::unordered_set<ComboAddress,ComboAddress::addressOnlyHash,ComboAddress::addressOnlyEqual> cas_t;
-  d_lw->writeFunction("newCAS", []{ return cas_t(); });
-
-
-  d_lw->registerFunction<void(cas_t::*)(boost::variant<string,ComboAddress, vector<pair<unsigned int,string> > >)>("add",
-                                                                                   [](cas_t& cas, const boost::variant<string,ComboAddress,vector<pair<unsigned int,string> > >& in)
-                                                                                   {
-                                                                                     try {
-                                                                                     if(auto s = boost::get<string>(&in)) {
-                                                                                       cas.insert(ComboAddress(*s));
-                                                                                     }
-                                                                                     else if(auto v = boost::get<vector<pair<unsigned int, string> > >(&in)) {
-                                                                                       for(const auto& str : *v)
-                                                                                         cas.insert(ComboAddress(str.second));
-                                                                                     }
-                                                                                     else
-                                                                                       cas.insert(boost::get<ComboAddress>(in));
-                                                                                     }
-                                                                                     catch(std::exception& e) { theL() <<Logger::Error<<e.what()<<endl; }
-                                                                                   });
-
-  d_lw->registerFunction<bool(cas_t::*)(const ComboAddress&)>("check",[](const cas_t& cas, const ComboAddress&ca) {
-      return (bool)cas.count(ca);
-    });
-
-
-
-  d_lw->registerFunction<bool(ComboAddress::*)(const ComboAddress&)>("equal", [](const ComboAddress& lhs, const ComboAddress& rhs) {
-      return ComboAddress::addressOnlyEqual()(lhs, rhs);
-    });
-
-
-  d_lw->registerFunction<ComboAddress(Netmask::*)()>("getNetwork", [](const Netmask& nm) { return nm.getNetwork(); } ); // const reference makes this necessary
-  d_lw->registerFunction("toString", &Netmask::toString);
-  d_lw->registerFunction("empty", &Netmask::empty);
-
-  d_lw->writeFunction("newNMG", []() { return NetmaskGroup(); });
-  d_lw->registerFunction<void(NetmaskGroup::*)(const std::string&mask)>("addMask", [](NetmaskGroup&nmg, const std::string& mask)
-                         {
-                           nmg.addMask(mask);
-                         });
+LuaContext* AuthLua4::getLua() { return 0; }
+DNSPacket *AuthLua4::prequery(DNSPacket *q) { return NULL; }
 
-  d_lw->registerFunction<void(NetmaskGroup::*)(const vector<pair<unsigned int, std::string>>&)>("addMasks", [](NetmaskGroup&nmg, const vector<pair<unsigned int, std::string>>& masks)
-                         {
-                           for(const auto& mask: masks)
-                             nmg.addMask(mask.second);
-                         });
-
-
-  d_lw->registerFunction("match", (bool (NetmaskGroup::*)(const ComboAddress&) const)&NetmaskGroup::match);
-  d_lw->registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
-  d_lw->registerFunction<string(DNSName::*)()>("toStringNoDot", [](const DNSName&dn ) { return dn.toStringNoDot(); });
-  d_lw->registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });
-  d_lw->registerMember("name", &DNSRecord::d_name);
-  d_lw->registerMember("type", &DNSRecord::d_type);
-  d_lw->registerMember("ttl", &DNSRecord::d_ttl);
-
-
-  d_lw->registerFunction<string(DNSRecord::*)()>("getContent", [](const DNSRecord& dr) { return dr.d_content->getZoneRepresentation(); });
-  d_lw->registerFunction<boost::optional<ComboAddress>(DNSRecord::*)()>("getCA", [](const DNSRecord& dr) {
-      boost::optional<ComboAddress> ret;
-
-      if(auto rec = std::dynamic_pointer_cast<ARecordContent>(dr.d_content))
-        ret=rec->getCA(53);
-      else if(auto aaaarec = std::dynamic_pointer_cast<AAAARecordContent>(dr.d_content))
-        ret=aaaarec->getCA(53);
-      return ret;
-    });
-
-
-  d_lw->registerFunction<void(DNSRecord::*)(const std::string&)>("changeContent", [](DNSRecord& dr, const std::string& newContent) { dr.d_content = DNSRecordContent::mastermake(dr.d_type, 1, newContent); });
-
-  d_lw->writeFunction("pdnslog", [](const std::string& msg, boost::optional<int> loglevel) {
-      theL() << (Logger::Urgency)loglevel.get_value_or(Logger::Warning) << msg<<endl;
-    });
-  typedef vector<pair<string, int> > in_t;
-  vector<pair<string, boost::variant<int, in_t, struct timeval* > > >  pd{
-    {"PASS", (int)PolicyDecision::PASS}, {"DROP",  (int)PolicyDecision::DROP},
-    {"TRUNCATE", (int)PolicyDecision::TRUNCATE}
-  };
-
-  vector<pair<string, int> > rcodes = {{"NOERROR",  RCode::NoError  },
-                                       {"FORMERR",  RCode::FormErr  },
-                                       {"SERVFAIL", RCode::ServFail },
-                                       {"NXDOMAIN", RCode::NXDomain },
-                                       {"NOTIMP",   RCode::NotImp   },
-                                       {"REFUSED",  RCode::Refused  },
-                                       {"YXDOMAIN", RCode::YXDomain },
-                                       {"YXRRSET",  RCode::YXRRSet  },
-                                       {"NXRRSET",  RCode::NXRRSet  },
-                                       {"NOTAUTH",  RCode::NotAuth  },
-                                       {"NOTZONE",  RCode::NotZone  }};
-  for(const auto& rcode : rcodes)
-    pd.push_back({rcode.first, rcode.second});
+AuthLua4::~AuthLua4() { }
 
-  pd.push_back({"place", in_t{
-    {"QUESTION", 0},
-    {"ANSWER", 1},
-    {"AUTHORITY", 2},
-    {"ADDITIONAL", 3}
-  }});
+void AuthLua4::postPrepareContext()
+{
+}
 
-  pd.push_back({"loglevels", in_t{
-        {"Alert", LOG_ALERT},
-        {"Critical", LOG_CRIT},
-        {"Debug", LOG_DEBUG},
-        {"Emergency", LOG_EMERG},
-        {"Info", LOG_INFO},
-        {"Notice", LOG_NOTICE},
-        {"Warning", LOG_WARNING},
-        {"Error", LOG_ERR}
-          }});
+void AuthLua4::postLoad()
+{
+}
 
-  for(const auto& n : QType::names)
-    pd.push_back({n.first, n.second});
-  d_lw->registerMember("tv_sec", &timeval::tv_sec);
-  d_lw->registerMember("tv_usec", &timeval::tv_usec);
+#else
 
-  d_lw->writeVariable("pdns", pd);
+LuaContext* AuthLua4::getLua()
+{
+  return d_lw.get();
+}
 
+void AuthLua4::postPrepareContext() {
   d_lw->writeFunction("resolve", [](const std::string& qname, uint16_t qtype) {
       std::vector<DNSZoneRecord> ret;
       std::unordered_map<int, DNSResourceRecord> luaResult;
@@ -204,6 +55,46 @@ AuthLua4::AuthLua4(const std::string& fname) {
       return luaResult;
   });
 
+/* DNSPacket */
+  d_lw->writeFunction("newDNSPacket", [](bool isQuery) { return new DNSPacket(isQuery); });
+  d_lw->writeFunction("dupDNSPacket", [](const DNSPacket &orig) { return new DNSPacket(orig); });
+  d_lw->registerFunction<DNSPacket, int(const char *, size_t)>("noparse", [](DNSPacket &p, const char *mesg, size_t len){ return p.noparse(mesg, len); });
+  d_lw->registerFunction<DNSPacket, int(const char *, size_t)>("parse", [](DNSPacket &p, const char *mesg, size_t len){ return p.parse(mesg, len); });
+  d_lw->registerFunction<DNSPacket, const std::string()>("getString", [](DNSPacket &p) { return p.getString(); });
+  d_lw->registerFunction<DNSPacket, void(const ComboAddress&)>("setRemote", [](DNSPacket &p, const ComboAddress &ca) { p.setRemote(&ca); });
+  d_lw->registerFunction<DNSPacket, ComboAddress()>("getRemote", [](DNSPacket &p) { return p.getRemote(); });
+  d_lw->registerFunction<DNSPacket, Netmask()>("getRealRemote", [](DNSPacket &p) { return p.getRealRemote(); });
+  d_lw->registerFunction<DNSPacket, ComboAddress()>("getLocal", [](DNSPacket &p) { return p.getLocal(); });
+  d_lw->registerFunction<DNSPacket, unsigned int()>("getRemotePort", [](DNSPacket &p) { return p.getRemotePort(); });
+  d_lw->registerFunction<DNSPacket, std::tuple<const std::string, unsigned int>()>("getQuestion", [](DNSPacket &p) { return std::make_tuple(p.qdomain.toString(), static_cast<unsigned int>(p.qtype.getCode())); });
+  d_lw->registerFunction<DNSPacket, void(bool)>("setA", [](DNSPacket &p, bool a) { return p.setA(a); });
+  d_lw->registerFunction<DNSPacket, void(unsigned int)>("setID", [](DNSPacket &p, unsigned int id) { return p.setID(static_cast<uint16_t>(id)); });
+  d_lw->registerFunction<DNSPacket, void(bool)>("setRA", [](DNSPacket &p, bool ra) { return p.setRA(ra); });
+  d_lw->registerFunction<DNSPacket, void(bool)>("setRD", [](DNSPacket &p, bool rd) { return p.setRD(rd); });
+  d_lw->registerFunction<DNSPacket, void(bool)>("setAnswer", [](DNSPacket &p, bool answer) { return p.setAnswer(answer); });
+  d_lw->registerFunction<DNSPacket, void(unsigned int)>("setOpCode", [](DNSPacket &p, unsigned int opcode) { return p.setOpcode(static_cast<uint16_t>(opcode)); });
+  d_lw->registerFunction<DNSPacket, void(int)>("setRcode", [](DNSPacket &p, int rcode) { return p.setRcode(rcode); });
+  d_lw->registerFunction<DNSPacket, void()>("clearRecords",[](DNSPacket &p){p.clearRecords();});
+  d_lw->registerFunction<DNSPacket, void(DNSRecord&, bool)>("addRecord", [](DNSPacket &p, DNSRecord &dr, bool auth) { DNSZoneRecord dzr; dzr.dr = dr; dzr.auth = auth; p.addRecord(dzr); });
+  d_lw->registerFunction<DNSPacket, void(const vector<pair<unsigned int, DNSRecord> >&)>("addRecords", [](DNSPacket &p, const vector<pair<unsigned int, DNSRecord> >& records){ for(const auto &dr: records){ DNSZoneRecord dzr; dzr.dr = std::get<1>(dr); dzr.auth = true; p.addRecord(dzr); }});
+  d_lw->registerFunction<DNSPacket, void(unsigned int, const DNSName&, const std::string&)>("setQuestion", [](DNSPacket &p, unsigned int opcode, const DNSName &name, const string &type){ QType qtype; qtype = type; p.setQuestion(static_cast<int>(opcode), name, static_cast<int>(qtype.getCode())); });
+  d_lw->registerFunction<DNSPacket, bool()>("isEmpty", [](DNSPacket &p){return p.isEmpty();});
+  d_lw->registerFunction<DNSPacket, DNSPacket*()>("replyPacket",[](DNSPacket& p){ return p.replyPacket();});
+  d_lw->registerFunction<DNSPacket, bool()>("hasEDNSSubnet", [](DNSPacket &p){return p.hasEDNSSubnet();});
+  d_lw->registerFunction<DNSPacket, bool()>("hasEDNS",[](DNSPacket &p){return p.hasEDNS();});
+  d_lw->registerFunction<DNSPacket, unsigned int()>("getEDNSVersion",[](DNSPacket &p){return p.getEDNSVersion();});
+  d_lw->registerFunction<DNSPacket, void(unsigned int)>("setEDNSRcode",[](DNSPacket &p, unsigned int extRCode){p.setEDNSRcode(static_cast<uint16_t>(extRCode));});
+  d_lw->registerFunction<DNSPacket, unsigned int()>("getEDNSRcode",[](DNSPacket &p){return p.getEDNSRCode();});
+  d_lw->registerFunction<DNSPacket, DNSName()>("getTSIGKeyname",[](DNSPacket &p){ return p.getTSIGKeyname();});
+  d_lw->registerFunction<DNSPacket, std::unordered_map<unsigned int, DNSRecord>()>("getRRS", [](DNSPacket &p){ std::unordered_map<unsigned int, DNSRecord> ret; unsigned int i = 0; for(const auto &rec: p.getRRS()) { ret.insert({i++, rec.dr}); } return ret;});
+  d_lw->registerMember<DNSPacket, DNSName>("qdomain", [](const DNSPacket &p) -> DNSName { return p.qdomain; }, [](DNSPacket &p, const DNSName& name) { p.qdomain = name; });
+  d_lw->registerMember<DNSPacket, DNSName>("qdomainwild", [](const DNSPacket &p) -> DNSName { return p.qdomainwild; }, [](DNSPacket &p, const DNSName& name) { p.qdomainwild = name; });
+  d_lw->registerMember<DNSPacket, DNSName>("qdomainzone", [](const DNSPacket &p) -> DNSName { return p.qdomainzone; }, [](DNSPacket &p, const DNSName& name) { p.qdomainzone = name; });
+
+  d_lw->registerMember<DNSPacket, std::string>("d_peer_principal", [](const DNSPacket &p) -> std::string { return p.d_peer_principal; }, [](DNSPacket &p, const std::string &princ) { p.d_peer_principal = princ; });
+  d_lw->registerMember<DNSPacket, const std::string>("qtype", [](const DNSPacket &p) ->  const std::string { return p.qtype.getName(); }, [](DNSPacket &p, const std::string &type) { p.qtype = type; });
+/* End of DNSPacket */
+
 
 /* update policy */
   d_lw->registerFunction<DNSName(UpdatePolicyQuery::*)()>("getQName", [](UpdatePolicyQuery& upq) { return upq.qname; });
@@ -215,17 +106,12 @@ AuthLua4::AuthLua4(const std::string& fname) {
   d_lw->registerFunction<DNSName(UpdatePolicyQuery::*)()>("getTsigName", [](UpdatePolicyQuery& upq) { return upq.tsigName; });
   d_lw->registerFunction<std::string(UpdatePolicyQuery::*)()>("getPeerPrincipal", [](UpdatePolicyQuery& upq) { return upq.peerPrincipal; });
 /* end of update policy */
+}
 
-  ifstream ifs(fname);
-  if(!ifs) {
-    theL()<<Logger::Error<<"Unable to read configuration file from '"<<fname<<"': "<<strerror(errno)<<endl;
-    return;
-  }
-  d_lw->executeCode(ifs);
-
+void AuthLua4::postLoad() {
   d_update_policy = d_lw->readVariable<boost::optional<luacall_update_policy_t>>("updatepolicy").get_value_or(0);
   d_axfr_filter = d_lw->readVariable<boost::optional<luacall_axfr_filter_t>>("axfrfilter").get_value_or(0);
-
+  d_prequery = d_lw->readVariable<boost::optional<luacall_prequery_t>>("prequery").get_value_or(0);
 }
 
 bool AuthLua4::axfrfilter(const ComboAddress& remote, const DNSName& zone, const DNSResourceRecord& in, vector<DNSResourceRecord>& out) {
@@ -236,10 +122,17 @@ bool AuthLua4::axfrfilter(const ComboAddress& remote, const DNSName& zone, const
 
   ret = d_axfr_filter(remote, zone, in);
   rcode = std::get<0>(ret);
-  if (rcode < 0)
+  if (rcode < 0) {
+    // no modification, handle normally
     return false;
-  else if (rcode == 1)
+  }
+  else if (rcode == 0) {
+    // replace the matching record by the filtered record(s)
+  }
+  else if (rcode == 1) {
+    // append the filtered record(s) after the matching record
     out.push_back(in);
+  }
   else
     throw PDNSException("Cannot understand return code "+std::to_string(rcode)+" in axfr filter response");
 
@@ -267,6 +160,9 @@ bool AuthLua4::axfrfilter(const ComboAddress& remote, const DNSName& zone, const
 
 
 bool AuthLua4::updatePolicy(const DNSName &qname, QType qtype, const DNSName &zonename, DNSPacket *packet) {
+  // default decision is all goes
+  if (d_update_policy == NULL) return true;
+
   UpdatePolicyQuery upq;
   upq.qname = qname;
   upq.qtype = qtype.getCode();
@@ -280,6 +176,17 @@ bool AuthLua4::updatePolicy(const DNSName &qname, QType qtype, const DNSName &zo
   return d_update_policy(upq);
 }
 
+DNSPacket *AuthLua4::prequery(DNSPacket *q) {
+  if (d_prequery == NULL) return NULL;
+
+  DNSPacket *r = q->replyPacket();
+  if (d_prequery(r))
+    return r;
+  delete r;
+  return NULL;
+}
+
 AuthLua4::~AuthLua4() { }
 
+
 #endif