]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Incoming Proxy Protocol support
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 14 Oct 2020 14:47:49 +0000 (16:47 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Jan 2021 09:22:00 +0000 (10:22 +0100)
20 files changed:
pdns/dnsdist-console.cc
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-idstate.cc
pdns/dnsdistdist/dnsdist-proxy-protocol.cc
pdns/dnsdistdist/dnsdist-proxy-protocol.hh
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/docs/advanced/ecs.rst
pdns/dnsdistdist/docs/advanced/proxyprotocol.rst
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/docs/rules-actions.rst
pdns/proxy-protocol.cc
pdns/proxy-protocol.hh

index ecf2d08d034fec5c092555f6e85393d681bb4a20..008ca01840a9bd73c1dfd6f9b372b342e90e2531 100644 (file)
@@ -352,6 +352,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "addDynBlockSMT", true, "names, message[, seconds [, action]]", "block the set of names with message `msg`, for `seconds` seconds (10 by default), applying `action` (default to the one set with `setDynBlocksAction()`)" },
   { "addLocal", true, "addr [, {doTCP=true, reusePort=false, tcpFastOpenQueueSize=0, interface=\"\", cpus={}}]", "add `addr` to the list of addresses we listen on" },
   { "addCacheHitResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\", name=\"name\"}}]", "add a cache hit response rule" },
+  { "AddProxyProtocolValueAction", true, "type, value", "Add a Proxy Protocol TLV value of this type" },
   { "addResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\", name=\"name\"}}]", "add a response rule" },
   { "addSelfAnsweredResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\", name=\"name\"}}]", "add a self-answered response rule" },
   { "addTLSLocal", true, "addr, certFile(s), keyFile(s) [,params]", "listen to incoming DNS over TLS queries on the specified address using the specified certificate (or list of) and key (or list of). The last parameter is a table" },
@@ -484,6 +485,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "PoolAvailableRule", true, "poolname", "Check whether a pool has any servers available to handle queries" },
   { "printDNSCryptProviderFingerprint", true, "\"/path/to/providerPublic.key\"", "display the fingerprint of the provided resolver public key" },
   { "ProbaRule", true, "probability", "Matches queries with a given probability. 1.0 means always" },
+  { "ProxyProtocolValueRule", true, "type [, value]", "matches queries with a specified Proxy Protocol TLV value of that type, optionnally matching the content of the option as well" },
   { "QClassRule", true, "qclass", "Matches queries with the specified qclass. class can be specified as an integer or as one of the built-in DNSClass" },
   { "QNameLabelsCountRule", true, "min, max", "matches if the qname has less than `min` or more than `max` labels" },
   { "QNameRule", true, "qname", "matches queries with the specified qname" },
@@ -542,7 +544,9 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setPoolServerPolicyLua", true, "name, function, pool", "set the server selection policy for this pool to one named 'name' and provided by 'function'" },
   { "setPoolServerPolicyLuaFFI", true, "name, function, pool", "set the server selection policy for this pool to one named 'name' and provided by 'function'" },
   { "setPoolServerPolicyLuaFFIPerThread", true, "name, code", "set server selection policy for this pool to one named 'name' and returned by the Lua FFI code passed in 'code'" },
-  { "setPreserveTrailingData", true, "bool", "set whether trailing data should be preserved while adding ECS or XPF records to incoming queries" },
+  { "setProxyProtocolACL", true, "{netmask, netmask}", "Set the netmasks who are allowed to send Proxy Protocol headers in front of queries/connections" },
+  { "setProxyProtocolApplyACLToProxiedClients", true, "apply", "Whether the general ACL should be applied to the source IP address gathered from a Proxy Protocol header, in addition to being first applied to the source address seen by dnsdist" },
+  { "setProxyProtocolMaximumPayloadSize", true, "max", "Set the maximum size of a Proxy Protocol payload, in bytes" },
   { "setQueryCount", true, "bool", "set whether queries should be counted" },
   { "setQueryCountFilter", true, "func", "filter queries that would be counted, where `func` is a function with parameter `dq` which decides whether a query should and how it should be counted" },
   { "setRingBuffersLockRetries", true, "n", "set the number of attempts to get a non-blocking lock to a ringbuffer shard before blocking" },
index 4a234bf1aa597e1a55334a2107d1da2968b3a7f0..9102feac4db4585cf13100a37a2820ea858718da 100644 (file)
@@ -1444,6 +1444,34 @@ private:
   std::vector<ProxyProtocolValue> d_values;
 };
 
+class AddProxyProtocolValueAction : public DNSAction
+{
+public:
+  AddProxyProtocolValueAction(uint8_t type, const std::string& value): d_value(value), d_type(type)
+  {
+  }
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    if (!dq->proxyProtocolValues) {
+      dq->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
+    }
+
+    dq->proxyProtocolValues->push_back({ d_value, d_type });
+
+    return Action::None;
+  }
+
+  std::string toString() const override
+  {
+    return "add a Proxy-Protocol value of type " + std::to_string(d_type);
+  }
+
+private:
+  std::string d_value;
+  uint8_t d_type;
+};
+
 template<typename T, typename ActionT>
 static void addAction(GlobalStateHolder<vector<T> > *someRulActions, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params) {
   setLuaSideEffect();
@@ -1832,4 +1860,8 @@ void setupLuaActions(LuaContext& luaCtx)
   luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector<std::pair<uint8_t, std::string>>& values) {
       return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values));
     });
+
+  luaCtx.writeFunction("AddProxyProtocolValueAction", [](uint8_t type, const std::string& value) {
+    return std::shared_ptr<DNSAction>(new AddProxyProtocolValueAction(type, value));
+  });
 }
