]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Implement SNIRule for DoT
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 15 May 2019 15:04:09 +0000 (17:04 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 20 May 2019 09:10:16 +0000 (11:10 +0200)
12 files changed:
pdns/dnsdist-console.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsdistdist/tcpiohandler.cc
pdns/tcpiohandler.hh
regression-tests.dnsdist/configCA.conf
regression-tests.dnsdist/configServer.conf
regression-tests.dnsdist/runtests
regression-tests.dnsdist/test_TLS.py

index f20f3cf8f671c30994f3d1e7951bd38d0cfa0d43..65fcf8ab8814850ed85d1eb06a97a3d179905bdb 100644 (file)
@@ -546,6 +546,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "showVersion", true, "", "show the current version" },
   { "shutdown", true, "", "shut down `dnsdist`" },
   { "SkipCacheAction", true, "", "Don’t lookup the cache for this query, don’t store the answer" },
+  { "SNIRule", true, "name", "Create a rule which matches on the incoming TLS SNI value, if any (DoT or DoH)" },
   { "snmpAgent", true, "enableTraps [, masterSocket]", "enable `SNMP` support. `enableTraps` is a boolean indicating whether traps should be sent and `masterSocket` an optional string specifying how to connect to the master agent"},
   { "SNMPTrapAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the query description"},
   { "SNMPTrapResponseAction", true, "[reason]", "send an SNMP trap, adding the optional `reason` string as the response description"},
index 5fbf7643dbb19ef4e2fe312b94cee25d7538674b..b781e8c79ac5ecef92b8e0b98dc95ea2e9015cb6 100644 (file)
@@ -295,6 +295,10 @@ void setupLuaRules()
     });
 #endif
 
+  g_lua.writeFunction("SNIRule", [](const std::string& name) {
+      return std::shared_ptr<DNSRule>(new SNIRule(name));
+  });
+
   g_lua.writeFunction("SuffixMatchNodeRule", [](const SuffixMatchNode& smn, boost::optional<bool> quiet) {
       return std::shared_ptr<DNSRule>(new SuffixMatchNodeRule(smn, quiet ? *quiet : false));
     });
index 6c437e270f697824099b8c5da839c250ff8dc963..8419011661fd5022e8df2547b6c7d906652bcb72 100644 (file)
@@ -808,6 +808,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, stru
   DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
   DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
   dq.dnsCryptQuery = std::move(dnsCryptQuery);
+  dq.sni = state->d_handler.getServerNameIndication();
 
   state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
   if (state->d_isXFR) {
index 3d32ee2a10466a689beb77180f1a28b34bc8a142..a74d1646736541acab3284156f461c6c7e5fe4d4 100644 (file)
@@ -75,6 +75,7 @@ struct DNSQuestion
 #endif
   Netmask ecs;
   boost::optional<Netmask> subnet;
+  std::string sni; /* Server Name Indication, if any (DoT or DoH) */
   const DNSName* qname{nullptr};
   const ComboAddress* local{nullptr};
   const ComboAddress* remote{nullptr};
index a25d8572c058d9b63f2356e0254a5f919ab05ada..4827a6aa0274ce1096aee356b3b25f2ac8cae56c 100644 (file)
@@ -525,6 +525,24 @@ private:
 };
 #endif
 
+class SNIRule : public DNSRule
+{
+public:
+  SNIRule(const std::string& name) : d_sni(name)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->sni == d_sni;
+  }
+  string toString() const override
+  {
+    return "SNI == " + d_sni;
+  }
+private:
+  std::string d_sni;
+};
+
 class SuffixMatchNodeRule : public DNSRule
 {
 public:
index ae38919755feb2fab344c0c324ee6f758345f01c..546063d311bc1e4aedfb42ba9379d4a87932d5a7 100644 (file)
@@ -756,6 +756,15 @@ These ``DNSRule``\ s be one of the following items:
 
   :param str regex: The regular expression to match the QNAME.
 
+.. function:: SNIRule(name)
+  .. versionadded:: 1.4.0
+
+  Matches against the TLS Server Name Indication value sent by the client, if any. Only makes
+  sense for DoT or DoH, and for that last one matching on the HTTP Host header might provide
+  more consistent results.
+
+  :param str name: The exact SNI name to match.
+
 .. function:: SuffixMatchNodeRule(smn[, quiet])
 
   Matches based on a group of domain suffixes for rapid testing of membership.
index 6e77c7840aa30b78354f8c1e190798f1e98a6a07..bef5ce2faa9abe9699b0f63786f344fb01a910db 100644 (file)
@@ -352,6 +352,7 @@ public:
 
     return got;
   }
