]> git.ipfire.org Git - thirdparty/squid.git/commitdiff
Fix CertficateDB locking scheme
authorChristos Tsantilas <chtsanti@users.sourceforge.net>
Thu, 6 Oct 2011 18:18:24 +0000 (21:18 +0300)
committerChristos Tsantilas <chtsanti@users.sourceforge.net>
Thu, 6 Oct 2011 18:18:24 +0000 (21:18 +0300)
Currently we are locking every file going to be accessed by CertificateDB code
even if it is not realy needed, because of a more general lock.

This patch:
   - Replace the old FileLocker class with the pair Lock/Locker classes
   - Remove most of the locks in CertificateDB with only two locks one
     for main database locking and one lock for the file contain the
     current serial number.

This is a Measurement Factory project

src/ssl/certificate_db.cc
src/ssl/certificate_db.h

index b2e3576e84da6f664ab0bb3c1c3705d1aa21df23..ebfb14db4df79d8fa65b7ad5de3b007b5a8b049b 100644 (file)
@@ -4,6 +4,9 @@
 
 #include "config.h"
 #include "ssl/certificate_db.h"
+#if HAVE_ERRNO_H
+#include <errno.h>
+#endif
 #if HAVE_FSTREAM
 #include <fstream>
 #endif
 #include <fcntl.h>
 #endif
 