index 34f52accb81e2a2ee9d41eefd9eaebbbcfd0e0f9..2a254ccd4f5e265af6877891c5d47788bb01f3c1 100644 (file)
@@ -118,16 +118,38 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     });
 
   luaCtx.registerFunction<void(DNSQuestion::*)(std::vector<std::pair<uint8_t, std::string>>)>("setProxyProtocolValues", [](DNSQuestion& dq, const std::vector<std::pair<uint8_t, std::string>>& values) {
-      if (!dq.proxyProtocolValues) {
-        dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
-      }
+    if (!dq.proxyProtocolValues) {
+      dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
+    }
 
-      dq.proxyProtocolValues->clear();
-      dq.proxyProtocolValues->reserve(values.size());
-      for (const auto& value : values) {
-        dq.proxyProtocolValues->push_back({value.second, value.first});
-      }
-    });
+    dq.proxyProtocolValues->clear();
+    dq.proxyProtocolValues->reserve(values.size());
+    for (const auto& value : values) {
+      dq.proxyProtocolValues->push_back({value.second, value.first});
+    }
+  });
+
+  luaCtx.registerFunction<void(DNSQuestion::*)(uint8_t, std::string)>("addProxyProtocolValue", [](DNSQuestion& dq, uint8_t type, std::string value) {
+    if (!dq.proxyProtocolValues) {
+      dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
+    }
+
+    dq.proxyProtocolValues->push_back({value, type});
+  });
+
+  luaCtx.registerFunction<std::vector<std::pair<uint8_t, std::string>>(DNSQuestion::*)()>("getProxyProtocolValues", [](const DNSQuestion& dq) {
+    if (!dq.proxyProtocolValues) {
+      return std::vector<std::pair<uint8_t, std::string>>();
+    }
+
+    std::vector<std::pair<uint8_t, std::string>> result;
+    result.resize(dq.proxyProtocolValues->size());
+    for (const auto& value : *dq.proxyProtocolValues) {
+      result.push_back({ value.type, value.content });
+    }
+
+    return result;
+  });
 
   /* LuaWrapper doesn't support inheritance */
   luaCtx.registerMember<const ComboAddress (DNSResponse::*)>("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.local; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; });
index 68dee6bf96f6ebc836f8308f049412c3c223d41a..d9bcc141f5d289f25eeb2c33c4cc1d58ae3dabb4 100644 (file)
@@ -599,4 +599,8 @@ void setupLuaRules(LuaContext& luaCtx)
   luaCtx.writeFunction("LuaFFIRule", [](LuaFFIRule::func_t func) {
       return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
     });
+
+  luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional<std::string> value) {
+      return std::shared_ptr<DNSRule>(new ProxyProtocolValueRule(type, value));
+    });
 }
index 8a13a0f0052a215dac6c6c57c8134ea1275bc26b..97a748ac44815bd33410b1ebcae72492c1f1a1de 100644 (file)
@@ -40,6 +40,7 @@
 #ifdef LUAJIT_VERSION
 #include "dnsdist-lua-ffi.hh"
 #endif /* LUAJIT_VERSION */
+#include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
 #include "dnsdist-secpoll.hh"
 #include "dnsdist-web.hh"
@@ -1893,6 +1894,45 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       g_consoleOutputMsgMaxSize = size;
     });
 
+  luaCtx.writeFunction("setProxyProtocolACL", [](boost::variant<string,vector<pair<int, string>>> inp) {
+    if (g_configurationDone) {
+      errlog("setProxyProtocolACL() cannot be used at runtime!");
+      g_outputBuffer="setProxyProtocolACL() cannot be used at runtime!\n";
+      return;
+    }
+    setLuaSideEffect();
+    NetmaskGroup nmg;
+    if (auto str = boost::get<string>(&inp)) {
+      nmg.addMask(*str);
+    }
+    else {
+      for(const auto& p : boost::get<vector<pair<int,string>>>(inp)) {
+       nmg.addMask(p.second);
+      }
+    }
+    g_proxyProtocolACL = std::move(nmg);
+  });
+
+  luaCtx.writeFunction("setProxyProtocolApplyACLToProxiedClients", [](bool apply) {
+    if (g_configurationDone) {
+      errlog("setProxyProtocolApplyACLToProxiedClients() cannot be used at runtime!");
+      g_outputBuffer="setProxyProtocolApplyACLToProxiedClients() cannot be used at runtime!\n";
+      return;
+    }
+    setLuaSideEffect();
+    g_applyACLToProxiedClients = apply;
+  });
+
+  luaCtx.writeFunction("setProxyProtocolMaximumPayloadSize", [](size_t size) {
+    if (g_configurationDone) {
+      errlog("setProxyProtocolMaximumPayloadSize() cannot be used at runtime!");
+      g_outputBuffer="setProxyProtocolMaximumPayloadSize() cannot be used at runtime!\n";
+      return;
+    }
+    setLuaSideEffect();
+    g_proxyProtocolMaximumSize = std::max(static_cast<size_t>(16), size);
+  });
+
   luaCtx.writeFunction("setUDPMultipleMessagesVectorSize", [](size_t vSize) {
       if (g_configurationDone) {
         errlog("setUDPMultipleMessagesVectorSize() cannot be used at runtime!");
index 93b8790cd4f7c920ec40b1198f90fd330881154f..cdb2b7d8d0ff744a8dbf6d3cfb07a7c5770fc546 100644 (file)
@@ -411,6 +411,7 @@ void IncomingTCPConnectionState::sendOrQueueResponse(std::shared_ptr<IncomingTCP
   // if we were already reading a query (not the query size, mind you), or sending a response we need to queue the response
   // otherwise we can start sending it right away
   if (state->d_state == IncomingTCPConnectionState::State::idle ||
+      state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
       state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
 
     auto iostate = sendResponse(state, now, std::move(response));
@@ -536,9 +537,12 @@ static IOState handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, c
   uint16_t qtype, qclass;
   unsigned int qnameWireLength = 0;
   DNSName qname(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
-  DNSQuestion dq(&qname, qtype, qclass, &state->d_origDest, &state->d_ci.remote, state->d_buffer, true, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, &state->d_proxiedDestination, &state->d_proxiedRemote, state->d_buffer, true, &queryRealTime);
   dq.dnsCryptQuery = std::move(dnsCryptQuery);
   dq.sni = state->d_handler.getServerNameIndication();
+  if (state->d_proxyProtocolValues) {
+    dq.proxyProtocolValues = std::move(state->d_proxyProtocolValues);
+  }
 
   state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
   if (state->d_isXFR) {
@@ -607,7 +611,7 @@ static IOState handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, c
   }
 
   ++state->d_currentQueriesCount;
-  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
+  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_proxiedRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
   downstreamConnection->queueQuery(TCPQuery(std::move(state->d_buffer), std::move(ids)), downstreamConnection);
 
   return IOState::NeedRead;
@@ -664,13 +668,64 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
           }
 
           state->d_handshakeDoneTime = now;
-          state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+          if (expectProxyProtocolFrom(state->d_ci.remote)) {
+            state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
+            state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+            state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+          }
+          else {
+            state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+          }
         }
         else {
           wouldBlock = true;
         }
       }
 
+      if (state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+        DEBUGLOG("reading proxy protocol header");
+        do {
+          iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed);
+          if (iostate == IOState::Done) {
+            state->d_buffer.resize(state->d_currentPos);
+            ssize_t remaining = isProxyHeaderComplete(state->d_buffer);
+            if (remaining == 0) {
+              vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort());
+              ++g_stats.proxyProtocolInvalid;
+              break;
+            }
+            else if (remaining < 0) {
+              state->d_proxyProtocolNeed += -remaining;
+              state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed);
+              /* we need to keep reading, since we might have buffered data */
+              iostate = IOState::NeedRead;
+            }
+            else {
+              /* proxy header received */
+              std::vector<ProxyProtocolValue> proxyProtocolValues;
+              if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) {
+                vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort());
+                break;
+              }
+
+              if (!proxyProtocolValues.empty()) {
+                state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+              }
+
+              state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+              state->d_buffer.resize(sizeof(uint16_t));
+              state->d_currentPos = 0;
+              state->d_proxyProtocolNeed = 0;
+              break;
+            }
+          }
+          else {
+            wouldBlock = true;
+          }
+        }
+        while (!wouldBlock);
+      }
+
       if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
         DEBUGLOG("reading query size");
         iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
