]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
auth: basic protobuf emission including test
authorPeter van Dijk <peter.van.dijk@powerdns.com>
Fri, 20 Mar 2026 07:23:12 +0000 (08:23 +0100)
committerMiod Vallat <miod.vallat@powerdns.com>
Wed, 8 Apr 2026 06:51:18 +0000 (08:51 +0200)
Signed-off-by: Miod Vallat <miod.vallat@powerdns.com>
15 files changed:
docs/settings.rst
meson.build
modules/remotebackend/Makefile.am
modules/remotebackend/meson.build
pdns/Makefile.am
pdns/auth-main.cc
pdns/auth-main.hh
pdns/packethandler.cc
pdns/packethandler.hh
pdns/remote_logger.cc
regression-tests.auth-py/requirements.in
regression-tests.auth-py/requirements.txt
regression-tests.auth-py/runtests
regression-tests.auth-py/test_Protobuf.py [new file with mode: 0644]
tasks.py

index 85e7c9312311dad797f6228e637a17562cd87101..352c0515e25f53f6bec23f077c85b4318b7e1334 100644 (file)
@@ -1530,6 +1530,19 @@ prevent-self-notification to "no".
 
 Turn on operating as a primary. See :ref:`primary-operation`.
 
+.. _setting-protobuf-servers:
+
+``protobuf-servers``
+-----------
+
+.. versionadded:: 5.1.0
+
+-  IP addresses with ports, separated by commas
+-  Default: empty
+
+Servers to send Protobuf logging to.
+This currently sends both questions and responses, but without answer data.
+
 .. _setting-proxy-protocol-from:
 
 ``proxy-protocol-from``
