]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: [WIP] Add unit tests for the TCP stack
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 9 Feb 2021 17:56:09 +0000 (18:56 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 2 Mar 2021 09:50:46 +0000 (10:50 +0100)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/test-dnsdisttcp_cc.cc [new file with mode: 0644]
pdns/epollmplexer.cc
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index 7e159311a0a9d4a76205c6d3937108e430405d52..ee44e4025333234cc47851fb7db51bf2f34215de 100644 (file)
@@ -64,6 +64,8 @@ size_t g_maxTCPConnectionDuration{0};
 size_t g_maxTCPConnectionsPerClient{0};
 size_t g_tcpInternalPipeBufferSize{0};
 uint16_t g_downstreamTCPCleanupInterval{60};
+int g_tcpRecvTimeout{2};
+int g_tcpSendTimeout{2};
 bool g_useTCPSinglePipe{false};
 
 class DownstreamConnectionsManager
@@ -703,7 +705,7 @@ void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcpar
   }
 
   struct timeval now;
-  gettimeofday(&now, 0);
+  gettimeofday(&now, nullptr);
   handleIO(conn, now);
 }
 
@@ -1008,7 +1010,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
     g_tcpclientthreads->decrementQueuedCount();
 
     struct timeval now;
-    gettimeofday(&now, 0);
+    gettimeofday(&now, nullptr);
     auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
     delete citmp;
     citmp = nullptr;
@@ -1036,7 +1038,7 @@ static void tcpClientThread(int pipefd)
 
   data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
   struct timeval now;
-  gettimeofday(&now, 0);
+  gettimeofday(&now, nullptr);
   time_t lastTCPCleanup = now.tv_sec;
   time_t lastTimeoutScan = now.tv_sec;
 