@@ -750,6 +805,7 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
 
       if (state->d_state != IncomingTCPConnectionState::State::idle &&
           state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+          state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
           state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
           state->d_state != IncomingTCPConnectionState::State::readingQuery &&
           state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
@@ -763,6 +819,7 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
       */
       if (state->d_state == IncomingTCPConnectionState::State::idle ||
           state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
+          state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
           state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
           state->d_state == IncomingTCPConnectionState::State::readingQuery) {
         ++state->d_ci.cs->tcpDiedReadingQuery;
index 09585afc22bc7b629093b96912e4e8688dd523cc..568fd0c2db3cd3ba83fa8318c9205cd124a46ef2 100644 (file)
@@ -474,6 +474,30 @@ bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResp
   return true;
 }
 
+static size_t getInitialUDPPacketBufferSize()
+{
+  static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "The incoming buffer size should not be larger than s_initialUDPPacketBufferSize");
+
+  if (g_proxyProtocolACL.empty()) {
+    return s_initialUDPPacketBufferSize;
+  }
+
+  return s_initialUDPPacketBufferSize + g_proxyProtocolMaximumSize;
+}
+
+static size_t getMaximumIncomingPacketSize(const ClientState& cs)
+{
+  if (cs.dnscryptCtx) {
+    return getInitialUDPPacketBufferSize();
+  }
+
+  if (g_proxyProtocolACL.empty()) {
+    return s_udpIncomingBufferSize;
+  }
+
+  return s_udpIncomingBufferSize + g_proxyProtocolMaximumSize;
+}
+
 static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   if(delayMsec && g_delay) {
@@ -482,7 +506,7 @@ static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int
   }
   else {
     ssize_t res;
-    if(origDest.sin4.sin_family == 0) {
+    if (origDest.sin4.sin_family == 0) {
       res = sendto(origFD, response.data(), response.size(), 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
     }
     else {
@@ -497,7 +521,6 @@ static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int
   return true;
 }
 
-
 int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
 {
   return state->sockets[state->socketsOffset++ % state->sockets.size()];
@@ -524,7 +547,8 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
   try {
   setThreadName("dnsdist/respond");
   auto localRespRulactions = g_resprulactions.getLocal();
-  PacketBuffer response(s_initialUDPPacketBufferSize);
+  const size_t initialBufferSize = getInitialUDPPacketBufferSize();
+  PacketBuffer response(initialBufferSize);
 
   /* when the answer is encrypted in place, we need to get a copy
      of the original header before encryption to fill the ring buffer */
@@ -541,7 +565,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
       }
 
       for (const auto& fd : sockets) {
-        response.resize(s_initialUDPPacketBufferSize);
+        response.resize(initialBufferSize);
         ssize_t got = recv(fd, response.data(), response.size(), 0);
 
         if (got == 0 && dss->isStopped()) {
@@ -647,9 +671,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           else {
             ComboAddress empty;
             empty.sin4.sin_family = 0;
-            /* if ids->destHarvested is false, origDest holds the listening address.
-               We don't want to use that as a source since it could be 0.0.0.0 for example. */
-            sendUDPResponse(origFD, response, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
+            sendUDPResponse(origFD, response, dr.delayMsec, ids->hopLocal, ids->hopRemote);
           }
         }
 
@@ -1010,7 +1032,7 @@ ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss
   return result;
 }
 
-static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest)
+static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, bool& expectProxyProtocol)
 {
   if (msgh->msg_flags & MSG_TRUNC) {
     /* message was too large for our buffer */
@@ -1019,15 +1041,13 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
     return false;
   }
 
-  if(!holders.acl->match(remote)) {
+  expectProxyProtocol = expectProxyProtocolFrom(remote);
+  if (!holders.acl->match(remote) && !expectProxyProtocol) {
     vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
     ++g_stats.aclDrops;
     return false;
   }
 
-  cs.queries++;
-  ++g_stats.queries;
-
   if (HarvestDestinationAddress(msgh, &dest)) {
     /* we don't get the port, only the address */
     dest.sin4.sin_port = cs.local.sin4.sin_port;
@@ -1036,6 +1056,9 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
     dest.sin4.sin_family = 0;
   }
 
+  cs.queries++;
+  ++g_stats.queries;
+
   return true;
 }
 
@@ -1253,9 +1276,19 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 {
   assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
   uint16_t queryId = 0;
+  ComboAddress proxiedRemote = remote;
+  ComboAddress proxiedDestination = dest;
 
   try {
-    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
+    bool expectProxyProtocol = false;
+    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest, expectProxyProtocol)) {
+      return;
+    }
+    /* dest might have been updated, if we managed to harvest the destination address */
+    proxiedDestination = dest;
+
+    std::vector<ProxyProtocolValue> proxyProtocolValues;
+    if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, proxiedRemote, proxiedDestination, proxyProtocolValues)) {
       return;
     }
 
@@ -1282,8 +1315,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     uint16_t qtype, qclass;
     unsigned int qnameWireLength = 0;
     DNSName qname(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
-    DNSQuestion dq(&qname, qtype, qclass, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, query, false, &queryRealTime);
+    DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, false, &queryRealTime);
     dq.dnsCryptQuery = std::move(dnsCryptQuery);
+    if (!proxyProtocolValues.empty()) {
+      dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+    }
+    dq.hopRemote = &remote;
+    dq.hopLocal = &dest;
     std::shared_ptr<DownstreamState> ss{nullptr};
     auto result = processQuery(dq, cs, holders, ss);
 
@@ -1296,13 +1334,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     if (result == ProcessQueryResult::SendAnswer) {
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
       if (dq.delayMsec == 0 && responsesVect != nullptr) {
-        queueResponse(cs, query, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+        queueResponse(cs, query, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
         (*queuedResponses)++;
         return;
       }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
       /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
-      sendUDPResponse(cs.udpFD, query, dq.delayMsec, dest, *dq.remote);
+      sendUDPResponse(cs.udpFD, query, dq.delayMsec, dest, remote);
       return;
     }
 
@@ -1343,19 +1381,11 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     ids->origID = dh->id;
     setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
-    /* If we couldn't harvest the real dest addr, still
-       write down the listening addr since it will be useful
-       (especially if it's not an 'any' one).
-       We need to keep track of which one it is since we may
-       want to use the real but not the listening addr to reply.
-    */
     if (dest.sin4.sin_family != 0) {
       ids->origDest = dest;
-      ids->destHarvested = true;
     }
     else {
       ids->origDest = cs.local;
-      ids->destHarvested = false;
     }
 
     dh = dq.getHeader();
@@ -1373,10 +1403,10 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ++g_stats.downstreamSendErrors;
     }
 
-    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
+    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), proxiedRemote.toStringWithPort(), ss->getName());
   }
   catch(const std::exception& e){
-    vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+    vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", proxiedRemote.toStringWithPort(), queryId, e.what());
   }
 }
 