-Ssl::FileLocker::FileLocker(std::string const & filename)
-        :    fd(-1)
+#define HERE "(ssl_crtd) " << __FILE__ << ':' << __LINE__ << ": "
+
+Ssl::Lock::Lock(std::string const &aFilename) :
+    filename(aFilename),
+#if _SQUID_MSWIN_
+    hFile(INVALID_HANDLE_VALUE)
+#else
+    fd(-1)
+#endif
+{
+}
+
+bool Ssl::Lock::locked() const
 {
+#if _SQUID_MSWIN_
+    return hFile != INVALID_HANDLE_VALUE;
+#else
+    return fd != -1;
+#endif
+}
+
+void Ssl::Lock::lock()
+{
+
 #if _SQUID_MSWIN_
     hFile = CreateFile(TEXT(filename.c_str()), GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
-    if (hFile != INVALID_HANDLE_VALUE)
-        LockFile(hFile, 0, 0, 1, 0);
+    if (hFile == INVALID_HANDLE_VALUE)
 #else
     fd = open(filename.c_str(), 0);
-    if (fd != -1)
-        flock(fd, LOCK_EX);
+    if (fd == -1)
 #endif
+        throw std::runtime_error("Failed to open file " + filename);
+
+
+#if _SQUID_MSWIN_
+    if (!LockFile(hFile, 0, 0, 1, 0))
+#else
+    if (flock(fd, LOCK_EX) != 0)
+#endif
+        throw std::runtime_error("Failed to get a lock of " + filename);
 }
 
-Ssl::FileLocker::~FileLocker()
-{
+void Ssl::Lock::unlock()
+{ 
 #if _SQUID_MSWIN_
     if (hFile != INVALID_HANDLE_VALUE) {
         UnlockFile(hFile, 0, 0, 1, 0);
         CloseHandle(hFile);
+        hFile = INVALID_HANDLE_VALUE;
     }
 #else
     if (fd != -1) {
         flock(fd, LOCK_UN);
         close(fd);
+        fd = -1;
     }
 #endif
+    else
+        throw std::runtime_error("Lock is already unlocked for " + filename);
+}
+
+Ssl::Lock::~Lock()
+{
+    if (locked())
+        unlock();
+}
+
+Ssl::Locker::Locker(Lock &aLock, const char *aFileName, int aLineNo): 
+    weLocked(false), lock(aLock), fileName(aFileName), lineNo(aLineNo)
+{
+    if (!lock.locked()) {
+        lock.lock();
+        weLocked = true;
+    }
+}
+
+Ssl::Locker::~Locker()
+{
+    if (weLocked)
+        lock.unlock();
 }
 
 Ssl::CertificateDb::Row::Row()
@@ -130,26 +186,26 @@ Ssl::CertificateDb::CertificateDb(std::string const & aDb_path, size_t aMax_db_s
         db(NULL),
         max_db_size(aMax_db_size),
         fs_block_size(aFs_block_size),
+        dbLock(db_full),
+        dbSerialLock(serial_full),
         enabled_disk_store(true)
 {
     if (db_path.empty() && !max_db_size)
         enabled_disk_store = false;
     else if ((db_path.empty() && max_db_size) || (!db_path.empty() && !max_db_size))
         throw std::runtime_error("ssl_crtd is missing the required parameter. There should be -s and -M parameters together.");
-    else
-        load();
 }
 
 bool Ssl::CertificateDb::find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
 {
-    FileLocker db_locker(db_full);
+    const Locker locker(dbLock, Here);
     load();
     return pure_find(host_name, cert, pkey);
 }
 
 bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
 {
-    FileLocker db_locker(db_full);
+    const Locker locker(dbLock, Here);
     load();
     if (!db || !cert || !pkey || min_db_size > max_db_size)
         return false;
@@ -195,7 +251,6 @@ bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP
 
     row.reset();
     std::string filename(cert_full + "/" + serial_string + ".pem");
-    FileLocker cert_locker(filename);
     if (!writeCertAndPrivateKeyToFile(cert, pkey, filename.c_str()))
         return false;
     addSize(filename);
@@ -206,7 +261,7 @@ bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP
 
 BIGNUM * Ssl::CertificateDb::getCurrentSerialNumber()
 {
-    FileLocker serial_locker(serial_full);
+    const Locker locker(dbSerialLock, Here);
     // load serial number from file.
     Ssl::BIO_Pointer file(BIO_new(BIO_s_file()));
     if (!file)
@@ -297,11 +352,12 @@ void Ssl::CertificateDb::create(std::string const & db_path, int serial)
 void Ssl::CertificateDb::check(std::string const & db_path, size_t max_db_size)
 {
     CertificateDb db(db_path, max_db_size, 0);
+    db.load();
 }
 
 std::string Ssl::CertificateDb::getSNString() const
 {
-    FileLocker serial_locker(serial_full);
+    const Locker locker(dbSerialLock, Here);
     std::ifstream file(serial_full.c_str());
     if (!file)
         return "";
@@ -329,7 +385,6 @@ bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Poin
 
     // read cert and pkey from file.
     std::string filename(cert_full + "/" + rrow[cnlSerial] + ".pem");
-    FileLocker cert_locker(filename);
     readCertAndPrivateKeyFromFiles(cert, pkey, filename.c_str(), NULL);
     if (!cert || !pkey)
         return false;
@@ -338,19 +393,16 @@ bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Poin
 
 size_t Ssl::CertificateDb::size() const
 {
-    FileLocker size_locker(size_full);
     return readSize();
 }
 
 void Ssl::CertificateDb::addSize(std::string const & filename)
 {
-    FileLocker size_locker(size_full);
     writeSize(readSize() + getFileSize(filename));
 }
 
 void Ssl::CertificateDb::subSize(std::string const & filename)
 {
-    FileLocker size_locker(size_full);
     writeSize(readSize() - getFileSize(filename));
 }
 
@@ -432,7 +484,6 @@ void Ssl::CertificateDb::save()
 void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex)
 {
     const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
-    const FileLocker cert_locker(filename);
 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
     sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex);
 #else
index 0ed15d6fd7f8aa4bbbef4a03da1022098b457cf3..73ec704705f40ed3c6e918d299b1660a43d6d9ca 100644 (file)
 
 namespace Ssl
 {
-/// Cross platform file locker.
-class FileLocker
-{
+/// maintains an exclusive blocking file-based lock
+class Lock {
 public:
-    /// Lock file
-    FileLocker(std::string const & aFilename);
-    /// Unlock file
-    ~FileLocker();
+    explicit Lock(std::string const &filename); ///<  creates an unlocked lock
+    ~Lock(); ///<  releases the lock if it is locked
+    void lock(); ///<  locks the lock, may block
+    void unlock(); ///<  unlocks locked lock or throws
+    bool locked() const; ///<  whether our lock is locked
+    const char *name() const { return filename.c_str(); }
 private:
+    std::string filename;
 #if _SQUID_MSWIN_
     HANDLE hFile; ///< Windows file handle.
 #else
@@ -32,6 +34,24 @@ private:
 #endif
 };
 
+/// an exception-safe way to obtain and release a lock
+class Locker
+{
+public:
+    /// locks the lock if the lock was unlocked
+    Locker(Lock &lock, const char  *aFileName, int lineNo);
+    /// unlocks the lock if it was locked by us
+    ~Locker();
+private:
+    bool weLocked; ///<  whether we locked the lock
+    Lock &lock; ///<  the lock we are operating on
+    const std::string fileName; ///<  where the lock was needed
+    const int lineNo; ///<  where the lock was needed    
+};
+
+/// convenience macro to pass source code location to Locker and others
+#define Here __FILE__, __LINE__
+
 /**
  * Database class for storing SSL certificates and their private keys.
  * A database consist by:
@@ -150,6 +170,8 @@ private:
     TXT_DB_Pointer db; ///< Database with certificates info.
     const size_t max_db_size; ///< Max size of db.
     const size_t fs_block_size; ///< File system block size.
+    mutable Lock dbLock;  ///< protects the database file
+    mutable Lock dbSerialLock; ///< protects the serial number file
 
     bool enabled_disk_store; ///< The storage on the disk is enabled.
 };