@@ -1054,7 +1056,7 @@ static void tcpClientThread(int pipefd)
       data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd)
       {
         struct timeval lnow;
-        gettimeofday(&lnow, 0);
+        gettimeofday(&lnow, nullptr);
         cerr<<"- "<<isRead<<" "<<fd<<": "<<" "<<(ttd.tv_sec-lnow.tv_sec)<<endl;
         if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
           auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
index bf68654f2c0359cabe0c292f2a3e21f8860a77ef..8aa49a489fbb716088401563cf6c6aecf7c50b0b 100644 (file)
@@ -133,8 +133,6 @@ GlobalStateHolder<servers_t> g_dstates;
 GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
 GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
 DNSAction::Action g_dynBlockAction = DNSAction::Action::Drop;
-int g_tcpRecvTimeout{2};
-int g_tcpSendTimeout{2};
 int g_udpTimeout{2};
 
 bool g_servFailOnNoPolicy{false};
index 41aa31b083601a7548ccabf44f0c3f126abd7137..ae01515a2b2d1ac6cf4612d5ed102696415c636c 100644 (file)
@@ -230,6 +230,7 @@ testrunner_SOURCES = \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-kvs.cc dnsdist-kvs.hh \
+       dnsdist-idstate.cc \
        dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
        dnsdist-lua-bindings-dnsquestion.cc \
        dnsdist-lua-bindings-kvs.cc \
@@ -237,7 +238,10 @@ testrunner_SOURCES = \
        dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
        dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
        dnsdist-lua-vars.cc \
+       dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
+       dnsdist-tcp.cc \
+       dnsdist-tcp-downstream.cc \
        dnsdist-xpf.cc dnsdist-xpf.hh \
        dnsdist.hh \
        dnslabeltext.cc \
@@ -275,6 +279,7 @@ testrunner_SOURCES = \
        test-dnsdistpacketcache_cc.cc \
        test-dnsdistrings_cc.cc \
        test-dnsdistrules_cc.cc \
+       test-dnsdisttcp_cc.cc \
        test-dnsparser_cc.cc \
        test-iputils_hh.cc \
        test-luawrapper.cc \
diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc
new file mode 100644 (file)
index 0000000..0fc57aa
--- /dev/null
@@ -0,0 +1,481 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnswriter.hh"
+#include "dnsdist.hh"
+#include "dnsdist-rings.hh"
+#include "dnsdist-tcp-downstream.hh"
+#include "dnsdist-tcp-upstream.hh"
+
+struct DNSDistStats g_stats;
+GlobalStateHolder<NetmaskGroup> g_ACL;
+GlobalStateHolder<vector<DNSDistRuleAction> > g_rulactions;
+GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_resprulactions;
+GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_cachehitresprulactions;
+GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_selfansweredresprulactions;
+GlobalStateHolder<servers_t> g_dstates;
+
+QueryCount g_qcount;
+
+
+bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
+{
+  return false;
+}
+
+bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, bool muted)
+{
+  return false;
+}
+
+bool checkQueryHeaders(const struct dnsheader* dh)
+{
+  return true;
+}
+
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& qnameWireLength)
+{
+  return true;
+}
+
+uint64_t uptimeOfProcess(const std::string& str)
+{
+  return 0;
+}
+
+uint64_t getLatencyCount(const std::string&)
+{
+  return 0;
+}
+
+static std::function<ProcessQueryResult(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
+{
+  if (s_processQuery) {
+    return s_processQuery(dq, cs, holders, selectedBackend);
+  }
+
+  return ProcessQueryResult::Drop;
+}
+
+BOOST_AUTO_TEST_SUITE(test_dnsdisttcp_cc)
+
+struct ExpectedStep
+{
+public:
+  enum class ExpectedRequest { handshake, connect, read, write, close };
+
+  ExpectedStep(ExpectedRequest r, IOState n): ExpectedStep(r, n, 0)
+  {
+  }
+
+  ExpectedStep(ExpectedRequest r, IOState n, size_t b): request(r), nextState(n), bytes(b)
+  {
+  }
+
+  ExpectedRequest request;
+  IOState nextState;
+  size_t bytes{0};
+};
+
+static std::deque<ExpectedStep> s_steps;
+static ExpectedStep getStep()
+{
+  BOOST_REQUIRE(!s_steps.empty());
+  auto res = s_steps.front();
+  s_steps.pop_front();
+  return res;
+}
+
+static boost::optional<PacketBuffer> s_readBuffer;
+static PacketBuffer s_writeBuffer;
+
+std::ostream& operator<<(std::ostream &os, const ExpectedStep::ExpectedRequest d);
+
+std::ostream& operator<<(std::ostream &os, const ExpectedStep::ExpectedRequest d)
+{
+  static const std::vector<std::string> requests = { "handshake", "connect", "read", "write", "close" };
+  os<<requests.at(static_cast<size_t>(d));
+  return os;
+}
+
+class MockupTLSConnection : public TLSConnection
+{
+private:
+public:
+  ~MockupTLSConnection() { }
+
+  IOState tryHandshake() override
+  {
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::handshake);
+    return step.nextState;
+  }
+
+  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+  {
+    if (buffer.size() < toWrite || pos >= toWrite) {
+      throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
+    }
+
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::write);
+
+    if (step.bytes == 0) {
+      throw std::runtime_error("Remote host closed the connection");
+    }
+
+    toWrite -= pos;
+    BOOST_REQUIRE_GE(buffer.size(), pos + toWrite);
+
+    if (step.bytes < toWrite) {
+      toWrite = step.bytes;
+    }
+
+    s_writeBuffer.insert(s_writeBuffer.end(), buffer.begin() + pos, buffer.begin() + pos + toWrite);
+    pos += toWrite;
+
+    return step.nextState;
+  }
+
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
+  {
+    if (buffer.size() < toRead || pos >= toRead) {
+      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
+    }
+
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::read);
+
+    if (step.bytes == 0) {
+      throw std::runtime_error("Remote host closed the connection");
+    }
+
+    if (s_readBuffer) {
+      toRead -= pos;
+
+      if (step.bytes < toRead) {
+        toRead = step.bytes;
+      }
+      BOOST_REQUIRE_GE(buffer.size(), toRead);
+      BOOST_REQUIRE_GE(s_readBuffer->size(), toRead);
+
+      std::copy(s_readBuffer->begin(), s_readBuffer->begin() + toRead, buffer.begin() + pos);
+      pos += toRead;
+      s_readBuffer->erase(s_readBuffer->begin(), s_readBuffer->begin() + toRead);
+    }
+
+    return step.nextState;
+  }
+
+  void close() override
+  {
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::close);
+  }
+
+  bool hasBufferedData() const override
+  {
+    return false;
+  }
+
+  std::string getServerNameIndication() const override
+  {
+    return "";
+  }
+
+  LibsslTLSVersion getTLSVersion() const override
+  {
+    return LibsslTLSVersion::TLS13;
+  }
+
+  bool hasSessionBeenResumed() const override
+  {
+    return false;
+  }
+
+  /* unused in that context, don't bother */
+  void doHandshake() override
+  {
+  }
+
+  void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
+  {
+  }
+
+  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+  {
+    return IOState::Done;
+  }
+
+  size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) override
+  {
+    return 0;
+  }
+
+  size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) override
+  {
+    return 0;
+  }
+};
+
+class MockupTLSCtx : public TLSCtx
+{
+public:
+  ~MockupTLSCtx()
+  {
+  }
+
+  std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) override
+  {
+    return std::make_unique<MockupTLSConnection>();
+  }
+
+  void rotateTicketsKey(time_t now) override
+  {
+  }
+
+  size_t getTicketsKeysCount() override
+  {
+    return 0;
+  }
+
+  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
+  {
+    return nullptr;
+  }
+};
+
+class MockupFDMultiplexer : public FDMultiplexer
+{
+public:
+  MockupFDMultiplexer()
+  {
+  }
+
+  ~MockupFDMultiplexer()
+  {
+  }
+
+  int run(struct timeval* tv, int timeout=500) override
+  {
+    int ret = 0;
+
+    gettimeofday(tv, nullptr); // MANDATORY
+
+    for (const auto fd : ready) {
+      {
+        const auto& it = d_readCallbacks.find(fd);
+
+        if (it != d_readCallbacks.end()) {
+          it->d_callback(it->d_fd, it->d_parameter);
+          continue; // so we don't refind ourselves as writable!
+        }
+      }
+
+      {
+        const auto& it = d_writeCallbacks.find(fd);
+
+        if (it != d_writeCallbacks.end()) {
+          it->d_callback(it->d_fd, it->d_parameter);
+        }
+      }
+    }
+
+    return ret;
+  }
+
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override
+  {
+  }
+
+  void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override
+  {
+    accountingAddFD(cbmap, fd, toDo, parameter, ttd);
+  }
+
+  void removeFD(callbackmap_t& cbmap, int fd) override
+  {
+    accountingRemoveFD(cbmap, fd);
+  }
+
+  void alterFD(callbackmap_t& from, callbackmap_t& to, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd) override
+  {
+    accountingRemoveFD(from, fd);
+    accountingAddFD(to, fd, toDo, parameter, ttd);
+  }
+
+  string getName() const override
+  {
+    return "mockup";
+  }
+
+  void setReady(int fd)
+  {
+    ready.insert(fd);
+  }
+
+  void setNotdReady(int fd)
+  {
+    ready.erase(fd);
+  }
+
+private:
+  std::set<int> ready;
+};
+
+BOOST_AUTO_TEST_CASE(test_IncomingConnection)
+{
+  //int sockets[2];
+  //int res = socketpair(AF_UNIX, SOCK_STREAM, 0, sockets);
+  //BOOST_REQUIRE_EQUAL(res, 0);
+  ComboAddress local("192.0.2.1:80");
+  ClientState localCS(local, true, false, false, "", {});
+  auto tlsCtx = std::make_shared<MockupTLSCtx>();
+  localCS.tlsFrontend = std::make_shared<TLSFrontend>(tlsCtx);
+
+  TCPClientThreadData threadData;
+  threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
+  PacketBuffer query;
+  GenericDNSPacketWriter<PacketBuffer> pwQ(query, DNSName("powerdns.com."), QType::A, QClass::IN, 0);
+  pwQ.getHeader()->rd = 1;
+
+  uint16_t querySize = static_cast<uint16_t>(query.size());
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(querySize / 256), static_cast<uint8_t>(querySize % 256) };
+  query.insert(query.begin(), sizeBytes, sizeBytes + 2);
+
+  g_verbose = true;
+
+  {
+    /* drop right away */
+    s_readBuffer = query;
+    s_writeBuffer.clear();
+    s_steps = {
+      { ExpectedStep::ExpectedRequest::handshake, IOState::Done },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 2 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, query.size() - 2 },
+      { ExpectedStep::ExpectedRequest::close, IOState::Done },
+    };
+    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      return ProcessQueryResult::Drop;
+    };
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0);
+  }
+
+  {
+    /* self-generated REFUSED, client closes connection right away */
+    s_readBuffer = query;
+    s_writeBuffer.clear();
+    s_steps = {
+      { ExpectedStep::ExpectedRequest::handshake, IOState::Done },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 2 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, query.size() - 2 },
+      { ExpectedStep::ExpectedRequest::write, IOState::Done, 65537 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 0 },
+      { ExpectedStep::ExpectedRequest::close, IOState::Done },
+    };
+    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      // Would be nicer to actually turn it into a response
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
+  }
+
+  {
+    /* short read on the size, then on the query itself,
+       self-generated REFUSED, short write on the response, 
+       client closes connection right away */
+    s_readBuffer = query;
+    s_writeBuffer.clear();
+    s_steps = {
+      { ExpectedStep::ExpectedRequest::handshake, IOState::Done },
+      { ExpectedStep::ExpectedRequest::read, IOState::NeedRead, 1 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 1 },
+      { ExpectedStep::ExpectedRequest::read, IOState::NeedRead, query.size() - 3 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 1 },
+      { ExpectedStep::ExpectedRequest::write, IOState::NeedWrite, query.size() - 1},
+      { ExpectedStep::ExpectedRequest::write, IOState::Done, 1 },
+      { ExpectedStep::ExpectedRequest::read, IOState::Done, 0 },
+      { ExpectedStep::ExpectedRequest::close, IOState::Done },
+    };
+    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      // Would be nicer to actually turn it into a response
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    /* mark the incoming FD as always ready */
+    dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+      threadData.mplexer->run(&now);
+    }
+    BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
+  }
+
+  {
+#if 0
+    /* 10k self-generated REFUSED on the same connection */
+    size_t count = 10000;
+    s_readBuffer->clear();
+    s_writeBuffer.clear();
+    s_steps = { { ExpectedStep::ExpectedRequest::handshake, IOState::Done } };
+
+    for (size_t idx = 0; idx < count; idx++) {
+      s_readBuffer->insert(s_readBuffer->end(), query.begin(), query.end());
+      s_steps.push_back({ ExpectedStep::ExpectedRequest::read, IOState::Done, 2 });
+      s_steps.push_back({ ExpectedStep::ExpectedRequest::read, IOState::Done, query.size() - 2 });
+      s_steps.push_back({ ExpectedStep::ExpectedRequest::write, IOState::Done, query.size() + 2 });
+    };
+    s_steps.push_back({ ExpectedStep::ExpectedRequest::read, IOState::Done, 0 });
+    s_steps.push_back({ ExpectedStep::ExpectedRequest::close, IOState::Done });
+
+    size_t counter = 0;
+    s_processQuery = [&counter](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      // Would be nicer to actually turn it into a response
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count);
+#endif
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
index 1a4b2a3b501ccbfb314b3661b3ee038db550cef6..7fa63e6237ba21b43fa693ca9cd2fad43f483a30 100644 (file)
@@ -107,7 +107,7 @@ void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo
   eevent.data.u64=0; // placate valgrind (I love it so much)
   eevent.data.fd=fd;
 
-  if(epoll_ctl(d_epollfd, EPOLL_CTL_ADD, fd, &eevent) < 0) {
+  if (epoll_ctl(d_epollfd, EPOLL_CTL_ADD, fd, &eevent) < 0) {
     cbmap.erase(fd);
     throw FDMultiplexerException("Adding fd to epoll set: "+stringerror());
   }
index 76a18e5611bbd06c7234dc506e1fd280b6f2776a..07d6bf0a2391a4def867830dbe0d0c003d39194f 100644 (file)
@@ -234,7 +234,7 @@ public:
     }
   }
 