@@ -1393,22 +1423,24 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
     cmsgbuf_aligned cbuf;
   };
   const size_t vectSize = g_udpVectorSize;
+
+  auto recvData = std::unique_ptr<MMReceiver[]>(new MMReceiver[vectSize]);
+  auto msgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
+  auto outMsgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
+
   /* the actual buffer is larger because:
      - we may have to add EDNS and/or ECS
      - we use it for self-generated responses (from rule or cache)
      but we only accept incoming payloads up to that size
   */
-  static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "the incoming buffer size should not be larger than s_initialUDPPacketBufferSize");
-
-  auto recvData = std::unique_ptr<MMReceiver[]>(new MMReceiver[vectSize]);
-  auto msgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
-  auto outMsgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
+  const size_t initialBufferSize = getInitialUDPPacketBufferSize();
+  const size_t maxIncomingPacketSize = getMaximumIncomingPacketSize(*cs);
 
   /* initialize the structures needed to receive our messages */
   for (size_t idx = 0; idx < vectSize; idx++) {
     recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
-    recvData[idx].packet.resize(s_initialUDPPacketBufferSize);
-    fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), reinterpret_cast<char*>(&recvData[idx].packet.at(0)), cs->dnscryptCtx ? recvData[idx].packet.size() : s_udpIncomingBufferSize, &recvData[idx].remote);
+    recvData[idx].packet.resize(initialBufferSize);
+    fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), reinterpret_cast<char*>(&recvData[idx].packet.at(0)), maxIncomingPacketSize, &recvData[idx].remote);
   }
 
   /* go now */
@@ -1417,7 +1449,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
     /* reset the IO vector, since it's also used to send the vector of responses
        to avoid having to copy the data around */
     for (size_t idx = 0; idx < vectSize; idx++) {
-      recvData[idx].packet.resize(s_initialUDPPacketBufferSize);
+      recvData[idx].packet.resize(initialBufferSize);
       recvData[idx].iov.iov_base = &recvData[idx].packet.at(0);
       recvData[idx].iov.iov_len = recvData[idx].packet.size();
     }
@@ -1465,65 +1497,70 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
 
 // listens to incoming queries, sends out to downstream servers, noting the intended return path
 static void udpClientThread(ClientState* cs)
-try
 {
-  setThreadName("dnsdist/udpClie");
-  LocalHolders holders;
+  try
+  {
+    setThreadName("dnsdist/udpClie");
+    LocalHolders holders;
 
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-  if (g_udpVectorSize > 1) {
-    MultipleMessagesUDPClientThread(cs, holders);
-
-  }
-  else
+    if (g_udpVectorSize > 1) {
+      MultipleMessagesUDPClientThread(cs, holders);
+    }
+    else
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-  {
-    PacketBuffer packet(s_initialUDPPacketBufferSize);
-    /* the actual buffer is larger because:
-       - we may have to add EDNS and/or ECS
-       - we use it for self-generated responses (from rule or cache)
-       but we only accept incoming payloads up to that size
-    */
-    static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)");
-    struct msghdr msgh;
-    struct iovec iov;
-    /* used by HarvestDestinationAddress */
-    cmsgbuf_aligned cbuf;
+    {
+      /* the actual buffer is larger because:
+         - we may have to add EDNS and/or ECS
+         - we use it for self-generated responses (from rule or cache)
+         but we only accept incoming payloads up to that size
+      */
+      const size_t initialBufferSize = getInitialUDPPacketBufferSize();
+      const size_t maxIncomingPacketSize = getMaximumIncomingPacketSize(*cs);
+      PacketBuffer packet(initialBufferSize);
 
-    ComboAddress remote;
-    ComboAddress dest;
-    remote.sin4.sin_family = cs->local.sin4.sin_family;
-    fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast<char*>(&packet.at(0)), cs->dnscryptCtx ? packet.size() : s_udpIncomingBufferSize, &remote);
+      struct msghdr msgh;
+      struct iovec iov;
+      /* used by HarvestDestinationAddress */
+      cmsgbuf_aligned cbuf;
 
-    for(;;) {
-      packet.resize(s_initialUDPPacketBufferSize);
-      iov.iov_base = &packet.at(0);
-      iov.iov_len = packet.size();
+      ComboAddress remote;
+      ComboAddress dest;
+      remote.sin4.sin_family = cs->local.sin4.sin_family;
+      fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast<char*>(&packet.at(0)), maxIncomingPacketSize, &remote);
 
-      ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
+      for(;;) {
+        packet.resize(initialBufferSize);
+        iov.iov_base = &packet.at(0);
+        iov.iov_len = packet.size();
 
-      if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
-        ++g_stats.nonCompliantQueries;
-        continue;
-      }
+        ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
 
-      packet.resize(static_cast<size_t>(got));
-      processUDPQuery(*cs, holders, &msgh, remote, dest, packet, nullptr, nullptr, nullptr, nullptr);
+        if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
+          ++g_stats.nonCompliantQueries;
+          continue;
+        }
+
+        packet.resize(static_cast<size_t>(got));
+
+        processUDPQuery(*cs, holders, &msgh, remote, dest, packet, nullptr, nullptr, nullptr, nullptr);
+      }
     }
   }
