]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Add helper to get map counts and remove unneeded error function.
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Fri, 24 Dec 2021 12:10:29 +0000 (13:10 +0100)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Fri, 2 Sep 2022 12:22:48 +0000 (14:22 +0200)
pdns/distributor.hh
pdns/dnspacket.cc
pdns/dnspacket.hh
pdns/gss_context.cc
pdns/gss_context.hh
pdns/tcpreceiver.cc
pdns/test-distributor_hh.cc
pdns/tkey.cc

index c3fce94cd5e489c760096928ca175d5dd158c2e9..a3dd1d08ce030ca9da4beddd6213a73ceef3e983 100644 (file)
@@ -245,6 +245,7 @@ retry:
       }
 
       QD->callback(a, QD->start);
+      QD->Q.cleanupGSS(a->d.rcode);
       QD.reset();
     }
 
@@ -306,6 +307,8 @@ retry:
     }
   }
   callback(a, start);
+  q.cleanupGSS(a->d.rcode);
+
   return 0;
 }
 
index 054e3539c63ff9b56944393d832773e7ac71c0fb..b58e1f1ca6401fb12b98021efbea7a7974ba9995 100644 (file)
@@ -759,3 +759,12 @@ bool DNSPacket::checkForCorrectTSIG(UeberBackend* B, DNSName* keyname, string* s
 const DNSName& DNSPacket::getTSIGKeyname() const {
   return d_tsigkeyname;
 }
+
+void DNSPacket::cleanupGSS(int rcode)
+{
+  if (rcode != RCode::NoError && d_tsig_algo == TSIG_GSS && !getTSIGKeyname().empty()) {
+    GssContext ctx(getTSIGKeyname());
+    ctx.destroy();
+  }
+}
index 097bb5b63f72b9f6952e81e75999eeccfcb270f1..02838b099b139da5ae50447de18c81f8eecfdf65 100644 (file)
@@ -172,6 +172,7 @@ public:
   static bool s_doEDNSSubnetProcessing;
   static bool s_doEDNSCookieProcessing;
   static string s_EDNSCookieKey;
+  void cleanupGSS(int rcode);
 
 private:
   void pasteQ(const char *question, int length); //!< set the question of this packet, useful for crafting replies
index 818322eca9e52af8498e67d2afd3d55a84d7038e..ec6bdd7279e30656d4600f5eb75b28cad151c04f 100644 (file)
@@ -27,6 +27,7 @@
 
 #ifndef ENABLE_GSS_TSIG
 
+std::tuple<size_t, size_t, size_t> GssContext::getCounts() { return std::make_tuple<size_t, size_t, size_t>(0, 0, 0); }
 bool GssContext::supported() { return false; }
 GssContext::GssContext() :
   d_error(GSS_CONTEXT_UNSUPPORTED), d_type(GSS_CONTEXT_NONE) {}
@@ -49,50 +50,20 @@ GssContextError GssContext::getError() { return GSS_CONTEXT_UNSUPPORTED; }
 
 #else
 
-static string gsserror(OM_uint32 status_code)
-{
-  OM_uint32 maj_status;
-  OM_uint32 min_status;
-  OM_uint32 message_context = 0;
-  gss_buffer_desc status_string;
-  std::basic_ostringstream<char> ret;
-  bool first = true;
-  do {
-    if (!first) {
-      ret << '/';
-    } else {
-      first = false;
-    }
-    maj_status = gss_display_status(&min_status,
-                                    status_code,
-                                    GSS_C_GSS_CODE,
-                                    GSS_C_NO_OID,
-                                    &message_context,
-                                    &status_string);
-    if (maj_status == GSS_S_COMPLETE) {
-      ret << string(static_cast<char*>(status_string.value), status_string.length);
-      gss_release_buffer(&min_status, &status_string);
-    } else {
-      // XXX to release or not to release?
-      ret << std::to_string(status_code);
-    }
-  } while (message_context != 0);
-  return ret.str();
-}
-
 class GssCredential : boost::noncopyable
 {
 public:
   GssCredential(const std::string& name, const gss_cred_usage_t usage) :
-    d_valid(false), d_nameS(name), d_name(GSS_C_NO_NAME), d_cred(GSS_C_NO_CREDENTIAL), d_usage(usage)
+    d_nameS(name), d_usage(usage)
   {
     gss_buffer_desc buffer;
 
     if (!name.empty()) {
       buffer.length = name.size();
       buffer.value = const_cast<void*>(static_cast<const void*>(name.c_str()));
-      d_maj = gss_import_name(&d_min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &d_name);
-      if (d_maj != GSS_S_COMPLETE) {
+      OM_uint32 min;
+      auto maj = gss_import_name(&min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &d_name);
+      if (maj != GSS_S_COMPLETE) {
         d_name = GSS_C_NO_NAME;
         d_valid = false;
         return;
@@ -105,25 +76,28 @@ public:
   ~GssCredential()
   {
     OM_uint32 tmp_maj __attribute__((unused)), tmp_min __attribute__((unused));
-    if (d_cred != GSS_C_NO_CREDENTIAL)
+    if (d_cred != GSS_C_NO_CREDENTIAL) {
       tmp_maj = gss_release_cred(&tmp_min, &d_cred);
-    if (d_name != GSS_C_NO_NAME)
+    }
+    if (d_name != GSS_C_NO_NAME) {
       tmp_maj = gss_release_name(&tmp_min, &d_name);
+    }
   };
 
   bool expired() const
   {
-    if (d_expires == -1)
+    if (d_expires == -1) {
       return false;
+    }
     return time(nullptr) > d_expires;
   }
 
   bool renew()
   {
     OM_uint32 time_rec, tmp_maj __attribute__((unused)), tmp_min __attribute__((unused));
-    d_maj = gss_acquire_cred(&d_min, d_name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, d_usage, &d_cred, nullptr, &time_rec);
+    tmp_maj = gss_acquire_cred(&tmp_min, d_name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, d_usage, &d_cred, nullptr, &time_rec);
 
-    if (d_maj != GSS_S_COMPLETE) {
+    if (tmp_maj != GSS_S_COMPLETE) {
       d_valid = false;
       tmp_maj = gss_release_name(&tmp_min, &d_name);
       d_name = GSS_C_NO_NAME;
@@ -132,12 +106,11 @@ public:
 
     d_valid = true;
 
-    if (time_rec > GSS_C_INDEFINITE) {
-      d_expires = time(nullptr) + time_rec;
-    }
-    else {
-      d_expires = -1;
+    // We do not want forever, but a good time
+    if (time_rec == GSS_C_INDEFINITE) {
+      time_rec = 24 * 60 * 60;
     }
+    d_expires = time(nullptr) + time_rec;
 
     return true;
   }
@@ -147,32 +120,30 @@ public:
     return d_valid && !expired();
   }
 
-  OM_uint32 d_maj, d_min;
-
-  bool d_valid;
-  int64_t d_expires;
   std::string d_nameS;
-  gss_name_t d_name;
-  gss_cred_id_t d_cred;
   gss_cred_usage_t d_usage;
-};
+  gss_name_t d_name{GSS_C_NO_NAME};
+  gss_cred_id_t d_cred{GSS_C_NO_CREDENTIAL};
+  time_t d_expires{time(nullptr) + 60}; // partly initialized wil be cleaned up
+  bool d_valid{false};
+}; // GssCredential
 
-LockGuarded<std::map<std::string, std::shared_ptr<GssCredential>>> s_gss_accept_creds;
-LockGuarded<std::map<std::string, std::shared_ptr<GssCredential>>> s_gss_init_creds;
+static LockGuarded<std::unordered_map<std::string, std::shared_ptr<GssCredential>>> s_gss_accept_creds;
+static LockGuarded<std::unordered_map<std::string, std::shared_ptr<GssCredential>>> s_gss_init_creds;
 
 class GssSecContext : boost::noncopyable
 {
 public:
-  GssSecContext(std::shared_ptr<GssCredential> cred)
+  GssSecContext(std::shared_ptr<GssCredential> cred) :
+    d_cred(cred)
   {
     if (!cred->valid()) {
-      throw PDNSException("Invalid credential " + cred->d_nameS + ": " + gsserror(cred->d_maj));
+      throw PDNSException("Invalid credential " + cred->d_nameS);
     }
     d_cred = cred;
     d_state = GssStateInitial;
     d_ctx = GSS_C_NO_CONTEXT;
     d_expires = 0;
-    d_maj = d_min = 0;
     d_peer_name = GSS_C_NO_NAME;
     d_type = GSS_CONTEXT_NONE;
   }
@@ -188,12 +159,11 @@ public:
     }
   }
 
-  GssContextType d_type;
-  gss_ctx_id_t d_ctx;
-  gss_name_t d_peer_name;
-  int64_t d_expires;
   std::shared_ptr<GssCredential> d_cred;
-  OM_uint32 d_maj, d_min;
+  GssContextType d_type{GSS_CONTEXT_NONE};
+  gss_ctx_id_t d_ctx{GSS_C_NO_CONTEXT};
+  gss_name_t d_peer_name{GSS_C_NO_NAME};
+  time_t d_expires{time(nullptr) + 60}; // partly initialized wil be cleaned up
 
   enum
   {
@@ -202,9 +172,36 @@ public:
     GssStateComplete,
     GssStateError
   } d_state;
-};
+}; // GssSecContext
+
+static LockGuarded<std::unordered_map<DNSName, std::shared_ptr<GssSecContext>>> s_gss_sec_context;
 
-LockGuarded<std::map<DNSName, std::shared_ptr<GssSecContext>>> s_gss_sec_context;
+template <typename T>
+static void doExpire(T& m, time_t now)
+{
+  auto lock = m.lock();
+  for (auto i = lock->begin(); i != lock->end();) {
+    if (now > i->second->d_expires) {
+      i = lock->erase(i);
+    }
+    else {
+      ++i;
+    }
+  }
+}
+
+static void expire()
+{
+  static time_t s_last_expired;
+  time_t now = time(nullptr);
+  if (now - s_last_expired < 60) {
+    return;
+  }
+  s_last_expired = now;
+  doExpire(s_gss_init_creds, now);
+  doExpire(s_gss_accept_creds, now);
+  doExpire(s_gss_sec_context, now);
+}
 
 bool GssContext::supported() { return true; }
 
@@ -239,52 +236,52 @@ void GssContext::setLabel(const DNSName& label)
 {
   d_label = label;
   auto lock = s_gss_sec_context.lock();
-  if (lock->find(d_label) != lock->end()) {
-    d_ctx = (*lock)[d_label];
-    d_type = d_ctx->d_type;
+  auto it = lock->find(d_label);
+  if (it != lock->end()) {
+    d_secctx = it->second;
+    d_type = d_secctx->d_type;
   }
 }
 
 bool GssContext::expired()
 {
-  return (!d_ctx || (d_ctx->d_expires > -1 && d_ctx->d_expires < time(nullptr)));
+  return (!d_secctx || (d_secctx->d_expires > -1 && d_secctx->d_expires < time(nullptr)));
 }
 
 bool GssContext::valid()
 {
-  return (d_ctx && !expired() && d_ctx->d_state == GssSecContext::GssStateComplete);
+  return (d_secctx && !expired() && d_secctx->d_state == GssSecContext::GssStateComplete);
 }
 
 bool GssContext::init(const std::string& input, std::string& output)
 {
+  expire();
+
   OM_uint32 tmp_maj __attribute__((unused)), tmp_min __attribute__((unused));
   OM_uint32 maj, min;
   gss_buffer_desc recv_tok, send_tok, buffer;
   OM_uint32 flags;
   OM_uint32 expires;
 
-  std::shared_ptr<GssCredential> cred;
   if (d_label.empty()) {
     d_error = GSS_CONTEXT_INVALID;
     return false;
   }
 
   d_type = GSS_CONTEXT_INIT;
-
+  std::shared_ptr<GssCredential> cred;
   {
     auto lock = s_gss_init_creds.lock();
-    if (lock->find(d_localPrincipal) != lock->end()) {
-      cred = (*lock)[d_localPrincipal];
-    }
-    else {
-      (*lock)[d_localPrincipal] = std::make_shared<GssCredential>(d_localPrincipal, GSS_C_INITIATE);
-      cred = (*lock)[d_localPrincipal];
+    auto it = lock->find(d_localPrincipal);
+    if (it == lock->end()) {
+      it = lock->emplace(d_localPrincipal, std::make_shared<GssCredential>(d_localPrincipal, GSS_C_INITIATE)).first;
     }
+    cred = it->second;
   }
 
   // see if we can find a context in non-completed state
-  if (d_ctx) {
-    if (d_ctx->d_state != GssSecContext::GssStateNegotiate) {
+  if (d_secctx) {
+    if (d_secctx->d_state != GssSecContext::GssStateNegotiate) {
       d_error = GSS_CONTEXT_INVALID;
       return false;
     }
@@ -292,10 +289,10 @@ bool GssContext::init(const std::string& input, std::string& output)
   else {
     // make context
     auto lock = s_gss_sec_context.lock();
-    (*lock)[d_label] = std::make_shared<GssSecContext>(cred);
-    (*lock)[d_label]->d_type = d_type;
-    d_ctx = (*lock)[d_label];
-    d_ctx->d_state = GssSecContext::GssStateNegotiate;
+    d_secctx = std::make_shared<GssSecContext>(cred);
+    d_secctx->d_state = GssSecContext::GssStateNegotiate;
+    d_secctx->d_type = d_type;
+    (*lock)[d_label] = d_secctx;
   }
 
   recv_tok.length = input.size();
@@ -304,14 +301,14 @@ bool GssContext::init(const std::string& input, std::string& output)
   if (!d_peerPrincipal.empty()) {
     buffer.value = const_cast<void*>(static_cast<const void*>(d_peerPrincipal.c_str()));
     buffer.length = d_peerPrincipal.size();
-    maj = gss_import_name(&min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &(d_ctx->d_peer_name));
+    maj = gss_import_name(&min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &(d_secctx->d_peer_name));
     if (maj != GSS_S_COMPLETE) {
       processError("gss_import_name", maj, min);
       return false;
     }
   }
 
-  maj = gss_init_sec_context(&min, cred->d_cred, &(d_ctx->d_ctx), d_ctx->d_peer_name, GSS_C_NO_OID, GSS_C_MUTUAL_FLAG | GSS_C_REPLAY_FLAG, GSS_C_INDEFINITE, GSS_C_NO_CHANNEL_BINDINGS, &recv_tok, nullptr, &send_tok, &flags, &expires);
+  maj = gss_init_sec_context(&min, cred->d_cred, &d_secctx->d_ctx, d_secctx->d_peer_name, GSS_C_NO_OID, GSS_C_MUTUAL_FLAG | GSS_C_REPLAY_FLAG, GSS_C_INDEFINITE, GSS_C_NO_CHANNEL_BINDINGS, &recv_tok, nullptr, &send_tok, &flags, &expires);
 
   if (send_tok.length > 0) {
     output.assign(static_cast<char*>(send_tok.value), send_tok.length);
@@ -319,13 +316,12 @@ bool GssContext::init(const std::string& input, std::string& output)
   }
 
   if (maj == GSS_S_COMPLETE) {
-    if (expires > GSS_C_INDEFINITE) {
-      d_ctx->d_expires = time(nullptr) + expires;
+    // We do not want forever
+    if (expires == GSS_C_INDEFINITE) {
+      expires = 60;
     }
-    else {
-      d_ctx->d_expires = -1;
-    }
-    d_ctx->d_state = GssSecContext::GssStateComplete;
+    d_secctx->d_expires = time(nullptr) + expires;
+    d_secctx->d_state = GssSecContext::GssStateComplete;
     return true;
   }
   else if (maj != GSS_S_CONTINUE_NEEDED) {
@@ -337,33 +333,33 @@ bool GssContext::init(const std::string& input, std::string& output)
 
 bool GssContext::accept(const std::string& input, std::string& output)
 {
+  expire();
+
   OM_uint32 tmp_maj __attribute__((unused)), tmp_min __attribute__((unused));
   OM_uint32 maj, min;
   gss_buffer_desc recv_tok, send_tok;
   OM_uint32 flags;
   OM_uint32 expires;
 
-  std::shared_ptr<GssCredential> cred;
   if (d_label.empty()) {
     d_error = GSS_CONTEXT_INVALID;
     return false;
   }
 
   d_type = GSS_CONTEXT_ACCEPT;
+  std::shared_ptr<GssCredential> cred;
   {
     auto lock = s_gss_accept_creds.lock();
-    if (lock->find(d_localPrincipal) != lock->end()) {
-      cred = (*lock)[d_localPrincipal];
-    }
-    else {
-      (*lock)[d_localPrincipal] = std::make_shared<GssCredential>(d_localPrincipal, GSS_C_ACCEPT);
-      cred = (*lock)[d_localPrincipal];
+    auto it = lock->find(d_localPrincipal);
+    if (it == lock->end()) {
+      it = lock->emplace(d_localPrincipal, std::make_shared<GssCredential>(d_localPrincipal, GSS_C_ACCEPT)).first;
     }
+    cred = it->second;
   }
 
   // see if we can find a context in non-completed state
-  if (d_ctx) {
-    if (d_ctx->d_state != GssSecContext::GssStateNegotiate) {
+  if (d_secctx) {
+    if (d_secctx->d_state != GssSecContext::GssStateNegotiate) {
       d_error = GSS_CONTEXT_INVALID;
       return false;
     }
@@ -371,16 +367,16 @@ bool GssContext::accept(const std::string& input, std::string& output)
   else {
     // make context
     auto lock = s_gss_sec_context.lock();
-    (*lock)[d_label] = std::make_shared<GssSecContext>(cred);
-    (*lock)[d_label]->d_type = d_type;
-    d_ctx = (*lock)[d_label];
-    d_ctx->d_state = GssSecContext::GssStateNegotiate;
+    d_secctx = std::make_shared<GssSecContext>(cred);
+    d_secctx->d_state = GssSecContext::GssStateNegotiate;
+    d_secctx->d_type = d_type;
+    (*lock)[d_label] = d_secctx;
   }
 
   recv_tok.length = input.size();
   recv_tok.value = const_cast<void*>(static_cast<const void*>(input.c_str()));
 
-  maj = gss_accept_sec_context(&min, &(d_ctx->d_ctx), cred->d_cred, &recv_tok, GSS_C_NO_CHANNEL_BINDINGS, &(d_ctx->d_peer_name), nullptr, &send_tok, &flags, &expires, nullptr);
+  maj = gss_accept_sec_context(&min, &d_secctx->d_ctx, cred->d_cred, &recv_tok, GSS_C_NO_CHANNEL_BINDINGS, &d_secctx->d_peer_name, nullptr, &send_tok, &flags, &expires, nullptr);
 
   if (send_tok.length > 0) {
     output.assign(static_cast<char*>(send_tok.value), send_tok.length);
@@ -388,13 +384,12 @@ bool GssContext::accept(const std::string& input, std::string& output)
   }
 
   if (maj == GSS_S_COMPLETE) {
-    if (expires > GSS_C_INDEFINITE) {
-      d_ctx->d_expires = time(nullptr) + expires;
-    }
-    else {
-      d_ctx->d_expires = -1;
+    // We do not want forever
+    if (expires == GSS_C_INDEFINITE) {
+      expires = 60;
     }
-    d_ctx->d_state = GssSecContext::GssStateComplete;
+    d_secctx->d_expires = time(nullptr) + expires;
+    d_secctx->d_state = GssSecContext::GssStateComplete;
     return true;
   }
   else if (maj != GSS_S_CONTINUE_NEEDED) {
@@ -414,7 +409,7 @@ bool GssContext::sign(const std::string& input, std::string& output)
   recv_tok.length = input.size();
   recv_tok.value = const_cast<void*>(static_cast<const void*>(input.c_str()));
 
-  maj = gss_get_mic(&min, d_ctx->d_ctx, GSS_C_QOP_DEFAULT, &recv_tok, &send_tok);
+  maj = gss_get_mic(&min, d_secctx->d_ctx, GSS_C_QOP_DEFAULT, &recv_tok, &send_tok);
 
   if (send_tok.length > 0) {
     output.assign(static_cast<char*>(send_tok.value), send_tok.length);
@@ -436,11 +431,11 @@ bool GssContext::verify(const std::string& input, const std::string& signature)
   gss_buffer_desc sign_tok = GSS_C_EMPTY_BUFFER;
 
   recv_tok.length = input.size();
-  recv_tok.value =  const_cast<void*>(static_cast<const void*>(input.c_str()));
+  recv_tok.value = const_cast<void*>(static_cast<const void*>(input.c_str()));
   sign_tok.length = signature.size();
-  sign_tok.value =  const_cast<void*>(static_cast<const void*>(signature.c_str()));
+  sign_tok.value = const_cast<void*>(static_cast<const void*>(signature.c_str()));
 
-  maj = gss_verify_mic(&min, d_ctx->d_ctx, &recv_tok, &sign_tok, nullptr);
+  maj = gss_verify_mic(&min, d_secctx->d_ctx, &recv_tok, &sign_tok, nullptr);
 
   if (maj != GSS_S_COMPLETE) {
     processError("gss_get_mic", maj, min);
@@ -451,7 +446,11 @@ bool GssContext::verify(const std::string& input, const std::string& signature)
 
 bool GssContext::destroy()
 {
-  return false;
+  if (d_label.empty()) {
+    return false;
+  }
+  auto lock = s_gss_sec_context.lock();
+  return lock->erase(d_label) == 1;
 }
 
 void GssContext::setLocalPrincipal(const std::string& name)
@@ -475,8 +474,8 @@ bool GssContext::getPeerPrincipal(std::string& name)
   gss_buffer_desc value;
   OM_uint32 maj, min;
 
-  if (d_ctx->d_peer_name != GSS_C_NO_NAME) {
-    maj = gss_display_name(&min, d_ctx->d_peer_name, &value, nullptr);
+  if (d_secctx->d_peer_name != GSS_C_NO_NAME) {
+    maj = gss_display_name(&min, d_secctx->d_peer_name, &value, nullptr);
     if (maj == GSS_S_COMPLETE && value.length > 0) {
       name.assign(static_cast<char*>(value.value), value.length);
       maj = gss_release_buffer(&min, &value);
@@ -491,6 +490,11 @@ bool GssContext::getPeerPrincipal(std::string& name)
   }
 }
 
+std::tuple<size_t, size_t, size_t> GssContext::getCounts()
+{
+  return std::make_tuple(s_gss_init_creds.lock()->size(), s_gss_accept_creds.lock()->size(), s_gss_sec_context.lock()->size());
+}
+
 void GssContext::processError(const std::string& method, OM_uint32 maj, OM_uint32 min)
 {
   OM_uint32 tmp_min;
@@ -502,6 +506,7 @@ void GssContext::processError(const std::string& method, OM_uint32 maj, OM_uint3
     ostringstream oss;
     gss_display_status(&tmp_min, maj, GSS_C_GSS_CODE, GSS_C_NULL_OID, &msg_ctx, &msg);
     oss << method << ": " << msg.value;
+    /// XXX leaks gss_buffer_desc?
     d_gss_errors.push_back(oss.str());
     if (!msg_ctx)
       break;
@@ -511,6 +516,7 @@ void GssContext::processError(const std::string& method, OM_uint32 maj, OM_uint3
     ostringstream oss;
     gss_display_status(&tmp_min, min, GSS_C_MECH_CODE, GSS_C_NULL_OID, &msg_ctx, &msg);
     oss << method << ": " << msg.value;
+    /// XXX leaks gss_buffer_desc?
     d_gss_errors.push_back(oss.str());
     if (!msg_ctx)
       break;
index 0fcf1bea7c63c53a24e57c3960a5d8414b7068b3..d75e211f2e8f894470aac1c87754eb01df49c150 100644 (file)
@@ -33,7 +33,6 @@
 #include <gssapi/gssapi_ext.h>
 #endif
 
-
 //! Generic errors
 enum GssContextError
 {
@@ -151,11 +150,12 @@ private:
   OM_uint32 d_maj, d_min;
   gss_name_t d_name;
 #endif
-};
+}; // GssName
 
 class GssContext
 {
 public:
+  static std::tuple<size_t, size_t, size_t> getCounts();
   static bool supported(); //<! Returns true if GSS is supported in the first place
   GssContext(); //<! Construct new GSS context with random name
   GssContext(const DNSName& label); //<! Create or open existing named context
@@ -192,8 +192,8 @@ private:
   GssContextError d_error; //<! Context error
   GssContextType d_type; //<! Context type
   std::vector<std::string> d_gss_errors; //<! Native error string(s)
-  std::shared_ptr<GssSecContext> d_ctx; //<! Attached security context
-};
+  std::shared_ptr<GssSecContext> d_secctx; //<! Attached security context
+}; // GssContext
 
 bool gss_add_signature(const DNSName& context, const std::string& message, std::string& mac); //<! Create signature
 bool gss_verify_signature(const DNSName& context, const std::string& message, const std::string& mac); //<! Validate signature
index baa06ea6384e7e8cb55908296abce431dbea32c3..e35960f505d12f0121cb6e94ecd26837374721ca 100644 (file)
@@ -409,6 +409,7 @@ void TCPNameserver::doConnection(int fd)
         break;
 
       sendPacket(reply, fd);
+      packet->cleanupGSS(reply->d.rcode);
     }
   }
   catch(PDNSException &ae) {
index f8799bd5f10702cc0a60af0b622d1d16a800a7f2..61cdf43a3263903f8e282a803b7692615931a451 100644 (file)
@@ -22,6 +22,7 @@ struct Question
   {
     return make_unique<DNSPacket>(false);
   }
+  void cleanupGSS(int){}
 };
 
 struct Backend
index 7c9e9aa316b4c065b5766f992ec36e443f5afda2..0186e7c0a008648b2ed9bf6ac458f47bb30dbc91 100644 (file)
@@ -4,6 +4,11 @@
 #include "packethandler.hh"
 
 void PacketHandler::tkeyHandler(const DNSPacket& p, std::unique_ptr<DNSPacket>& r) {
+#if 0
+  auto [i,a,s] = GssContext::getCounts();
+  cerr << "#init_creds: " << i << " #accept_creds: " << a << " #secctxs: " << s << endl;
+#endif
+
   TKEYRecordContent tkey_in;
   std::shared_ptr<TKEYRecordContent> tkey_out(new TKEYRecordContent());
   DNSName name;
@@ -22,8 +27,6 @@ void PacketHandler::tkeyHandler(const DNSPacket& p, std::unique_ptr<DNSPacket>&
   tkey_out->d_inception = time((time_t*)nullptr);
   tkey_out->d_expiration = tkey_out->d_inception+15;
 
-  GssContext ctx(name);
-
   if (tkey_in.d_mode == 3) { // establish context
     if (tkey_in.d_algo == DNSName("gss-tsig.")) {
       std::vector<std::string> meta;
@@ -34,14 +37,20 @@ void PacketHandler::tkeyHandler(const DNSPacket& p, std::unique_ptr<DNSPacket>&
         }
       } while(tmpName.chopOff());
 
-      if (meta.size()>0) {
+      if (meta.size() == 0) {
+        tkey_out->d_error = 20;
+      } else {
+        GssContext ctx(name);
         ctx.setLocalPrincipal(meta[0]);
+        // try to get a context
+        if (!ctx.accept(tkey_in.d_key, tkey_out->d_key)) {
+          ctx.destroy();
+          tkey_out->d_error = 19;
+        }
+        else {
+          sign = true;
+        }
       }
-      // try to get a context
-      if (!ctx.accept(tkey_in.d_key, tkey_out->d_key))
-        tkey_out->d_error = 19;
-      else
-        sign = true;
     } else {
       tkey_out->d_error = 21; // BADALGO
     }
@@ -53,10 +62,13 @@ void PacketHandler::tkeyHandler(const DNSPacket& p, std::unique_ptr<DNSPacket>&
         r->setRcode(RCode::NotAuth);
       return;
     }
-    if (ctx.valid())
+    GssContext ctx(name);
+    if (ctx.valid()) {
       ctx.destroy();
-    else
+    }
+    else {
       tkey_out->d_error = 20; // BADNAME (because we have no support for anything here)
+    }
   } else {
     if (p.d_havetsig == false && tkey_in.d_mode != 2) { // unauthenticated
       if (p.d.opcode == Opcode::Update)