]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec: Add an option to not override custom RPZ types with the default policy
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Feb 2019 15:16:29 +0000 (16:16 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Feb 2019 15:16:29 +0000 (16:16 +0100)
pdns/filterpo.cc
pdns/filterpo.hh
pdns/rec-lua-conf.cc
pdns/rec-lua-conf.hh
pdns/recursordist/docs/lua-config/rpz.rst
pdns/rpzloader.cc
pdns/rpzloader.hh
regression-tests.recursor-dnssec/test_RPZ.py

index dc7fc5e375777e4fefee5d62044fbc7cf1bdc79a..c3585ce031a3be21ab9ba01139c5ea29f7e650c2 100644 (file)
@@ -210,18 +210,18 @@ void DNSFilterEngine::Zone::addResponseTrigger(const Netmask& nm, Policy&& pol)
   d_postpolAddr.insert(nm).second=std::move(pol);
 }
 
-void DNSFilterEngine::Zone::addQNameTrigger(const DNSName& n, Policy&& pol)
+void DNSFilterEngine::Zone::addQNameTrigger(const DNSName& n, Policy&& pol, bool ignoreDuplicate)
 {
   auto it = d_qpolName.find(n);
 
   if (it != d_qpolName.end()) {
     auto& existingPol = it->second;
 
-    if (pol.d_kind != PolicyKind::Custom) {
+    if (pol.d_kind != PolicyKind::Custom && !ignoreDuplicate) {
       throw std::runtime_error("Adding a QName-based filter policy of kind " + getKindToString(pol.d_kind) + " but a policy of kind " + getKindToString(existingPol.d_kind) + " already exists for the following QName: " + n.toLogString());
     }
 
-    if (existingPol.d_kind != PolicyKind::Custom) {
+    if (existingPol.d_kind != PolicyKind::Custom && ignoreDuplicate) {
       throw std::runtime_error("Adding a QName-based filter policy of kind " + getKindToString(existingPol.d_kind) + " but there was already an existing policy for the following QName: " + n.toLogString());
     }
 
index 22900cb37d73e2e3b6d1c3bb153aae5db1f35c41..3d5aea5d86c42500414f098dc983e970edc92158 100644 (file)
@@ -158,7 +158,7 @@ public:
     void dump(FILE * fp) const;
 
     void addClientTrigger(const Netmask& nm, Policy&& pol);
-    void addQNameTrigger(const DNSName& nm, Policy&& pol);
+    void addQNameTrigger(const DNSName& nm, Policy&& pol, bool ignoreDuplicate=false);
     void addNSTrigger(const DNSName& dn, Policy&& pol);
     void addNSIPTrigger(const Netmask& nm, Policy&& pol);
     void addResponseTrigger(const Netmask& nm, Policy&& pol);
index a94c992a2f38f33c521d7bba18689b4225c23642..ea1d56ed6c9d42934aa71b5df2d93fe682c1580e 100644 (file)
@@ -51,31 +51,36 @@ typename C::value_type::second_type constGet(const C& c, const std::string& name
   return iter->second;
 }
 
+typedef std::unordered_map<std::string, boost::variant<bool, uint32_t, std::string > > rpzOptions_t;
 
-static void parseRPZParameters(const std::unordered_map<string,boost::variant<uint32_t, string> >& have, std::string& polName, boost::optional<DNSFilterEngine::Policy>& defpol, uint32_t& maxTTL, size_t& zoneSizeHint)
+static void parseRPZParameters(rpzOptions_t& have, std::string& polName, boost::optional<DNSFilterEngine::Policy>& defpol, bool& defpolOverrideLocal, uint32_t& maxTTL, size_t& zoneSizeHint)
 {
   if(have.count("policyName")) {
-    polName = boost::get<std::string>(constGet(have, "policyName"));
+    polName = boost::get<std::string>(have["policyName"]);
   }
   if(have.count("defpol")) {
     defpol=DNSFilterEngine::Policy();
-    defpol->d_kind = (DNSFilterEngine::PolicyKind)boost::get<uint32_t>(constGet(have, "defpol"));
+    defpol->d_kind = (DNSFilterEngine::PolicyKind)boost::get<uint32_t>(have["defpol"]);
     defpol->d_name = std::make_shared<std::string>(polName);
     if(defpol->d_kind == DNSFilterEngine::PolicyKind::Custom) {
       defpol->d_custom.push_back(DNSRecordContent::mastermake(QType::CNAME, QClass::IN,
-                                                              boost::get<string>(constGet(have,"defcontent"))));
+                                                              boost::get<string>(have["defcontent"])));
 
       if(have.count("defttl"))
-        defpol->d_ttl = static_cast<int32_t>(boost::get<uint32_t>(constGet(have, "defttl")));
+        defpol->d_ttl = static_cast<int32_t>(boost::get<uint32_t>(have["defttl"]));
       else
         defpol->d_ttl = -1; // get it from the zone
     }
+
+    if (have.count("defpolOverrideLocalData")) {
+      defpolOverrideLocal = boost::get<bool>(have["defpolOverrideLocalData"]);
+    }
   }
   if(have.count("maxTTL")) {
-    maxTTL = boost::get<uint32_t>(constGet(have, "maxTTL"));
+    maxTTL = boost::get<uint32_t>(have["maxTTL"]);
   }
   if(have.count("zoneSizeHint")) {
-    zoneSizeHint = static_cast<size_t>(boost::get<uint32_t>(constGet(have, "zoneSizeHint")));
+    zoneSizeHint = static_cast<size_t>(boost::get<uint32_t>(have["zoneSizeHint"]));
   }
 }
 
@@ -186,23 +191,24 @@ void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& de
   };
   Lua.writeVariable("Policy", pmap);
 
-  Lua.writeFunction("rpzFile", [&lci](const string& filename, const boost::optional<std::unordered_map<string,boost::variant<uint32_t, string>>>& options) {
+  Lua.writeFunction("rpzFile", [&lci](const string& filename, boost::optional<rpzOptions_t> options) {
       try {
         boost::optional<DNSFilterEngine::Policy> defpol;
+        bool defpolOverrideLocal = true;
         std::string polName("rpzFile");
         std::shared_ptr<DNSFilterEngine::Zone> zone = std::make_shared<DNSFilterEngine::Zone>();
         uint32_t maxTTL = std::numeric_limits<uint32_t>::max();
         if(options) {
           auto& have = *options;
           size_t zoneSizeHint = 0;
-          parseRPZParameters(have, polName, defpol, maxTTL, zoneSizeHint);
+          parseRPZParameters(have, polName, defpol, defpolOverrideLocal, maxTTL, zoneSizeHint);
           if (zoneSizeHint > 0) {
             zone->reserve(zoneSizeHint);
           }
         }
         g_log<<Logger::Warning<<"Loading RPZ from file '"<<filename<<"'"<<endl;
         zone->setName(polName);
-        loadRPZFromFile(filename, zone, defpol, maxTTL);
+        loadRPZFromFile(filename, zone, defpol, defpolOverrideLocal, maxTTL);
         lci.dfe.addZone(zone);
         g_log<<Logger::Warning<<"Done loading RPZ from file '"<<filename<<"'"<<endl;
       }
@@ -211,9 +217,10 @@ void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& de
       }
     });
 
-  Lua.writeFunction("rpzMaster", [&lci, &delayedThreads](const boost::variant<string, std::vector<std::pair<int, string> > >& masters_, const string& zoneName, const boost::optional<std::unordered_map<string,boost::variant<uint32_t, string>>>& options) {
+  Lua.writeFunction("rpzMaster", [&lci, &delayedThreads](const boost::variant<string, std::vector<std::pair<int, string> > >& masters_, const string& zoneName, boost::optional<rpzOptions_t> options) {
 
       boost::optional<DNSFilterEngine::Policy> defpol;
+      bool defpolOverrideLocal = true;
       std::shared_ptr<DNSFilterEngine::Zone> zone = std::make_shared<DNSFilterEngine::Zone>();
       TSIGTriplet tt;
       uint32_t refresh=0;
@@ -242,40 +249,40 @@ void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& de
         if (options) {
           auto& have = *options;
           size_t zoneSizeHint = 0;
-          parseRPZParameters(have, polName, defpol, maxTTL, zoneSizeHint);
+          parseRPZParameters(have, polName, defpol, defpolOverrideLocal, maxTTL, zoneSizeHint);
           if (zoneSizeHint > 0) {
             zone->reserve(zoneSizeHint);
           }
 
           if(have.count("tsigname")) {
-            tt.name=DNSName(toLower(boost::get<string>(constGet(have, "tsigname"))));
-            tt.algo=DNSName(toLower(boost::get<string>(constGet(have, "tsigalgo"))));
-            if(B64Decode(boost::get<string>(constGet(have, "tsigsecret")), tt.secret))
+            tt.name=DNSName(toLower(boost::get<string>(have["tsigname"])));
+            tt.algo=DNSName(toLower(boost::get<string>(have[ "tsigalgo"])));
+            if(B64Decode(boost::get<string>(have[ "tsigsecret"]), tt.secret))
               throw std::runtime_error("TSIG secret is not valid Base-64 encoded");
           }
 
           if(have.count("refresh")) {
-            refresh = boost::get<uint32_t>(constGet(have,"refresh"));
+            refresh = boost::get<uint32_t>(have["refresh"]);
           }
 
           if(have.count("maxReceivedMBytes")) {
-            maxReceivedXFRMBytes = static_cast<size_t>(boost::get<uint32_t>(constGet(have,"maxReceivedMBytes")));
+            maxReceivedXFRMBytes = static_cast<size_t>(boost::get<uint32_t>(have["maxReceivedMBytes"]));
           }
 
           if(have.count("localAddress")) {
-            localAddress = ComboAddress(boost::get<string>(constGet(have,"localAddress")));
+            localAddress = ComboAddress(boost::get<string>(have["localAddress"]));
           }
 
           if(have.count("axfrTimeout")) {
-            axfrTimeout = static_cast<uint16_t>(boost::get<uint32_t>(constGet(have, "axfrTimeout")));
+            axfrTimeout = static_cast<uint16_t>(boost::get<uint32_t>(have["axfrTimeout"]));
           }
 
           if(have.count("seedFile")) {
-            seedFile = boost::get<std::string>(constGet(have, "seedFile"));
+            seedFile = boost::get<std::string>(have["seedFile"]);
           }
 
           if(have.count("dumpFile")) {
-            dumpFile = boost::get<std::string>(constGet(have, "dumpFile"));
+            dumpFile = boost::get<std::string>(have["dumpFile"]);
           }
         }
 
@@ -297,7 +304,7 @@ void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& de
         if (!seedFile.empty()) {
           g_log<<Logger::Info<<"Pre-loading RPZ zone "<<zoneName<<" from seed file '"<<seedFile<<"'"<<endl;
           try {
-            sr = loadRPZFromFile(seedFile, zone, defpol, maxTTL);
+            sr = loadRPZFromFile(seedFile, zone, defpol, defpolOverrideLocal, maxTTL);
 
             if (zone->getDomain() != domain) {
               throw PDNSException("The RPZ zone " + zoneName + " loaded from the seed file (" + zone->getDomain().toString() + ") does not match the one passed in parameter (" + domain.toString() + ")");
@@ -321,7 +328,7 @@ void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& de
         exit(1);  // FIXME proper exit code?
       }
 
-      delayedThreads.rpzMasterThreads.push_back(std::make_tuple(masters, defpol, maxTTL, zoneIdx, tt, maxReceivedXFRMBytes, localAddress, axfrTimeout, sr, dumpFile));
+      delayedThreads.rpzMasterThreads.push_back(std::make_tuple(masters, defpol, defpolOverrideLocal, maxTTL, zoneIdx, tt, maxReceivedXFRMBytes, localAddress, axfrTimeout, sr, dumpFile));
     });
 
   typedef vector<pair<int,boost::variant<string, vector<pair<int, string> > > > > argvec_t;
@@ -518,7 +525,7 @@ void startLuaConfigDelayedThreads(const luaConfigDelayedThreads& delayedThreads,
 {
   for (const auto& rpzMaster : delayedThreads.rpzMasterThreads) {
     try {
-      std::thread t(RPZIXFRTracker, std::get<0>(rpzMaster), std::get<1>(rpzMaster), std::get<2>(rpzMaster), std::get<3>(rpzMaster), std::get<4>(rpzMaster), std::get<5>(rpzMaster) * 1024 * 1024, std::get<6>(rpzMaster), std::get<7>(rpzMaster), std::get<8>(rpzMaster), std::get<9>(rpzMaster), generation);
+      std::thread t(RPZIXFRTracker, std::get<0>(rpzMaster), std::get<1>(rpzMaster), std::get<2>(rpzMaster), std::get<3>(rpzMaster), std::get<4>(rpzMaster), std::get<5>(rpzMaster), std::get<6>(rpzMaster) * 1024 * 1024, std::get<7>(rpzMaster), std::get<8>(rpzMaster), std::get<9>(rpzMaster), std::get<10>(rpzMaster), generation);
       t.detach();
     }
     catch(const std::exception& e) {
index 6ce143076820e1fc2ea1c554327eb0fce07d0005..4323bd0c6cc9b5bbde4750bba9a2bb54b3856ff2 100644 (file)
@@ -69,7 +69,7 @@ extern GlobalStateHolder<LuaConfigItems> g_luaconfs;
 
 struct luaConfigDelayedThreads
 {
-  std::vector<std::tuple<std::vector<ComboAddress>, boost::optional<DNSFilterEngine::Policy>, uint32_t, size_t, TSIGTriplet, size_t, ComboAddress, uint16_t, std::shared_ptr<SOARecordContent>, std::string> > rpzMasterThreads;
+  std::vector<std::tuple<std::vector<ComboAddress>, boost::optional<DNSFilterEngine::Policy>, bool, uint32_t, size_t, TSIGTriplet, size_t, ComboAddress, uint16_t, std::shared_ptr<SOARecordContent>, std::string> > rpzMasterThreads;
 };
 
 void loadRecursorLuaConfig(const std::string& fname, luaConfigDelayedThreads& delayedThreads);
index 6f6d832a942a7bb0009bceea5e4e30d389f489ee..eba8cb8f0c382cfea724b8fb24f76680e968882e 100644 (file)
@@ -18,16 +18,31 @@ An RPZ can be loaded from file or slaved from a master. To load from file, use f
 
 .. code-block:: Lua
 
-    rpzFile("dblfilename", {defpol=Policy.Custom, defcontent="badserver.example.com"})
+    rpzFile("dblfilename")
 
 To slave from a master and start IXFR to get updates, use for example:
 
 .. code-block:: Lua
 
-    rpzMaster("192.0.2.4", "policy.rpz", {defpol=Policy.Drop})
+    rpzMaster("192.0.2.4", "policy.rpz")
 
 In this example, 'policy.rpz' denotes the name of the zone to query for.
 
+The action to be taken on a match is defined by the zone itself, but in some cases it might be interesting to be able to override it, and always apply the same action
+regardless of the one specified in the RPZ zone. To load from file and override the default action with a custom CNAME to badserver.example.com., use for example:
+
+.. code-block:: Lua
+
+    rpzFile("dblfilename", {defpol=Policy.Custom, defcontent="badserver.example.com"})
+
+To instead drop all queries matching a rule, while slaving from a master:
+
+.. code-block:: Lua
+
+    rpzMaster("192.0.2.4", "policy.rpz", {defpol=Policy.Drop})
+
+Note that since 4.2.0, it is possible for the override policy specified via 'defpol' to no longer applied to local data entries present in the zone by setting the 'defpolOverrideLocalData' parameter to false.
+
 As of version 4.2.0, the first parameter of :func:`rpzMaster` can be a list of addresses for failover:
 
     rpzMaster({"192.0.2.4","192.0.2.5:5301"}, "policy.rpz", {defpol=Policy.Drop})
@@ -61,13 +76,20 @@ RPZ settings
 
 These options can be set in the ``settings`` of both :func:`rpzMaster` and :func:`rpzFile`.
 
+defcontent
+^^^^^^^^^^
+CNAME field to return in case of defpol=Policy.Custom
+
 defpol
 ^^^^^^
 Default policy: `Policy.Custom`_, `Policy.Drop`_, `Policy.NXDOMAIN`_, `Policy.NODATA`_, `Policy.Truncate`_, `Policy.NoAction`_.
 
-defcontent
-^^^^^^^^^^
-CNAME field to return in case of defpol=Policy.Custom
+defpolOverrideLocalData
+^^^^^^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.2.0
+  Before 4.2.0 local data entries are always overridden by the default policy.
+
+Whether local data entries should be overridden by the default policy. Default is true.
 
 defttl
 ^^^^^^
index 8dc17311c5010d25a1fce100dab5dc6b0d95e5e1..116cc9f300feb9910bd8e3d9361c37f47b1f8545 100644 (file)
@@ -60,7 +60,7 @@ static Netmask makeNetmaskFromRPZ(const DNSName& name)
   return Netmask(v6);
 }
 
-void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zone> zone, bool addOrRemove, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL)
+static void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zone> zone, bool addOrRemove, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL)
 {
   static const DNSName drop("rpz-drop."), truncate("rpz-tcp-only."), noaction("rpz-passthru.");
   static const DNSName rpzClientIP("rpz-client-ip"), rpzIP("rpz-ip"),
@@ -68,6 +68,7 @@ void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zon
   static const std::string rpzPrefix("rpz-");
 
   DNSFilterEngine::Policy pol;
+  bool defpolApplied = false;
 
   if(dr.d_class != QClass::IN) {
     return;
@@ -81,6 +82,7 @@ void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zon
     auto crcTarget=crc->getTarget();
     if(defpol) {
       pol=*defpol;
+      defpolApplied = true;
     }
     else if(crcTarget.isRoot()) {
       // cerr<<"Wants NXDOMAIN for "<<dr.d_name<<": ";
@@ -121,8 +123,9 @@ void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zon
     }
   }
   else {
-    if (defpol) {
+    if (defpol && defpolOverrideLocal) {
       pol=*defpol;
+      defpolApplied = true;
     }
     else {
       pol.d_kind = DNSFilterEngine::PolicyKind::Custom;
@@ -131,7 +134,7 @@ void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zon
     }
   }
 
-  if (!defpol || defpol->d_ttl < 0) {
+  if (!defpolApplied || defpol->d_ttl < 0) {
     pol.d_ttl = static_cast<int32_t>(std::min(maxTTL, dr.d_ttl));
   } else {
     pol.d_ttl = static_cast<int32_t>(std::min(maxTTL, static_cast<uint32_t>(pol.d_ttl)));
@@ -169,14 +172,19 @@ void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zon
     else
       zone->rmNSIPTrigger(nm, std::move(pol));
   } else {
-    if(addOrRemove)
-      zone->addQNameTrigger(dr.d_name, std::move(pol));
-    else
+    if(addOrRemove) {
+      /* if we did override the existing policy with the default policy,
+         we might turn two A or AAAA into a CNAME, which would trigger
+         an exception. Let's just ignore it. */
+      zone->addQNameTrigger(dr.d_name, std::move(pol), defpolApplied);
+    }
+    else {
       zone->rmQNameTrigger(dr.d_name, std::move(pol));
+    }
   }
 }
 
-static shared_ptr<SOARecordContent> loadRPZFromServer(const ComboAddress& master, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, uint16_t axfrTimeout)
+static shared_ptr<SOARecordContent> loadRPZFromServer(const ComboAddress& master, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, uint16_t axfrTimeout)
 {
   g_log<<Logger::Warning<<"Loading RPZ zone '"<<zoneName<<"' from "<<master.toStringWithPort()<<endl;
   if(!tt.name.empty())
@@ -206,7 +214,7 @@ static shared_ptr<SOARecordContent> loadRPZFromServer(const ComboAddress& master
        continue;
       }
 
-      RPZRecordToPolicy(dr, zone, true, defpol, maxTTL);
+      RPZRecordToPolicy(dr, zone, true, defpol, defpolOverrideLocal, maxTTL);
       nrecords++;
     } 
     axfrNow = time(nullptr);
@@ -223,7 +231,7 @@ static shared_ptr<SOARecordContent> loadRPZFromServer(const ComboAddress& master
 }
 
 // this function is silent - you do the logging
-std::shared_ptr<SOARecordContent> loadRPZFromFile(const std::string& fname, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL)
+std::shared_ptr<SOARecordContent> loadRPZFromFile(const std::string& fname, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL)
 {
   shared_ptr<SOARecordContent> sr = nullptr;
   ZoneParserTNG zpt(fname);
@@ -244,7 +252,7 @@ std::shared_ptr<SOARecordContent> loadRPZFromFile(const std::string& fname, std:
       }
       else {
        dr.d_name=dr.d_name.makeRelative(domain);
-       RPZRecordToPolicy(dr, zone, true, defpol, maxTTL);
+       RPZRecordToPolicy(dr, zone, true, defpol, defpolOverrideLocal, maxTTL);
       }
     }
     catch(const PDNSException& pe) {
@@ -338,7 +346,7 @@ static bool dumpZoneToDisk(const DNSName& zoneName, const std::shared_ptr<DNSFil
   return true;
 }
 
-void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL, size_t zoneIdx, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, const uint16_t axfrTimeout, std::shared_ptr<SOARecordContent> sr, std::string dumpZoneFileName, uint64_t configGeneration)
+void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL, size_t zoneIdx, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, const uint16_t axfrTimeout, std::shared_ptr<SOARecordContent> sr, std::string dumpZoneFileName, uint64_t configGeneration)
 {
   setThreadName("pdns-r/RPZIXFR");
   bool isPreloaded = sr != nullptr;
@@ -360,7 +368,7 @@ void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DN
     std::shared_ptr<DNSFilterEngine::Zone> newZone = std::make_shared<DNSFilterEngine::Zone>(*oldZone);
     for (const auto& master : masters) {
       try {
-        sr = loadRPZFromServer(master, zoneName, newZone, defpol, maxTTL, tt, maxReceivedBytes, localAddress, axfrTimeout);
+        sr = loadRPZFromServer(master, zoneName, newZone, defpol, defpolOverrideLocal, maxTTL, tt, maxReceivedBytes, localAddress, axfrTimeout);
         if(refresh == 0) {
           refresh = sr->d_st.refresh;
         }
@@ -473,7 +481,7 @@ void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DN
        else {
           totremove++;
          g_log<<(g_logRPZChanges ? Logger::Info : Logger::Debug)<<"Had removal of "<<rr.d_name<<" from RPZ zone "<<zoneName<<endl;
-         RPZRecordToPolicy(rr, newZone, false, defpol, maxTTL);
+         RPZRecordToPolicy(rr, newZone, false, defpol, defpolOverrideLocal, maxTTL);
        }
       }
 
@@ -490,7 +498,7 @@ void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DN
        else {
           totadd++;
          g_log<<(g_logRPZChanges ? Logger::Info : Logger::Debug)<<"Had addition of "<<rr.d_name<<" to RPZ zone "<<zoneName<<endl;
-         RPZRecordToPolicy(rr, newZone, true, defpol, maxTTL);
+         RPZRecordToPolicy(rr, newZone, true, defpol, defpolOverrideLocal, maxTTL);
        }
       }
     }
index b167ee7f19a2d0217ff33daa1d6a043112a85b1a..7a2047a2269338e664df61619e9ad76725b871bd 100644 (file)
@@ -26,9 +26,8 @@
 
 extern bool g_logRPZChanges;
 
-std::shared_ptr<SOARecordContent> loadRPZFromFile(const std::string& fname, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL);
-void RPZRecordToPolicy(const DNSRecord& dr, std::shared_ptr<DNSFilterEngine::Zone> zone, bool addOrRemove, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL);
-void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DNSFilterEngine::Policy> defpol, uint32_t maxTTL, size_t zoneIdx, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, const uint16_t axfrTimeout, shared_ptr<SOARecordContent> sr, std::string dumpZoneFileName, uint64_t configGeneration);
+std::shared_ptr<SOARecordContent> loadRPZFromFile(const std::string& fname, std::shared_ptr<DNSFilterEngine::Zone> zone, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL);
+void RPZIXFRTracker(const std::vector<ComboAddress>& masters, boost::optional<DNSFilterEngine::Policy> defpol, bool defpolOverrideLocal, uint32_t maxTTL, size_t zoneIdx, const TSIGTriplet& tt, size_t maxReceivedBytes, const ComboAddress& localAddress, const uint16_t axfrTimeout, shared_ptr<SOARecordContent> sr, std::string dumpZoneFileName, uint64_t configGeneration);
 
 struct rpzStats
 {
index 39f8e410c4168b40ffcfd5f980b5a5bc0aaa462b..0ed6e08ba0bc3450a7b9042a8e364b4fe86c34e1 100644 (file)
@@ -176,19 +176,7 @@ class RPZServer(object):
                 print('Error in RPZ socket: %s' % str(e))
                 sock.close()
 
-rpzServerPort = 4250
-rpzServer = RPZServer(rpzServerPort)
-
 class RPZRecursorTest(RecursorTest):
-    """
-    This test makes sure that we correctly update RPZ zones via AXFR then IXFR
-    """
-
-    global rpzServerPort
-    _lua_config_file = """
-    -- The first server is a bogus one, to test that we correctly fail over to the second one
-    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
-    """ % (rpzServerPort)
     _wsPort = 8042
     _wsTimeout = 2
     _wsPassword = 'secretpassword'
@@ -213,21 +201,6 @@ webserver-address=127.0.0.1
 webserver-password=%s
 api-key=%s
 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
-    _xfrDone = 0
-
-    @classmethod
-    def generateRecursorConfig(cls, confdir):
-        authzonepath = os.path.join(confdir, 'example.zone')
-        with open(authzonepath, 'w') as authzone:
-            authzone.write("""$ORIGIN example.
-@ 3600 IN SOA {soa}
-a 3600 IN A 192.0.2.42
-b 3600 IN A 192.0.2.42
-c 3600 IN A 192.0.2.42
-d 3600 IN A 192.0.2.42
-e 3600 IN A 192.0.2.42
-""".format(soa=cls._SOA))
-        super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
 
     @classmethod
     def setUpClass(cls):
@@ -283,6 +256,16 @@ e 3600 IN A 192.0.2.42
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertEqual(len(res.answer), 0)
 
+    def checkNXD(self, qname, qtype):
+        query = dns.message.make_query(qname, qtype, want_dnssec=True)
+        query.flags |= dns.flags.CD
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            res = sender(query)
+            self.assertRcodeEqual(res, dns.rcode.NXDOMAIN)
+            self.assertEqual(len(res.answer), 0)
+            self.assertEqual(len(res.authority), 1)
+
     def checkTruncated(self, qname, qtype='A'):
         query = dns.message.make_query(qname, qtype, want_dnssec=True)
         query.flags |= dns.flags.CD
@@ -308,6 +291,66 @@ e 3600 IN A 192.0.2.42
             res = sender(query)
             self.assertEqual(res, None)
 
+    def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount):
+        headers = {'x-api-key': self._apiKey}
+        url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
+        r = requests.get(url, headers=headers, timeout=self._wsTimeout)
+        self.assertTrue(r)
+        self.assertEquals(r.status_code, 200)
+        self.assertTrue(r.json())
+        content = r.json()
+        self.assertIn('zone.rpz.', content)
+        zone = content['zone.rpz.']
+        for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
+            self.assertIn(key, zone)
+
+        self.assertEquals(zone['serial'], serial)
+        self.assertEquals(zone['records'], recordsCount)
+        self.assertEquals(zone['transfers_full'], fullXFRCount)
+        self.assertEquals(zone['transfers_success'], totalXFRCount)
+
+rpzServerPort = 4250
+rpzServer = RPZServer(rpzServerPort)
+
+class RPZXFRRecursorTest(RPZRecursorTest):
+    """
+    This test makes sure that we correctly update RPZ zones via AXFR then IXFR
+    """
+
+    global rpzServerPort
+    _lua_config_file = """
+    -- The first server is a bogus one, to test that we correctly fail over to the second one
+    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
+    """ % (rpzServerPort)
+    _confdir = 'RPZXFR'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+    _xfrDone = 0
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+        super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
+
     def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
         global rpzServer
 
@@ -327,24 +370,6 @@ e 3600 IN A 192.0.2.42
 
         raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
 
-    def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount):
-        headers = {'x-api-key': self._apiKey}
-        url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
-        r = requests.get(url, headers=headers, timeout=self._wsTimeout)
-        self.assertTrue(r)
-        self.assertEquals(r.status_code, 200)
-        self.assertTrue(r.json())
-        content = r.json()
-        self.assertIn('zone.rpz.', content)
-        zone = content['zone.rpz.']
-        for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
-            self.assertIn(key, zone)
-
-        self.assertEquals(zone['serial'], serial)
-        self.assertEquals(zone['records'], recordsCount)
-        self.assertEquals(zone['transfers_full'], fullXFRCount)
-        self.assertEquals(zone['transfers_success'], totalXFRCount)
-
     def testRPZ(self):
         # first zone, only a should be blocked
         self.waitUntilCorrectSerialIsLoaded(1)
@@ -410,3 +435,192 @@ e 3600 IN A 192.0.2.42
         # check non-custom policies
         self.checkTruncated('tc.example.')
         self.checkDropped('drop.example.')
+
+class RPZFileRecursorTest(RPZRecursorTest):
+    """
+    This test makes sure that we correctly load RPZ zones from a file
+    """
+
+    _confdir = 'RPZFile'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _lua_config_file = """
+    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz." })
+    """ % (_confdir)
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+z 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+
+        rpzFilePath = os.path.join(confdir, 'zone.rpz')
+        with open(rpzFilePath, 'w') as rpzZone:
+            rpzZone.write("""$ORIGIN zone.rpz.
+@ 3600 IN SOA {soa}
+a.example.zone.rpz. 60 IN A 192.0.2.42
+a.example.zone.rpz. 60 IN A 192.0.2.43
+a.example.zone.rpz. 60 IN TXT "some text"
+drop.example.zone.rpz. 60 IN CNAME rpz-drop.
+z.example.zone.rpz. 60 IN A 192.0.2.1
+tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
+""".format(soa=cls._SOA))
+        super(RPZFileRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def testRPZ(self):
+        self.checkCustom('a.example.', 'A', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42', '192.0.2.43'))
+        self.checkCustom('a.example.', 'TXT', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'TXT', '"some text"'))
+        self.checkBlocked('z.example.')
+        self.checkNotBlocked('b.example.')
+        self.checkNotBlocked('c.example.')
+        self.checkNotBlocked('d.example.')
+        self.checkNotBlocked('e.example.')
+        # check that the policy is disabled for AD=1 queries
+        self.checkNotBlocked('z.example.', True)
+        # check non-custom policies
+        self.checkTruncated('tc.example.')
+        self.checkDropped('drop.example.')
+
+class RPZFileDefaultPolRecursorTest(RPZRecursorTest):
+    """
+    This test makes sure that we correctly load RPZ zones from a file with a default policy
+    """
+
+    _confdir = 'RPZFileDefaultPolicy'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _lua_config_file = """
+    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz.", defpol=Policy.NoAction })
+    """ % (_confdir)
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+drop 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+z 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+
+        rpzFilePath = os.path.join(confdir, 'zone.rpz')
+        with open(rpzFilePath, 'w') as rpzZone:
+            rpzZone.write("""$ORIGIN zone.rpz.
+@ 3600 IN SOA {soa}
+a.example.zone.rpz. 60 IN A 192.0.2.42
+drop.example.zone.rpz. 60 IN CNAME rpz-drop.
+z.example.zone.rpz. 60 IN A 192.0.2.1
+tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
+""".format(soa=cls._SOA))
+        super(RPZFileDefaultPolRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def testRPZ(self):
+        # local data entries are overridden by default
+        self.checkCustom('a.example.', 'A', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42'))
+        self.checkNoData('a.example.', 'TXT')
+        # will not be blocked because the default policy overrides local data entries by default
+        self.checkNotBlocked('z.example.')
+        self.checkNotBlocked('b.example.')
+        self.checkNotBlocked('c.example.')
+        self.checkNotBlocked('d.example.')
+        self.checkNotBlocked('e.example.')
+        # check non-local policies, they should be overridden by the default policy
+        self.checkNXD('tc.example.', 'A')
+        self.checkNotBlocked('drop.example.')
+
+class RPZFileDefaultPolNotOverrideLocalRecursorTest(RPZRecursorTest):
+    """
+    This test makes sure that we correctly load RPZ zones from a file with a default policy, not overriding local data entries
+    """
+
+    _confdir = 'RPZFileDefaultPolicyNotOverrideLocal'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _lua_config_file = """
+    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz.", defpol=Policy.NoAction, defpolOverrideLocalData=false })
+    """ % (_confdir)
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+drop 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+z 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+
+        rpzFilePath = os.path.join(confdir, 'zone.rpz')
+        with open(rpzFilePath, 'w') as rpzZone:
+            rpzZone.write("""$ORIGIN zone.rpz.
+@ 3600 IN SOA {soa}
+a.example.zone.rpz. 60 IN A 192.0.2.42
+a.example.zone.rpz. 60 IN A 192.0.2.43
+a.example.zone.rpz. 60 IN TXT "some text"
+drop.example.zone.rpz. 60 IN CNAME rpz-drop.
+z.example.zone.rpz. 60 IN A 192.0.2.1
+tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
+""".format(soa=cls._SOA))
+        super(RPZFileDefaultPolNotOverrideLocalRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def testRPZ(self):
+        # local data entries will not be overridden by the default polic
+        self.checkCustom('a.example.', 'A', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42', '192.0.2.43'))
+        self.checkCustom('a.example.', 'TXT', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'TXT', '"some text"'))
+        # will be blocked because the default policy does not override local data entries
+        self.checkBlocked('z.example.')
+        self.checkNotBlocked('b.example.')
+        self.checkNotBlocked('c.example.')
+        self.checkNotBlocked('d.example.')
+        self.checkNotBlocked('e.example.')
+        # check non-local policies, they should be overridden by the default policy
+        self.checkNXD('tc.example.', 'A')
+        self.checkNotBlocked('drop.example.')