-  IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
   {
     do {
       int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
@@ -846,7 +846,7 @@ public:
     throw std::runtime_error("Error accepting a new connection");
   }
 
-  IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
   {
     do {
       ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
index 1487ad883ad2d3b8ee9d79568aff5ff875abc8d9..e6c4983aa9ccba21437f67501aa268c71e4fa469 100644 (file)
@@ -18,7 +18,7 @@ public:
   virtual IOState tryHandshake() = 0;
   virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
-  virtual IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
+  virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
   virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) = 0;
   virtual bool hasBufferedData() const = 0;
   virtual std::string getServerNameIndication() const = 0;
@@ -106,6 +106,14 @@ protected:
 class TLSFrontend
 {
 public:
+  TLSFrontend()
+  {
+  }
+
+  TLSFrontend(std::shared_ptr<TLSCtx> ctx): d_ctx(std::move(ctx))
+  {
+  }
+
   bool setupTLS();
 
   void rotateTicketsKey(time_t now)
@@ -122,7 +130,7 @@ public:
     }
   }
 
-  std::shared_ptr<TLSCtx> getContext()
+  std::shared_ptr<TLSCtx>& getContext()
   {
     return d_ctx;
   }
@@ -173,7 +181,7 @@ public:
   ComboAddress d_addr;
   std::string d_provider;
 
-private:
+protected:
   std::shared_ptr<TLSCtx> d_ctx{nullptr};
 };
 
@@ -302,7 +310,7 @@ public:
      return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
      would block.
   */
-  IOState tryWrite(PacketBuffer& buffer, size_t& pos, size_t toWrite)
+  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite)
   {
     if (buffer.size() < toWrite || pos >= toWrite) {
       throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));