+  catch(const std::exception &e)
+  {
+    errlog("UDP client thread died because of exception: %s", e.what());
+  }
+  catch(const PDNSException &e)
+  {
+    errlog("UDP client thread died because of PowerDNS exception: %s", e.reason);
+  }
+  catch(...)
+  {
+    errlog("UDP client thread died because of an exception: %s", "unknown");
+  }
 }
-catch(const std::exception &e)
-{
-  errlog("UDP client thread died because of exception: %s", e.what());
-}
-catch(const PDNSException &e)
-{
-  errlog("UDP client thread died because of PowerDNS exception: %s", e.reason);
-}
-catch(...)
-{
-  errlog("UDP client thread died because of an exception: %s", "unknown");
-}
+
 
 uint16_t getRandomDNSID()
 {
index bd9387dc131d8b937144a89f02c624006fec5901..a572ff6a068529e9756e48f9868622b15c802449 100644 (file)
@@ -124,6 +124,11 @@ public:
   const DNSName* qname{nullptr};
   const ComboAddress* local{nullptr};
   const ComboAddress* remote{nullptr};
+  /* this is the address dnsdist received the packet on,
+     which might not match local when support for incoming proxy protocol
+     is enabled */
+  const ComboAddress* hopLocal{nullptr};  /* the address dnsdist received the packet from, see above */
+  const ComboAddress* hopRemote{nullptr};
   std::shared_ptr<QTag> qTag{nullptr};
   std::unique_ptr<std::vector<ProxyProtocolValue>> proxyProtocolValues{nullptr};
   mutable std::shared_ptr<std::map<uint16_t, EDNSOptionView> > ednsOptions;
@@ -322,6 +327,7 @@ struct DNSDistStats
   stat_t securityStatus{0};
   stat_t dohQueryPipeFull{0};
   stat_t dohResponsePipeFull{0};
+  stat_t proxyProtocolInvalid{0};
 
   double latencyAvg100{0}, latencyAvg1000{0}, latencyAvg10000{0}, latencyAvg1000000{0};
   typedef std::function<uint64_t(const std::string&)> statfunction_t;
@@ -644,6 +650,8 @@ struct IDState
   std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue    // 4
   ComboAddress origRemote;                                    // 28
   ComboAddress origDest;                                      // 28
+  ComboAddress hopRemote;
+  ComboAddress hopLocal;
   StopWatch sentTime;                                         // 16
   DNSName qname;                                              // 80
   std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
index ef91d77d96dc31d651fc8eb1387c4fda922667c1..ad124179f4b9f8cfb7ab6efeaebf56582fdbc4cc 100644 (file)
@@ -23,6 +23,9 @@ DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data, bool is
     dr.dnsCryptQuery = std::move(ids.dnsCryptQuery);
   }
 
+  dr.hopRemote = &ids.hopRemote;
+  dr.hopLocal = &ids.hopLocal;
+
   return dr;
 }
 
@@ -49,5 +52,19 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
   ids.dnssecOK = dq.dnssecOK;
   ids.uniqueId = std::move(dq.uniqueId);
 
+  if (dq.hopRemote) {
+    ids.hopRemote = *dq.hopRemote;
+  }
+  else {
+    ids.hopRemote.sin4.sin_family = 0;
+  }
+
+  if (dq.hopLocal) {
+    ids.hopLocal = *dq.hopLocal;
+  }
+  else {
+    ids.hopLocal.sin4.sin_family = 0;
+  }
+
   ids.dnsCryptQuery = std::move(dq.dnsCryptQuery);
 }
index 0a3d330e90a088b43d3aee4a09c6189be96af334..b187f848dc13159a5c96da64cb20c3690daaa848 100644 (file)
  */
 
 #include "dnsdist-proxy-protocol.hh"