+
   void close() override
   {
     if (d_conn) {
@@ -359,6 +360,17 @@ public:
     }
   }
 
+  std::string getServerNameIndication()
+  {
+    if (d_conn) {
+      const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
+      if (value) {
+        return std::string(value);
+      }
+    }
+    return std::string();
+  }
+
 private:
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
   unsigned int d_timeout;
@@ -860,6 +872,23 @@ public:
     return got;
   }
 
+  std::string getServerNameIndication()
+  {
+    if (d_conn) {
+      unsigned int type;
+      size_t name_len = 256;
+      std::string sni;
+      sni.resize(name_len);
+
+      int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
+      if (res == GNUTLS_E_SUCCESS) {
+        sni.resize(name_len);
+        return sni;
+      }
+    }
+    return std::string();
+  }
+
   void close() override
   {
     if (d_conn) {
index 061ab884a49f2c69cfef5acaa83fca2fee082c07..dd82281a7a81331f63b3abe98ab7eda7d6d49e62 100644 (file)
@@ -16,6 +16,7 @@ public:
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
   virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
   virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
+  virtual std::string getServerNameIndication() = 0;
   virtual void close() = 0;
 
 protected:
@@ -275,6 +276,14 @@ public:
     }
   }
 
+  std::string getServerNameIndication()
+  {
+    if (d_conn) {
+      return d_conn->getServerNameIndication();
+    }
+    return std::string();
+  }
+
 private:
   std::unique_ptr<TLSConnection> d_conn{nullptr};
   int d_socket{-1};
index fa5d736985504707f2f255e5c0187b6e641fe5d2..ddb427ce01c301b3e189d836681b0fc08d7e59f1 100644 (file)
@@ -18,3 +18,6 @@ countryName = NL
 [custom_extensions]
 basicConstraints = CA:true
 keyUsage = cRLSign, keyCertSign
+
+[CA_default]
+copy_extensions = copy
index 030cd5959f4215ac57518a549a2b3e79be83a19f..f1aa4c7feddbb3ff07fe529033afbf40108b29ff 100644 (file)
@@ -3,9 +3,18 @@ default_bits = 2048
 encrypt_key = no
 prompt = no
 distinguished_name = server_distinguished_name
+req_extensions = v3_req
 
 [server_distinguished_name]
 CN = tls.tests.dnsdist.org
 OU = PowerDNS.com BV
 countryName = NL
 
+[v3_req]
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+subjectAltName = @alt_names
+
+[alt_names]
+DNS.1 = tls.tests.dnsdist.org
+DNS.2 = powerdns.com
index 251e76a6a79d99f39591f580c8792adbb3268dd1..1f6de2ea1c0f1504063781cec1186d8c742814a9 100755 (executable)
@@ -54,7 +54,7 @@ openssl req -new -x509 -days 1 -extensions v3_ca -keyout ca.key -out ca.pem -nod
 # Generate a new server certificate request
 openssl req -new -newkey rsa:2048 -nodes -keyout server.key -out server.csr -config configServer.conf
 # Sign the server cert
-openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server.csr -out server.pem
+openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server.csr -out server.pem -extfile configServer.conf -extensions v3_req
 # Generate a chain
 cat server.pem ca.pem > server.chain
 
index b31c6c2f9bfcf4933bdc99103b5eb2febbf2a035..6973613b45e0f6f90394a9ca3351bf6a12233810 100644 (file)
@@ -12,6 +12,7 @@ class TestTLS(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     addTLSLocal("127.0.0.1:%s", "%s", "%s")
+    addAction(SNIRule("powerdns.com"), SpoofAction("1.2.3.4"))
     """
     _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
 
@@ -90,3 +91,44 @@ class TestTLS(DNSDistTest):
             receivedQuery.id = query.id
             self.assertEquals(query, receivedQuery)
             self.assertEquals(response, receivedResponse)
+
+    def testTLSSNIRouting(self):
+        """
+        TLS: SNI Routing
+        """
+        name = 'sni.tls.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.4')
+        expectedResponse.answer.append(rrset)
+
+        # this SNI should match so we should get a spoofed answer
+        conn = self.openTLSConnection(self._tlsServerPort, 'powerdns.com', self._caCert)
+
+        self.sendTCPQueryOverConnection(conn, query, response=None)
+        receivedResponse = self.recvTCPResponseOverConnection(conn, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        # this one should not
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+        self.sendTCPQueryOverConnection(conn, query, response=response)
+        (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)