index d9f1a60902ed8de60c4d3eab3f3faf15242a0ba3..8af212e94bfac07ba3e9843f9e4352cc7214d84b 100644 (file)
@@ -585,6 +585,8 @@ common_sources += files(
   src_dir / 'packethandler.cc',
   src_dir / 'packethandler.hh',
   src_dir / 'pdnsexception.hh',
+  src_dir / 'protozero.cc',
+  src_dir / 'protozero.hh',
   src_dir / 'proxy-protocol.cc',
   src_dir / 'proxy-protocol.hh',
   src_dir / 'qtype.cc',
@@ -593,6 +595,8 @@ common_sources += files(
   src_dir / 'query-local-address.hh',
   src_dir / 'rcpgenerator.cc',
   src_dir / 'rcpgenerator.hh',
+  src_dir / 'remote_logger.cc',
+  src_dir / 'remote_logger.hh',
   src_dir / 'resolver.cc',
   src_dir / 'resolver.hh',
   src_dir / 'responsestats-auth.cc',
index 8e9556f9e94f1738d8859e03a09a5a8e04c09d96..e66566135e93e9de25aa2d841c04cc48918b128b 100644 (file)
@@ -1,5 +1,6 @@
 AM_CPPFLAGS += \
        -I$(top_srcdir)/ext/json11 \
+       -I$(top_srcdir)/ext/protozero/include \
        $(YAHTTP_CFLAGS) \
        $(LIBCRYPTO_CFLAGS) \
        $(LIBCRYPTO_INCLUDES) \
index e4790932d8d5fde2f7d8a1d88cb31baba674bb89..778bb56554905b042927b9e0363d36950e8737f2 100644 (file)
@@ -10,7 +10,7 @@ module_extras = files(
   'remotebackend.hh',
 )
 
-module_deps = [deps, dep_zeromq]
+module_deps = [deps, dep_protozero, dep_zeromq]
 
 if get_option('unit-tests-backends')
   module_remotebackend_testrunner = files('testrunner.sh')[0]
index 67665ba2813999000bbe4fb8318bdffe42a03876..020768e764a51ac0780a6e9357ba48158d0f988e 100644 (file)
@@ -257,10 +257,12 @@ pdns_server_SOURCES = \
        packetcache.hh \
        packethandler.cc packethandler.hh \
        pdnsexception.hh \
+       protozero.cc protozero.hh \
        proxy-protocol.cc proxy-protocol.hh \
        qtype.cc qtype.hh \
        query-local-address.hh query-local-address.cc \
        rcpgenerator.cc \
+       remote_logger.cc remote_logger.hh \
        resolver.cc resolver.hh \
        responsestats.cc responsestats.hh responsestats-auth.cc \
        rfc2136handler.cc \
index 6da017a84e542bc7eada993976fd7d6e702e06dc..a658f13c59fac2eafea8882c1951365a9180243f 100644 (file)
@@ -126,6 +126,8 @@ StatBag S; //!< Statistics are gathered across PDNS via the StatBag class S
 AuthPacketCache PC; //!< This is the main PacketCache, shared across all threads
 AuthQueryCache QC;
 AuthZoneCache g_zoneCache;
+std::vector<std::unique_ptr<RemoteLogger>> g_remote_loggers;
+
 std::unique_ptr<DNSProxy> DP{nullptr};
 static std::unique_ptr<DynListener> s_dynListener{nullptr};
 CommunicatorClass Communicator;
@@ -342,6 +344,8 @@ static void declareArguments()
 
   ::arg().set("default-catalog-zone", "Catalog zone to assign newly created primary zones (via the API) to") = "";
 
+  ::arg().set("protobuf-servers", "Servers to send protobuf logging to");
+
 #ifdef ENABLE_GSS_TSIG
   ::arg().setSwitch("enable-gss-tsig", "Enable GSS TSIG processing") = "no";
 #endif
@@ -944,6 +948,15 @@ static void mainthread()
 
   pdns::parseTrustedNotificationProxy(::arg()["trusted-notification-proxy"]);
 
+  {
+    vector<string> addrs;
+    stringtok(addrs, ::arg()["protobuf-servers"], ", ;");
+
+    for (const string& addr : addrs) {
+      g_remote_loggers.emplace_back(make_unique<RemoteLogger>(ComboAddress(addr)));
+    }
+  }
+
   UeberBackend::go();
 
   // Setup the zone cache
index 3b92157a4dae744536939411967d09cfeba87cf9..1dad9dc3050a7ffeb6aa03e4422e2651d7245bc6 100644 (file)
@@ -34,6 +34,7 @@
 #include "statbag.hh"
 #include "tcpreceiver.hh"
 #include "dnsseckeeper.hh"
+#include "remote_logger.hh"
 
 extern time_t g_starttime;
 extern ArgvMap theArg;
@@ -57,3 +58,4 @@ extern time_t g_luaConsistentHashesExpireDelay;
 extern time_t g_luaConsistentHashesCleanupInterval;
 #endif // HAVE_LUA_RECORDS
 extern bool g_views;
+extern std::vector<std::unique_ptr<RemoteLogger>> g_remote_loggers;
index 890242093f3ecae721c11ef9bf0f5cb2d35c2bd4..67c8273dfd840e1a0a54cbcb3b7243a554f662c9 100644 (file)
@@ -48,6 +48,8 @@
 #include "auth-main.hh"
 #include "trusted-notification-proxy.hh"
 #include "gss_context.hh"
+#include "gettime.hh"
+#include "protozero.hh"
 
 #if 0
 #undef DLOG
@@ -2087,11 +2089,54 @@ bool PacketHandler::opcodeQueryInner2(DNSPacket& pkt, queryState &state, bool re
   return true;
 }
 
+static void fillProtoZeroMessageFromDNSPacket(pdns::ProtoZero::Message& msg, DNSPacket& pkt)
+{
+  struct timeval now{};
+
+  gettimeofday(&now, nullptr);
+  msg.setRequest(getUniqueID(), pkt.getRemote(), pkt.getLocal(), pkt.qdomain, pkt.qtype, pkt.qclass, pkt.d.id, pkt.d_tcp ? pdns::ProtoZero::Message::TransportProtocol::TCP : pdns::ProtoZero::Message::TransportProtocol::UDP, pkt.getString().length());
+
+  if (pkt.hasEDNS()) {
+    msg.setEDNSVersion(pkt.getEDNSVersion());
+  }
+
+  msg.setTime(now.tv_sec, now.tv_usec);
+  msg.setHeaderFlags(*getFlagsFromDNSHeader(&pkt.d));
+
+  if (pkt.d.qr == 0) {
+    msg.setType(pdns::ProtoZero::Message::MessageType::DNSQueryType);
+  }
+  else {
+    msg.setType(pdns::ProtoZero::Message::MessageType::DNSResponseType);
+  }
+}
+
+static bool mustSendProtoBuf()
+{
+  return !g_remote_loggers.empty();
+}
+
+static void sendProtobuf(const std::string& data)
+{
+  for (const auto& logger : g_remote_loggers) {
+    std::ignore = logger->queueData(data);
+  }
+}
+
 std::unique_ptr<DNSPacket> PacketHandler::opcodeQuery(DNSPacket& pkt, bool noCache)
 {
   queryState state;
   state.noCache = noCache;
 
+  if (mustSendProtoBuf()) {
+    std::string data;
+    // data.reserve()
+    pdns::ProtoZero::Message msg{data};
+
+    fillProtoZeroMessageFromDNSPacket(msg, pkt);
+    sendProtobuf(data);
+  }
+
   if (opcodeQueryInner(pkt, state)) {
     doAdditionalProcessing(pkt, state.r);
 
@@ -2112,7 +2157,15 @@ std::unique_ptr<DNSPacket> PacketHandler::opcodeQuery(DNSPacket& pkt, bool noCac
     if (PC.enabled() && !state.noCache && pkt.couldBeCached()) {
       PC.insert(pkt, *state.r, state.r->getMinTTL(), pkt.d_view); // in the packet cache
     }
-  }
+
+    if (mustSendProtoBuf()) {
+      std::string data;
+      pdns::ProtoZero::Message msg{data};
+
+      fillProtoZeroMessageFromDNSPacket(msg, *state.r);
+      sendProtobuf(data);
+    }
+}
 
   return std::move(state.r);
 }
index 4f9f5d6275d76a5c5607282ffdbea06473709c15..26f925a3b3966c23884dc5eb004ff2d4e25917e6 100644 (file)
@@ -23,6 +23,7 @@
 #include <sys/socket.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
+#include "protozero.hh"
 #include "ueberbackend.hh"
 #include "dnspacket.hh"
 #include "packetcache.hh"
index b00a26a63da352c8d9232dddb7626454a01f0211..61248f6303fa1748afc1507a81407ed6e4fffa90 100644 (file)
 
 #include "threadname.hh"
 
-#ifdef RECURSOR
+#ifndef DNSDIST // PDNS_AUTH or RECURSOR
 #include "logger.hh"
-#else /* !RECURSOR */
+#else
 #include "dolog.hh"
-#if defined(DNSDIST)
 #include "dnsdist-logging.hh"
-#endif /* DNSDIST */
-#endif /* !RECURSOR */
+#endif
 #include "logging.hh"
 
 bool CircularWriteBuffer::hasRoomFor(const std::string& str) const
@@ -161,7 +159,7 @@ bool RemoteLogger::reconnect()
     }
   }
   catch (const std::exception& e) {
-#ifdef RECURSOR
+#ifndef DNSDIST // PDNS_AUTH or RECURSOR
     SLOG(g_log << Logger::Warning << "Error connecting to remote logger " << d_remote.toStringWithPort() << ": " << e.what() << std::endl,
          g_slog->withName("protobuf")->error(Logr::Error, e.what(), "Exception while connecting to remote logger", "address", Logging::Loggable(d_remote)));
 #else
@@ -228,6 +226,8 @@ void RemoteLogger::maintenanceThread()
   try {
 #ifdef RECURSOR
     string threadName = "rec/remlog";
+#elif defined(PDNS_AUTH)
+    string threadName = "auth/remlog";
 #else
     string threadName = "dnsdist/remLog";
 #endif
@@ -273,7 +273,7 @@ void RemoteLogger::maintenanceThread()
     }
   }
   catch (const std::exception& e) {
-#ifdef RECURSOR
+#ifndef DNSDIST // PDNS_AUTH or RECURSOR
     SLOG(cerr << "Remote Logger's maintenance thread died on: " << e.what() << endl,
          g_slog->withName("protobuf")->error(Logr::Error, e.what(), "Remote Logger's maintenance thread died"));
 #else
@@ -282,7 +282,7 @@ void RemoteLogger::maintenanceThread()
 #endif
   }
   catch (...) {
-#ifdef RECURSOR
+#ifndef DNSDIST // PDNS_AUTH or RECURSOR
     SLOG(cerr << "Remote Logger's maintenance thread died on unknown exception" << endl,
          g_slog->withName("protobuf")->info(Logr::Error, "Remote Logger's maintenance thread died"));
 #else
index f0e3fbf698f8b96f6705e8b76529499bbb4c1770..8aaa42d3f08e1271cb570a50ea9ba0ce101e0fe1 100644 (file)
@@ -8,3 +8,4 @@ Twisted>0.15.0
 requests>=2.18.4
 https://github.com/PowerDNS/xfrserver/archive/refs/tags/0.3.zip
 setuptools<82
+protobuf>=3.0
index 15796a7ea2c3aa5147cf46b76cb76ae9ccc2f213..d7e46ac5963264e871d871b0eb97a6c1f1b2e91c 100644 (file)
@@ -2,7 +2,7 @@
 # This file is autogenerated by pip-compile with Python 3.13
 # by the following command:
 #
-#    pip-compile --allow-unsafe --generate-hashes requirements.in
+#    pip-compile --generate-hashes requirements.in
 #
 attrs==26.1.0 \
     --hash=sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309 \
@@ -186,6 +186,16 @@ pluggy==1.6.0 \
     --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
     --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
     # via pytest
+protobuf==7.34.0 \
+    --hash=sha256:3871a3df67c710aaf7bb8d214cc997342e63ceebd940c8c7fc65c9b3d697591a \
+    --hash=sha256:4a72a8ec94e7a9f7ef7fe818ed26d073305f347f8b3b5ba31e22f81fd85fca02 \
+    --hash=sha256:8e329966799f2c271d5e05e236459fe1cbfdb8755aaa3b0914fa60947ddea408 \
+    --hash=sha256:964cf977e07f479c0697964e83deda72bcbc75c3badab506fb061b352d991b01 \
+    --hash=sha256:9d7a5005fb96f3c1e64f397f91500b0eb371b28da81296ae73a6b08a5b76cdd6 \
+    --hash=sha256:9f9079f1dde4e32342ecbd1c118d76367090d4aaa19da78230c38101c5b3dd40 \
+    --hash=sha256:e3b914dd77fa33fa06ab2baa97937746ab25695f389869afdf03e81f34e45dc7 \
+    --hash=sha256:f791ec509707a1d91bd02e07df157e75e4fb9fbdad12a81b7396201ec244e2e3
+    # via -r requirements.in
 pygments==2.20.0 \
     --hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
     --hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
@@ -251,8 +261,7 @@ zope-interface==8.2 \
     --hash=sha256:f777e68c76208503609c83ca021a6864902b646530a1a39abb9ed310d1100664
     # via twisted
 
-# The following packages are considered to be unsafe in a requirements file:
-setuptools==81.0.0 \
-    --hash=sha256:487b53915f52501f0a79ccfd0c02c165ffe06631443a886740b91af4b7a5845a \
-    --hash=sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6
-    # via -r requirements.in
+# WARNING: The following packages were not pinned, but pip requires them to be
+# pinned when the requirements file includes hashes and the requirement is not
+# satisfied by a package already installed. Consider using the --allow-unsafe flag.
+# setuptools
index 3dff4b6cdcb11595f0c73b8983e925cbcb8e68f8..a3c4aed83dc3466d2dc1a51a72d3997237e9501f 100755 (executable)
@@ -13,6 +13,8 @@ mkdir -p configs
 
 [ -f ./vars ] && . ./vars
 
+protoc -I=../pdns/ --python_out=. ../pdns/dnsmessage.proto
+
 if [ -z "$PDNS_BUILD_PATH" ]; then
   # PDNS_BUILD_PATH is unset or empty. Assume an autotools build.
   PDNS_BUILD_PATH=${PWD}/../pdns
diff --git a/regression-tests.auth-py/test_Protobuf.py b/regression-tests.auth-py/test_Protobuf.py
new file mode 100644 (file)
index 0000000..d265785
--- /dev/null
@@ -0,0 +1,368 @@
+# carefully plagiarised from regression-tests.recursor-dnssec/test_Protobuf.py
+# if we add more features, we can grab more inspiration there
+
+import dns
+import dnsmessage_pb2
+import os
+import socket
+import struct
+import sys
+import threading
+import time
+import clientsubnetoption
+from queue import Queue
+
+from authtests import AuthTest
+
+
+def ProtobufConnectionHandler(queue, conn):
+    data = None
+    while True:
+        data = conn.recv(2)
+        if not data:
+            break
+        (datalen,) = struct.unpack("!H", data)
+        data = conn.recv(datalen)
+        if not data:
+            break
+
+        queue.put_nowait(data)
+
+    conn.close()
+
+
+def ProtobufListener(queue, port):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the protobuf listener: %s" % str(e))
+        sys.exit(1)
+
+    sock.listen(100)
+    while True:
+        try:
+            (conn, _) = sock.accept()
+            thread = threading.Thread(name="Connection Handler", target=ProtobufConnectionHandler, args=[queue, conn])
+            thread.daemon = True
+            thread.start()
+
+        except socket.error as e:
+            print("Error in protobuf socket: %s" % str(e))
+
+    sock.close()
+
+
+class ProtobufServerParams:
+    def __init__(self, port):
+        self.queue = Queue()
+        self.port = port
+
+
+protobufServersParameters = [ProtobufServerParams(4243)]
+protobufListeners = []
+for param in protobufServersParameters:
+    listener = threading.Thread(name="Protobuf Listener", target=ProtobufListener, args=[param.queue, param.port])
+    listener.daemon = True
+    listener.start()
+    protobufListeners.append(listener)
+
+
+class TestAuthProtobuf(AuthTest):
+    _config_template = """
+expand-alias=yes
+launch={backend}
+protobuf-servers=127.0.0.1:%s
+""" % (protobufServersParameters[0].port,)
+
+    _zones = {
+        "example": """
+example.                 3600 IN SOA  {soa}
+example.                 3600 IN NS   ns1.example.
+example.                 3600 IN NS   ns2.example.
+ns1.example.             3600 IN A    {prefix}.10
+ns2.example.             3600 IN A    {prefix}.11
+
+a.example.               3600 IN A    192.0.2.80
+        """,
+    }
+
+    def getFirstProtobufMessage(self, retries=10, waitTime=0.1):
+        msg = None
+        # print("in getFirstProtobufMessage")
+        for param in protobufServersParameters:
+            failed = 0
+
+            while param.queue.empty():
+                if failed >= retries:
+                    break
+                failed = failed + 1
+                # print(str(failed) + '...')
+                time.sleep(waitTime)
+
+            # print(str(failed) + ' ' + str(param.queue.empty()))
+            self.assertFalse(param.queue.empty())
+            data = param.queue.get(False)
+            self.assertTrue(data)
+            oldmsg = msg
+            msg = dnsmessage_pb2.PBDNSMessage()
+            msg.ParseFromString(data)
+            if oldmsg is not None:
+                self.assertEqual(msg, oldmsg)
+        return msg
+
+    def emptyProtoBufQueue(self):
+        for param in protobufServersParameters:
+            while not param.queue.empty():
+                param.queue.get(False)
+
+    def checkNoRemainingMessage(self):
+        for param in protobufServersParameters:
+            self.assertTrue(param.queue.empty())
+
+    def checkProtobufBase(
+        self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None, receivedSize=None
+    ):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField("timeSec"))
+        self.assertTrue(msg.HasField("socketFamily"))
+        self.assertEqual(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField("from"))
+        fromvalue = getattr(msg, "from")
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
+        self.assertTrue(msg.HasField("socketProtocol"))
+        self.assertEqual(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField("messageId"))
+        self.assertTrue(msg.HasField("id"))
+        self.assertEqual(msg.id, query.id)
+        self.assertTrue(msg.HasField("inBytes"))
+        if normalQueryResponse:
+            # compare inBytes with length of query/response
+            # Note that for responses, the size we received might differ
+            # because dnspython might compress labels differently from
+            # the recursor
+            if receivedSize:
+                self.assertEqual(msg.inBytes, receivedSize)
+            else:
+                self.assertEqual(msg.inBytes, len(query.to_wire()))
+        if expectedECS is not None:
+            self.assertTrue(msg.HasField("originalRequestorSubnet"))
+            # v4 only for now
+            self.assertEqual(len(msg.originalRequestorSubnet), 4)
+            self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), "127.0.0.1")
+
+    def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None, expectedECS=None):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField("timeSec"))
+        self.assertTrue(msg.HasField("socketFamily"))
+        self.assertEqual(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField("socketProtocol"))
+        self.assertEqual(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField("messageId"))
+        self.assertTrue(msg.HasField("id"))
+        self.assertNotEqual(msg.id, query.id)
+        self.assertTrue(msg.HasField("inBytes"))
+        if length is not None:
+            self.assertEqual(msg.inBytes, length)
+        else:
+            # compare inBytes with length of query/response
+            self.assertEqual(msg.inBytes, len(query.to_wire()))
+        if expectedECS is not None:
+            self.assertTrue(msg.HasField("originalRequestorSubnet"))
+            # v4 only for now
+            self.assertEqual(len(msg.originalRequestorSubnet), 4)
+            self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), expectedECS)
+
+    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator="127.0.0.1", to="127.0.0.1"):
+        self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
+        self.checkProtobufBase(msg, protocol, query, initiator)
+        # dnsdist doesn't fill the responder field for responses
+        # because it doesn't keep the information around.
+        self.assertTrue(msg.HasField("to"))
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.to), to)
+        self.assertTrue(msg.HasField("question"))
+        self.assertTrue(msg.question.HasField("qClass"))
+        self.assertEqual(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField("qType"))
+        self.assertEqual(msg.question.qClass, qtype)
+        self.assertTrue(msg.question.HasField("qName"))
+        self.assertEqual(msg.question.qName, qname)
+
+    # This method takes wire format values to check
+    def checkProtobufHeaderFlagsAndEDNSVersion(self, msg, flags, ednsVersion):
+        self.assertTrue(msg.HasField("headerFlags"))
+        self.assertEqual(msg.headerFlags, socket.htons(flags))
+        self.assertTrue(msg.HasField("ednsVersion"))
+        self.assertEqual(msg.ednsVersion, socket.htonl(ednsVersion))
+
+    def checkProtobufResponse(self, msg, protocol, response, initiator="127.0.0.1", receivedSize=None):
+        self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.checkProtobufBase(msg, protocol, response, initiator, receivedSize=receivedSize)
+
+    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl, checkTTL=True):
+        self.assertTrue(record.HasField("class"))
+        self.assertEqual(getattr(record, "class"), rclass)
+        self.assertTrue(record.HasField("type"))
+        self.assertEqual(record.type, rtype)
+        self.assertTrue(record.HasField("name"))
+        self.assertEqual(record.name, rname)
+        self.assertTrue(record.HasField("ttl"))
+        if checkTTL:
+            self.assertEqual(record.ttl, rttl)
+        self.assertTrue(record.HasField("rdata"))
+
+    def checkProtobufPolicy(self, msg, policyType, reason, trigger, hit, kind):
+        self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.assertTrue(msg.response.HasField("appliedPolicyType"))
+        self.assertTrue(msg.response.HasField("appliedPolicy"))
+        self.assertTrue(msg.response.HasField("appliedPolicyTrigger"))
+        self.assertTrue(msg.response.HasField("appliedPolicyHit"))
+        self.assertTrue(msg.response.HasField("appliedPolicyKind"))
+        self.assertEqual(msg.response.appliedPolicy, reason)
+        self.assertEqual(msg.response.appliedPolicyType, policyType)
+        self.assertEqual(msg.response.appliedPolicyTrigger, trigger)
+        self.assertEqual(msg.response.appliedPolicyHit, hit)
+        self.assertEqual(msg.response.appliedPolicyKind, kind)
+
+    def checkProtobufTags(self, msg, tags):
+        # print(tags)
+        # print('---')
+        # print(msg.response.tags)
+        self.assertEqual(len(msg.response.tags), len(tags))
+        for tag in msg.response.tags:
+            self.assertTrue(tag in tags)
+
+    def checkProtobufMetas(self, msg, metas):
+        # print(metas)
+        # print('---')
+        # print(msg.meta)
+        self.assertEqual(len(msg.meta), len(metas))
+        for m in msg.meta:
+            self.assertTrue(m.HasField("key"))
+            self.assertTrue(m.HasField("value"))
+            self.assertTrue(m.key in metas)
+            for i in m.value.intVal:
+                self.assertTrue(i in metas[m.key]["intVal"])
+            for s in m.value.stringVal:
+                self.assertTrue(s in metas[m.key]["stringVal"])
+
+    def checkProtobufOutgoingQuery(
+        self, msg, protocol, query, qclass, qtype, qname, initiator="127.0.0.1", length=None, expectedECS=None
+    ):
+        self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSOutgoingQueryType)
+        self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length, expectedECS=expectedECS)
+        self.assertTrue(msg.HasField("to"))
+        self.assertTrue(msg.HasField("question"))
+        self.assertTrue(msg.question.HasField("qClass"))
+        self.assertEqual(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField("qType"))
+        self.assertEqual(msg.question.qType, qtype)
+        self.assertTrue(msg.question.HasField("qName"))
+        self.assertEqual(msg.question.qName, qname)
+
+    def checkProtobufIncomingResponse(self, msg, protocol, response, initiator="127.0.0.1", length=None):
+        self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSIncomingResponseType)
+        self.checkOutgoingProtobufBase(msg, protocol, response, initiator, length=length)
+        self.assertTrue(msg.HasField("response"))
+        self.assertTrue(msg.response.HasField("rcode"))
+        self.assertTrue(msg.response.HasField("queryTimeSec"))
+
+    def checkProtobufIncomingNetworkErrorResponse(self, msg, protocol, response, initiator="127.0.0.1"):
+        self.checkProtobufIncomingResponse(msg, protocol, response, initiator, length=0)
+        self.assertEqual(msg.response.rcode, 65536)
+
+    def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName):
+        # print(msg)
+        self.assertTrue((requestorId == "") == (not msg.HasField("requestorId")))
+        self.assertTrue((deviceId == b"") == (not msg.HasField("deviceId")))
+        self.assertTrue((deviceName == "") == (not msg.HasField("deviceName")))
+        self.assertEqual(msg.requestorId, requestorId)
+        self.assertEqual(msg.deviceId, deviceId)
+        self.assertEqual(msg.deviceName, deviceName)
+
+    def setUp(self):
+        super(TestAuthProtobuf, self).setUp()
+        # Make sure the queue is empty, in case
+        # a previous test failed
+        self.emptyProtoBufQueue()
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, "example.zone")
+        with open(authzonepath, "w") as authzone:
+            authzone.write(
+                """$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+tagged 3600 IN A 192.0.2.84
+taggedtcp 3600 IN A 192.0.2.87
+meta 3600 IN A 192.0.2.85
+query-selected 3600 IN A 192.0.2.84
+answer-selected 3600 IN A 192.0.2.84
+types 3600 IN A 192.0.2.84
+types 3600 IN AAAA 2001:DB8::1
+types 3600 IN TXT "Lorem ipsum dolor sit amet"
+types 3600 IN MX 10 a.example.
+types 3600 IN SPF "v=spf1 -all"
+types 3600 IN SRV 10 20 443 a.example.
+cname 3600 IN CNAME a.example.
+
+""".format(soa=cls._SOA)
+            )
+        super(TestAuthProtobuf, cls).generateRecursorConfig(confdir)
+
+
+class ProtobufDefaultTest(TestAuthProtobuf):
+    """
+    This test makes sure that we correctly export queries and response over protobuf.
+    """
+
+    _confdir = "ProtobufDefault"
+
+    def testA(self):
+        name = "a.example."
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, "A", "192.0.2.42")
+        query = dns.message.make_query(name, "A", want_dnssec=True)
+
+        res = self.sendUDPQuery(query)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufHeaderFlagsAndEDNSVersion(msg, 0x0100, 0x00000000)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, "127.0.0.1")
+        self.checkNoRemainingMessage()
+        #
+        # again, for a PC cache hit
+        #
+        res = self.sendUDPQuery(query)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufHeaderFlagsAndEDNSVersion(msg, 0x0100, 0x00000000)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, "127.0.0.1")
+        self.checkNoRemainingMessage()
+
+    def testCNAME(self):
+        name = "cname.example."
+        expectedCNAME = dns.rrset.from_text(name, 0, dns.rdataclass.IN, "CNAME", "a.example.")
+        expectedA = dns.rrset.from_text("a.example.", 0, dns.rdataclass.IN, "A", "192.0.2.42")
+        query = dns.message.make_query(name, "A", want_dnssec=True)
+        query.flags |= dns.flags.CD
+        raw = self.sendUDPQuery(query, decode=False)
+        res = dns.message.from_wire(raw)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        # but first let the protobuf messages the time to get there
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, "127.0.0.1", receivedSize=len(raw))
+        self.checkNoRemainingMessage()
index d438d7669ddfd728eb97cc9d09865290f3b9c9d0..2df18d9a3eda710d74584c31604b8438456abe6b 100644 (file)
--- a/tasks.py
+++ b/tasks.py
@@ -126,6 +126,7 @@ auth_test_deps = [  # FIXME: we should be generating some of these from shlibdep
     "libzmq3-dev",
     "lmdb-utils",
     "prometheus",
+    "protobuf-compiler",
     "python3-venv",
     "socat",
     "softhsm2",