+#include "dolog.hh"
+
+NetmaskGroup g_proxyProtocolACL;
+size_t g_proxyProtocolMaximumSize = 512;
+bool g_applyACLToProxiedClients = false;
 
 std::string getProxyProtocolPayload(const DNSQuestion& dq)
 {
@@ -59,3 +64,44 @@ bool addProxyProtocol(PacketBuffer& buffer, bool tcp, const ComboAddress& source
   auto payload = makeProxyHeader(tcp, source, destination, values);
   return addProxyProtocol(buffer, payload);
 }
+
+bool expectProxyProtocolFrom(const ComboAddress& remote)
+{
+  return g_proxyProtocolACL.match(remote);
+}
+
+bool handleProxyProtocol(const ComboAddress& remote, bool isTCP, const NetmaskGroup& acl, PacketBuffer& query, ComboAddress& realRemote, ComboAddress& realDestination, std::vector<ProxyProtocolValue>& values)
+{
+  bool tcp;
+  bool proxyProto;
+
+  ssize_t used = parseProxyHeader(query, proxyProto, realRemote, realDestination, tcp, values);
+  if (used <= 0) {
+    ++g_stats.proxyProtocolInvalid;
+    vinfolog("Ignoring invalid proxy protocol (%d, %d) query over %s from %s", query.size(), used, (isTCP ? "TCP" : "UDP"), remote.toStringWithPort());
+    return false;
+  }
+  else if (static_cast<size_t>(used) > g_proxyProtocolMaximumSize) {
+    vinfolog("Proxy protocol header in %s packet from %s is larger than proxy-protocol-maximum-size (%d), dropping", (isTCP ? "TCP" : "UDP"), remote.toStringWithPort(), used);
+    ++g_stats.proxyProtocolInvalid;
+    return false;
+  }
+
+  query.erase(query.begin(), query.begin() + used);
+
+  /* on TCP we have not read the actual query yet */
+  if (!isTCP && query.size() < sizeof(struct dnsheader)) {
+    ++g_stats.nonCompliantQueries;
+    return false;
+  }
+
+  if (proxyProto && g_applyACLToProxiedClients) {
+    if (!acl.match(realRemote)) {
+      vinfolog("Query from %s dropped because of ACL", realRemote.toStringWithPort());
+      ++g_stats.aclDrops;
+      return false;
+    }
+  }
+
+  return true;
+}
index 9ca60eb361f5619179c1ea1ab9e3de2adf7ac789..fac7dea6181def9716efb5320a0060a101aa4a6f 100644 (file)
 
 #include "dnsdist.hh"
 
+extern NetmaskGroup g_proxyProtocolACL;
+extern size_t g_proxyProtocolMaximumSize;
+extern bool g_applyACLToProxiedClients;
+
 std::string getProxyProtocolPayload(const DNSQuestion& dq);
 
 bool addProxyProtocol(DNSQuestion& dq);
 bool addProxyProtocol(DNSQuestion& dq, const std::string& payload);
 bool addProxyProtocol(PacketBuffer& buffer, const std::string& payload);
 bool addProxyProtocol(PacketBuffer& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector<ProxyProtocolValue>& values);
+
+bool expectProxyProtocolFrom(const ComboAddress& remote);
+bool handleProxyProtocol(const ComboAddress& remote, bool isTCP, const NetmaskGroup& acl, PacketBuffer& query, ComboAddress& realRemote, ComboAddress& realDestination, std::vector<ProxyProtocolValue>& values);
index abbe81329b7e064f5782316eeb53cb7fbf59559f..a9f690ca9ff5bba9e2e7ee98fe6a194a520478ad 100644 (file)
@@ -1177,3 +1177,38 @@ public:
 private:
   func_t d_func;
 };
+
+class ProxyProtocolValueRule : public DNSRule
+{
+public:
+  ProxyProtocolValueRule(uint8_t type, boost::optional<std::string> value): d_value(value), d_type(type)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (!dq->proxyProtocolValues) {
+      return false;
+    }
+
+    for (const auto& entry : *dq->proxyProtocolValues) {
+      if (entry.type == d_type && (!d_value || entry.content == *d_value)) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  string toString() const override
+  {
+    if (d_value) {
+      return "proxy protocol value of type " + std::to_string(d_type) + " matches";
+    }
+    return "proxy protocol value of type " + std::to_string(d_type) + " is present";
+  }
+
+private:
+  boost::optional<std::string> d_value;
+  uint8_t d_type;
+};
index 58210dc40ebbe84f4f5c274235fd6edffc29aa45..983d96f2d05f9d82bb0b66063f9e85a3b461e728 100644 (file)
@@ -65,6 +65,8 @@ public:
     if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_origDest), &socklen)) {
       d_origDest = d_ci.cs->local;
     }
+    d_proxiedDestination = d_origDest;
+    d_proxiedRemote = d_ci.remote;
   }
 
   IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
@@ -137,36 +139,6 @@ public:
     return false;
   }
 
-#if 0
-  void dump() const
-  {
-    static std::mutex s_mutex;
-
-    struct timeval now;
-    gettimeofday(&now, 0);
-
-    {
-      std::lock_guard<std::mutex> lock(s_mutex);
-      fprintf(stderr, "State is %p\n", this);
-      cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
-      cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
-      cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
-      if (d_state > State::doingHandshake) {
-        cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuerySize) {
-        cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuerySize) {
-        cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuery) {
-        cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
-      }
-    }
-  }
-#endif
-
   std::shared_ptr<TCPConnectionToBackend> getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds);
   std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
   void registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
@@ -196,7 +168,7 @@ public:
     return d_ioState != nullptr;
   }
 
-  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
+  enum class State { doingHandshake, readingProxyProtocolHeader, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
 
   std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_activeConnectionsToBackend;
   PacketBuffer d_buffer;
@@ -205,14 +177,18 @@ public:
   TCPResponse d_currentResponse;
   ConnectionInfo d_ci;
   ComboAddress d_origDest;
+  ComboAddress d_proxiedRemote;
+  ComboAddress d_proxiedDestination;
   TCPIOHandler d_handler;
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+  std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValues{nullptr};
   struct timeval d_connectionStartTime;
   struct timeval d_handshakeDoneTime;
   struct timeval d_firstQuerySizeReadTime;
   struct timeval d_querySizeReadTime;
   struct timeval d_queryReadTime;
   size_t d_currentPos{0};
+  size_t d_proxyProtocolNeed{0};
   size_t d_queriesCount{0};
   size_t d_currentQueriesCount{0};
   unsigned int d_remainingTime{0};
index 004ebba58f4229da434c9a5da82c775ebb515320..e424f5be6ce83f3a0660d675d5f7d1e49e83e5e3 100644 (file)
@@ -15,4 +15,4 @@ In addition to the global settings, rules and Lua bindings can alter this behavi
 
 In effect this means that for the EDNS Client Subnet option to be added to the request, ``useClientSubnet`` should be set to ``true`` for the backend used (default to ``false``) and ECS should not have been disabled by calling :func:`DisableECSAction` or setting ``dq.useECS`` to ``false`` (default to true).
 
-Note that any trailing data present in the incoming query is removed by default when an OPT (or XPF) record has to be inserted. This behaviour can be modified using :func:`setPreserveTrailingData()`.
+Note that any trailing data present in the incoming query is removed when an OPT (or XPF) record has to be inserted.
index 90db38748329bfbd01d2e5372a96b378d6c70402..7c29c23e37281a1d3a335a00db6670c8921ecaaa 100644 (file)
@@ -3,7 +3,12 @@ Using the Proxy Protocol
 
 In order to provide the downstream server with the address of the real client, or at least the one talking to dnsdist, the ``useProxyProtocol`` parameter can be used when creating a :func:`new server <newServer>`.
 This parameter indicates whether a Proxy Protocol version 2 (binary) header should be prepended to the query before forwarding it to the backend, over UDP or TCP. This header contains the initial source and destination addresses and ports, and can also contain several custom values in a Type-Length-Value format. More information about the Proxy Protocol can be found at https://www.haproxy.org/download/2.2/doc/proxy-protocol.txt
+Such a Proxy Protocol header can also be passed from the client to dnsdist, using :func:`setProxyProtocolACL` to specify which clients to accept it from.
+If :func:`setProxyProtocolApplyACLToProxiedClients` is set (default is false), the general ACL will be applied to the source IP address as seen by dnsdist first, but also to the source IP address provided in the Proxy Protocol header.
 
-Custom values can be added to the header via :meth:`DNSQuestion:setProxyProtocolValues` and :func:`SetProxyProtocolValuesAction`.
+Custom values can be added to the header via :meth:`DNSQuestion:addProxyProtocolValue`, :meth:`DNSQuestion:setProxyProtocolValues`, :func:`AddProxyProtocolValueAction` and :func:`SetProxyProtocolValuesAction`.
+Be careful that Proxy Protocol values are sent once at the beginning of the TCP connection for TCP and DoT queries.
+That means that values received on an incoming TCP connection will be inherited by subsequent queries received over the same incoming TCP connection, if any, but values set to a query will not be inherited by subsequent queries.
+Please also note that the maximum size of a Proxy Protocol header dnsdist is willing to accept is 512 bytes by default, although it can be set via :func:`setProxyProtocolMaximumPayloadSize`.
 
-As of 1.5.0 only outgoing Proxy Protocol support has been implemented, although support for parsing incoming Proxy Protocol headers will likely be implemented in the future.
+dnsdist 1.5.0 only supports outgoing Proxy Protocol. Support for parsing incoming Proxy Protocol headers has been implemented in 1.6.0, except for DoH where it does not make sense anyway, since HTTP headers already provide a mechanism for that.
index 1c2e996992513baab435894aeac1f3f386cf18b9..a6f848dd42c7d87774165f7f6320a64fb07f371e 100644 (file)
@@ -428,6 +428,22 @@ Access Control Lists
 
   :param str fname: The path to a file containing a list of netmasks. Empty lines or lines starting with "#" are ignored.
 
+.. function:: setProxyProtocolACL(netmasks)
+
+  .. versionadded:: 1.6.0
+
+  Set the list of netmasks from which a Proxy Protocol header will be accepted, over UDP, TCP and DNS over TLS. The default is empty. Note that, if :func:`setProxyProtocolApplyACLToProxiedClients` is set (default is false), the general ACL will be applied to the source IP address as seen by dnsdist first, but also to the source IP address provided in the Proxy Protocol header.
+
+  :param {str} netmasks: A table of CIDR netmask, e.g. ``{"192.0.2.0/24", "2001:DB8:14::/56"}``. Without a subnetmask, only the specific address is allowed.
+
+.. function:: setProxyProtocolApplyACL(apply)
+
+  .. versionadded:: 1.6.0
+
+  Whether the general ACL should be applied to the source IP address provided in the Proxy Protocol header, in addition to being applied to the source IP address as seen by dnsdist first.
+
+  :param bool apply: Whether it should be applied or not (default is false).
+
 .. function:: showACL()
 
   Print a list of all netmasks allowed to send queries over UDP, TCP, DNS over TLS and DNS over HTTPS. See :ref:`ACL` for more information.
@@ -1504,6 +1520,14 @@ Other functions
 
   Set to true (defaults to false) to allow empty responses (qdcount=0) with a NoError or NXDomain rcode (default) from backends. dnsdist drops these responses by default because it can't match them against the initial query since they don't contain the qname, qtype and qclass, and therefore the risk of collision is much higher than with regular responses.
 
+.. function:: setProxyProtocolMaximumPayloadSize(size)
+
+  .. versionadded:: 1.6.0
+
+  Set the maximum size of a Proxy Protocol payload that dnsdist is willing to accept, in bytes. The default is 512, which is more than enough except for very large TLV data. This setting can't be set to a value lower than 16 since it would deny of Proxy Protocol headers.
+
+  :param int size: The maximum size in bytes (default is 512)
+
 .. function:: makeIPCipherKey(password) -> string
 
   .. versionadded:: 1.4.0
index 0a69752eea8908482f73bdd836d0819b07aec53a..748a69b98f8453e98a4f8e5f29e2c555103b2598 100644 (file)
@@ -76,6 +76,15 @@ This state can be modified from the various hooks.
 
   It also supports the following methods:
 
+  .. method:: DNSQuestion:addProxyProtocolValue(type, value)
+
+    .. versionadded:: 1.6.0
+
+    Add a proxy protocol TLV entry of type ``type`` and ``value`` to the current query.
+
+    :param int type: The type of the new value, ranging from 0 to 255 (both included)
+    :param str value: The binary-safe value
+
   .. method:: DNSQuestion:getDO() -> bool
 
     .. versionadded:: 1.2.0
@@ -132,6 +141,14 @@ This state can be modified from the various hooks.
 
     :returns: The scheme of the DoH query, for example ``http`` or ``https``
 
+  .. method:: DNSQuestion:getProxyProtocolValues() -> table
+
+    .. versionadded:: 1.6.0
+
+    Return a table of the Proxy Protocol values currently set for this query.
+
+    :returns: A table whose keys are types and values are binary-safe strings
+
   .. method:: DNSQuestion:getServerNameIndication() -> string
 
     .. versionadded:: 1.4.0
index 3bfa7f153f324d693f1c59cd3133c36dd94babdd..928999670403f82414bd1e0c0dffdc0597f0a2e1 100644 (file)
@@ -762,6 +762,16 @@ These ``DNSRule``\ s be one of the following items:
 
   :param double probability: Probability of a match
 
+.. function:: ProxyProtocolValueRule(type [, value])
+
+  .. versionadded:: 1.6.0
+
+  Matches queries that have a proxy protocol TVL value of the specified type. If ``value`` is set,
+  the content of the value should also match the content of ``value``.
+
+  :param int type: The type of the value, ranging from 0 to 255 (both included)
+  :param str value: The optional binary-safe value to match
+
 .. function:: QClassRule(qclass)
 
   Matches queries with the specified ``qclass``.
@@ -965,6 +975,19 @@ Actions
 Some actions allow further processing of rules, this is noted in their description.
 The following actions exist.
 
+.. function:: AddProxyProtocolValueAction(type, value)
+
+  .. versionadded:: 1.6.0
+
+  Add a Proxy-Protocol Type-Length value to be sent to the server along with this query. It does not replace any
+  existing value with the same type but adds a new value.
+  Be careful that Proxy Protocol values are sent once at the beginning of the TCP connection for TCP and DoT queries.
+  That means that values received on an incoming TCP connection will be inherited by subsequent queries received over
+  the same incoming TCP connection, if any, but values set to a query will not be inherited by subsequent queries.
+
+  :param int type: The type of the value to send, ranging from 0 to 255 (both included)
+  :param str value: The binary-safe value
+
 .. function:: AllowAction()
 
   Let these packets go through.
index 5e62e9ea0114f9d4de5fdf70ccccca62d72fc721..2b07f63f963e5b41dcc8dbcfb757fd4d83fcc8ad 100644 (file)
@@ -112,7 +112,7 @@ std::string makeProxyHeader(bool tcp, const ComboAddress& source, const ComboAdd
 /* returns: number of bytes consumed (positive) after successful parse
          or number of bytes missing (negative)
          or unfixable parse error (0)*/
-ssize_t isProxyHeaderComplete(const std::string& header, bool* proxy, bool* tcp, size_t* addrSizeOut, uint8_t* protocolOut)
+template<typename Container> ssize_t isProxyHeaderComplete(const Container& header, bool* proxy, bool* tcp, size_t* addrSizeOut, uint8_t* protocolOut)
 {
   static const size_t addr4Size = sizeof(ComboAddress::sin4.sin_addr.s_addr);
   static const size_t addr6Size = sizeof(ComboAddress::sin6.sin6_addr.s6_addr);
@@ -125,7 +125,7 @@ ssize_t isProxyHeaderComplete(const std::string& header, bool* proxy, bool* tcp,
     return -(s_proxyProtocolMinimumHeaderSize - header.size());
   }
 
-  if (header.compare(0, proxymagic.size(), proxymagic) != 0) {
+  if (std::memcmp(&header.at(0), &proxymagic.at(0), proxymagic.size()) != 0) {
     // wrong magic, can not be a proxy header
     return 0;
   }
@@ -208,7 +208,7 @@ ssize_t isProxyHeaderComplete(const std::string& header, bool* proxy, bool* tcp,
 /* returns: number of bytes consumed (positive) after successful parse
          or number of bytes missing (negative)
          or unfixable parse error (0)*/
-ssize_t parseProxyHeader(const std::string& header, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values)
+template<typename Container> ssize_t parseProxyHeader(const Container& header, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values)
 {
   size_t addrSize = 0;
   uint8_t protocol = 0;
@@ -220,9 +220,9 @@ ssize_t parseProxyHeader(const std::string& header, bool& proxy, ComboAddress& s
   size_t pos = s_proxyProtocolMinimumHeaderSize;
 
   if (proxy) {
-    source = makeComboAddressFromRaw(protocol, &header.at(pos), addrSize);
+    source = makeComboAddressFromRaw(protocol, reinterpret_cast<const char*>(&header.at(pos)), addrSize);
     pos = pos + addrSize;
-    destination = makeComboAddressFromRaw(protocol, &header.at(pos), addrSize);
+    destination = makeComboAddressFromRaw(protocol, reinterpret_cast<const char*>(&header.at(pos)), addrSize);
     pos = pos + addrSize;
     source.setPort((static_cast<uint8_t>(header.at(pos)) << 8) + static_cast<uint8_t>(header.at(pos+1)));
     pos = pos + sizeof(uint16_t);
@@ -243,7 +243,7 @@ ssize_t parseProxyHeader(const std::string& header, bool& proxy, ComboAddress& s
         return 0;
       }
 
-      values.push_back({ std::string(&header.at(pos), len), type });
+      values.push_back({ std::string(reinterpret_cast<const char*>(&header.at(pos)), len), type });
       pos += len;
     }
     else {
@@ -255,3 +255,9 @@ ssize_t parseProxyHeader(const std::string& header, bool& proxy, ComboAddress& s
 
   return pos;
 }
+
+#include "noinitvector.hh"
+template ssize_t isProxyHeaderComplete<std::string>(const std::string& header, bool* proxy, bool* tcp, size_t* addrSizeOut, uint8_t* protocolOut);
+template ssize_t isProxyHeaderComplete<PacketBuffer>(const PacketBuffer& header, bool* proxy, bool* tcp, size_t* addrSizeOut, uint8_t* protocolOut);
+template ssize_t parseProxyHeader<std::string>(const std::string& header, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values);
+template ssize_t parseProxyHeader<PacketBuffer>(const PacketBuffer& header, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values);
index 66b8f0aaf4fdbc564b386ccbadcb0b8717ca365e..97f7ac777e6b8b9ecc3e3b6365cdad7cd716e91d 100644 (file)
@@ -38,9 +38,9 @@ std::string makeProxyHeader(bool tcp, const ComboAddress& source, const ComboAdd
 /* returns: number of bytes consumed (positive) after successful parse
          or number of bytes missing (negative)
          or unfixable parse error (0)*/
-ssize_t isProxyHeaderComplete(const std::string& header, bool* proxy=nullptr, bool* tcp=nullptr, size_t* addrSizeOut=nullptr, uint8_t* protocolOut=nullptr);
+template<typename Container> ssize_t isProxyHeaderComplete(const Container& header, bool* proxy=nullptr, bool* tcp=nullptr, size_t* addrSizeOut=nullptr, uint8_t* protocolOut=nullptr);
 
 /* returns: number of bytes consumed (positive) after successful parse
          or number of bytes missing (negative)
          or unfixable parse error (0)*/
-ssize_t parseProxyHeader(const std::string& payload, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values);
+template<typename Container> ssize_t parseProxyHeader(const Container& header, bool& proxy, ComboAddress& source, ComboAddress& destination, bool& tcp, std::vector<ProxyProtocolValue>& values);