]> git.ipfire.org Git - thirdparty/squid.git/blob - src/ssl/certificate_db.cc
Merged from trunk
[thirdparty/squid.git] / src / ssl / certificate_db.cc
1 /*
2 * $Id$
3 */
4
5 #include "config.h"
6 #include "ssl/certificate_db.h"
7 #if HAVE_FSTREAM
8 #include <fstream>
9 #endif
10 #if HAVE_STDEXCEPT
11 #include <stdexcept>
12 #endif
13 #if HAVE_SYS_STAT_H
14 #include <sys/stat.h>
15 #endif
16 #if HAVE_SYS_FILE_H
17 #include <sys/file.h>
18 #endif
19 #if HAVE_FCNTL_H
20 #include <fcntl.h>
21 #endif
22
23 Ssl::FileLocker::FileLocker(std::string const & filename)
24 : fd(-1)
25 {
26 #if _SQUID_MSWIN_
27 hFile = CreateFile(TEXT(filename.c_str()), GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
28 if (hFile != INVALID_HANDLE_VALUE)
29 LockFile(hFile, 0, 0, 1, 0);
30 #else
31 fd = open(filename.c_str(), 0);
32 if (fd != -1)
33 flock(fd, LOCK_EX);
34 #endif
35 }
36
37 Ssl::FileLocker::~FileLocker()
38 {
39 #if _SQUID_MSWIN_
40 if (hFile != INVALID_HANDLE_VALUE) {
41 UnlockFile(hFile, 0, 0, 1, 0);
42 CloseHandle(hFile);
43 }
44 #else
45 if (fd != -1) {
46 flock(fd, LOCK_UN);
47 close(fd);
48 }
49 #endif
50 }
51
52 Ssl::CertificateDb::Row::Row()
53 : width(cnlNumber)
54 {
55 row = new char *[width + 1];
56 for (size_t i = 0; i < width + 1; i++)
57 row[i] = NULL;
58 }
59
60 Ssl::CertificateDb::Row::~Row()
61 {
62 if (row) {
63 for (size_t i = 0; i < width + 1; i++) {
64 delete[](row[i]);
65 }
66 delete[](row);
67 }
68 }
69
70 void Ssl::CertificateDb::Row::reset()
71 {
72 row = NULL;
73 }
74
75 void Ssl::CertificateDb::Row::setValue(size_t cell, char const * value)
76 {
77 assert(cell < width);
78 if (row[cell]) {
79 free(row[cell]);
80 }
81 if (value) {
82 row[cell] = static_cast<char *>(malloc(sizeof(char) * (strlen(value) + 1)));
83 memcpy(row[cell], value, sizeof(char) * (strlen(value) + 1));
84 } else
85 row[cell] = NULL;
86 }
87
88 char ** Ssl::CertificateDb::Row::getRow()
89 {
90 return row;
91 }
92
93 unsigned long Ssl::CertificateDb::index_serial_hash(const char **a)
94 {
95 const char *n = a[Ssl::CertificateDb::cnlSerial];
96 while (*n == '0') n++;
97 return lh_strhash(n);
98 }
99
100 int Ssl::CertificateDb::index_serial_cmp(const char **a, const char **b)
101 {
102 const char *aa, *bb;
103 for (aa = a[Ssl::CertificateDb::cnlSerial]; *aa == '0'; aa++);
104 for (bb = b[Ssl::CertificateDb::cnlSerial]; *bb == '0'; bb++);
105 return strcmp(aa, bb);
106 }
107
108 unsigned long Ssl::CertificateDb::index_name_hash(const char **a)
109 {
110 return(lh_strhash(a[Ssl::CertificateDb::cnlName]));
111 }
112
113 int Ssl::CertificateDb::index_name_cmp(const char **a, const char **b)
114 {
115 return(strcmp(a[Ssl::CertificateDb::cnlName], b[CertificateDb::cnlName]));
116 }
117
118 const std::string Ssl::CertificateDb::serial_file("serial");
119 const std::string Ssl::CertificateDb::db_file("index.txt");
120 const std::string Ssl::CertificateDb::cert_dir("certs");
121 const std::string Ssl::CertificateDb::size_file("size");
122 const size_t Ssl::CertificateDb::min_db_size(4096);
123
124 Ssl::CertificateDb::CertificateDb(std::string const & aDb_path, size_t aMax_db_size, size_t aFs_block_size)
125 : db_path(aDb_path),
126 serial_full(aDb_path + "/" + serial_file),
127 db_full(aDb_path + "/" + db_file),
128 cert_full(aDb_path + "/" + cert_dir),
129 size_full(aDb_path + "/" + size_file),
130 db(NULL),
131 max_db_size(aMax_db_size),
132 fs_block_size(aFs_block_size),
133 enabled_disk_store(true)
134 {
135 if (db_path.empty() && !max_db_size)
136 enabled_disk_store = false;
137 else if ((db_path.empty() && max_db_size) || (!db_path.empty() && !max_db_size))
138 throw std::runtime_error("ssl_crtd is missing the required parameter. There should be -s and -M parameters together.");
139 else
140 load();
141 }
142
143 bool Ssl::CertificateDb::find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
144 {
145 FileLocker db_locker(db_full);
146 load();
147 return pure_find(host_name, cert, pkey);
148 }
149
150 bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
151 {
152 FileLocker db_locker(db_full);
153 load();
154 if (!db || !cert || !pkey || min_db_size > max_db_size)
155 return false;
156 Row row;
157 ASN1_INTEGER * ai = X509_get_serialNumber(cert.get());
158 std::string serial_string;
159 Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(ai, NULL));
160 {
161 TidyPointer<char, tidyFree> hex_bn(BN_bn2hex(serial.get()));
162 serial_string = std::string(hex_bn.get());
163 }
164 row.setValue(cnlSerial, serial_string.c_str());
165 char ** rrow = TXT_DB_get_by_index(db.get(), cnlSerial, row.getRow());
166 if (rrow != NULL)
167 return false;
168
169 {
170 TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
171 if (pure_find(subject.get(), cert, pkey))
172 return true;
173 }
174 // check db size.
175 while (max_db_size < size()) {
176 if (!deleteInvalidCertificate())
177 break;
178 }
179
180 while (max_db_size < size()) {
181 deleteOldestCertificate();
182 }
183
184 row.setValue(cnlType, "V");
185 ASN1_UTCTIME * tm = X509_get_notAfter(cert.get());
186 row.setValue(cnlExp_date, std::string(reinterpret_cast<char *>(tm->data), tm->length).c_str());
187 row.setValue(cnlFile, "unknown");
188 {
189 TidyPointer<char, tidyFree> subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0));
190 row.setValue(cnlName, subject.get());
191 }
192
193 if (!TXT_DB_insert(db.get(), row.getRow()))
194 return false;
195
196 row.reset();
197 std::string filename(cert_full + "/" + serial_string + ".pem");
198 FileLocker cert_locker(filename);
199 if (!writeCertAndPrivateKeyToFile(cert, pkey, filename.c_str()))
200 return false;
201 addSize(filename);
202
203 save();
204 return true;
205 }
206
207 BIGNUM * Ssl::CertificateDb::getCurrentSerialNumber()
208 {
209 FileLocker serial_locker(serial_full);
210 // load serial number from file.
211 Ssl::BIO_Pointer file(BIO_new(BIO_s_file()));
212 if (!file)
213 return NULL;
214
215 if (BIO_rw_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
216 return NULL;
217
218 Ssl::ASN1_INT_Pointer serial_ai(ASN1_INTEGER_new());
219 if (!serial_ai)
220 return NULL;
221
222 char buffer[1024];
223 if (!a2i_ASN1_INTEGER(file.get(), serial_ai.get(), buffer, sizeof(buffer)))
224 return NULL;
225
226 Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(serial_ai.get(), NULL));
227
228 if (!serial)
229 return NULL;
230
231 // increase serial number.
232 Ssl::BIGNUM_Pointer increased_serial(BN_dup(serial.get()));
233 if (!increased_serial)
234 return NULL;
235
236 BN_add_word(increased_serial.get(), 1);
237
238 // save increased serial number.
239 if (BIO_seek(file.get(), 0))
240 return NULL;
241
242 Ssl::ASN1_INT_Pointer increased_serial_ai(BN_to_ASN1_INTEGER(increased_serial.get(), NULL));
243 if (!increased_serial_ai)
244 return NULL;
245
246 i2a_ASN1_INTEGER(file.get(), increased_serial_ai.get());
247 BIO_puts(file.get(),"\n");
248
249 return serial.release();
250 }
251
252 void Ssl::CertificateDb::create(std::string const & db_path, int serial)
253 {
254 if (db_path == "")
255 throw std::runtime_error("Path to db is empty");
256 std::string serial_full(db_path + "/" + serial_file);
257 std::string db_full(db_path + "/" + db_file);
258 std::string cert_full(db_path + "/" + cert_dir);
259 std::string size_full(db_path + "/" + size_file);
260
261 #if _SQUID_MSWIN_
262 if (mkdir(db_path.c_str()))
263 #else
264 if (mkdir(db_path.c_str(), 0777))
265 #endif
266 throw std::runtime_error("Cannot create " + db_path);
267
268 #if _SQUID_MSWIN_
269 if (mkdir(cert_full.c_str()))
270 #else
271 if (mkdir(cert_full.c_str(), 0777))
272 #endif
273 throw std::runtime_error("Cannot create " + cert_full);
274
275 Ssl::ASN1_INT_Pointer i(ASN1_INTEGER_new());
276 ASN1_INTEGER_set(i.get(), serial);
277
278 Ssl::BIO_Pointer file(BIO_new(BIO_s_file()));
279 if (!file)
280 throw std::runtime_error("SSL error");
281
282 if (BIO_write_filename(file.get(), const_cast<char *>(serial_full.c_str())) <= 0)
283 throw std::runtime_error("Cannot open " + cert_full + " to open");
284
285 i2a_ASN1_INTEGER(file.get(), i.get());
286
287 std::ofstream size(size_full.c_str());
288 if (size)
289 size << 0;
290 else
291 throw std::runtime_error("Cannot open " + size_full + " to open");
292 std::ofstream db(db_full.c_str());
293 if (!db)
294 throw std::runtime_error("Cannot open " + db_full + " to open");
295 }
296
297 void Ssl::CertificateDb::check(std::string const & db_path, size_t max_db_size)
298 {
299 CertificateDb db(db_path, max_db_size, 0);
300 }
301
302 std::string Ssl::CertificateDb::getSNString() const
303 {
304 FileLocker serial_locker(serial_full);
305 std::ifstream file(serial_full.c_str());
306 if (!file)
307 return "";
308 std::string serial;
309 file >> serial;
310 return serial;
311 }
312
313 bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey)
314 {
315 if (!db)
316 return false;
317
318 Row row;
319 row.setValue(cnlName, host_name.c_str());
320
321 char **rrow = TXT_DB_get_by_index(db.get(), cnlName, row.getRow());
322 if (rrow == NULL)
323 return false;
324
325 if (!sslDateIsInTheFuture(rrow[cnlExp_date])) {
326 deleteByHostname(rrow[cnlName]);
327 return false;
328 }
329
330 // read cert and pkey from file.
331 std::string filename(cert_full + "/" + rrow[cnlSerial] + ".pem");
332 FileLocker cert_locker(filename);
333 readCertAndPrivateKeyFromFiles(cert, pkey, filename.c_str(), NULL);
334 if (!cert || !pkey)
335 return false;
336 return true;
337 }
338
339 size_t Ssl::CertificateDb::size() const
340 {
341 FileLocker size_locker(size_full);
342 return readSize();
343 }
344
345 void Ssl::CertificateDb::addSize(std::string const & filename)
346 {
347 FileLocker size_locker(size_full);
348 writeSize(readSize() + getFileSize(filename));
349 }
350
351 void Ssl::CertificateDb::subSize(std::string const & filename)
352 {
353 FileLocker size_locker(size_full);
354 writeSize(readSize() - getFileSize(filename));
355 }
356
357 size_t Ssl::CertificateDb::readSize() const
358 {
359 size_t db_size;
360 std::ifstream size_file(size_full.c_str());
361 if (!size_file && enabled_disk_store)
362 throw std::runtime_error("cannot read \"" + size_full + "\" file");
363 size_file >> db_size;
364 return db_size;
365 }
366
367 void Ssl::CertificateDb::writeSize(size_t db_size)
368 {
369 std::ofstream size_file(size_full.c_str());
370 if (!size_file && enabled_disk_store)
371 throw std::runtime_error("cannot write \"" + size_full + "\" file");
372 size_file << db_size;
373 }
374
375 size_t Ssl::CertificateDb::getFileSize(std::string const & filename)
376 {
377 std::ifstream file(filename.c_str(), std::ios::binary);
378 file.seekg(0, std::ios_base::end);
379 size_t file_size = file.tellg();
380 return ((file_size + fs_block_size - 1) / fs_block_size) * fs_block_size;
381 }
382
383 void Ssl::CertificateDb::load()
384 {
385 // Load db from file.
386 Ssl::BIO_Pointer in(BIO_new(BIO_s_file()));
387 if (!in || BIO_read_filename(in.get(), db_full.c_str()) <= 0)
388 throw std::runtime_error("Uninitialized SSL certificate database directory: " + db_path + ". To initialize, run \"ssl_crtd -c -s " + db_path + "\".");
389
390 bool corrupt = false;
391 Ssl::TXT_DB_Pointer temp_db(TXT_DB_read(in.get(), cnlNumber));
392 if (!temp_db)
393 corrupt = true;
394
395 // Create indexes in db.
396 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
397 if (!corrupt && !TXT_DB_create_index(temp_db.get(), cnlSerial, NULL, LHASH_HASH_FN(index_serial), LHASH_COMP_FN(index_serial)))
398 corrupt = true;
399
400 if (!corrupt && !TXT_DB_create_index(temp_db.get(), cnlName, NULL, LHASH_HASH_FN(index_name), LHASH_COMP_FN(index_name)))
401 corrupt = true;
402 #else
403 if (!corrupt && !TXT_DB_create_index(temp_db.get(), cnlSerial, NULL, LHASH_HASH_FN(index_serial_hash), LHASH_COMP_FN(index_serial_cmp)))
404 corrupt = true;
405
406 if (!corrupt && !TXT_DB_create_index(temp_db.get(), cnlName, NULL, LHASH_HASH_FN(index_name_hash), LHASH_COMP_FN(index_name_cmp)))
407 corrupt = true;
408 #endif
409
410 if (corrupt)
411 throw std::runtime_error("The SSL certificate database " + db_path + " is corrupted. Please rebuild");
412
413 db.reset(temp_db.release());
414 }
415
416 void Ssl::CertificateDb::save()
417 {
418 if (!db)
419 throw std::runtime_error("The certificates database is not loaded");;
420
421 // To save the db to file, create a new BIO with BIO file methods.
422 Ssl::BIO_Pointer out(BIO_new(BIO_s_file()));
423 if (!out || !BIO_write_filename(out.get(), const_cast<char *>(db_full.c_str())))
424 throw std::runtime_error("Failed to initialize " + db_full + " file for writing");;
425
426 if (TXT_DB_write(out.get(), db.get()) < 0)
427 throw std::runtime_error("Failed to write " + db_full + " file");
428 }
429
430 // Normally defined in defines.h file
431 #define countof(arr) (sizeof(arr)/sizeof(*arr))
432 void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex)
433 {
434 const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem");
435 const FileLocker cert_locker(filename);
436 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
437 sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex);
438 #else
439 sk_delete(db.get()->data, rowIndex);
440 #endif
441
442 const Columns db_indexes[]={cnlSerial, cnlName};
443 for (unsigned int i = 0; i < countof(db_indexes); i++) {
444 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
445 if (LHASH_OF(OPENSSL_STRING) *fieldIndex = db.get()->index[db_indexes[i]])
446 lh_OPENSSL_STRING_delete(fieldIndex, (char **)row);
447 #else
448 if (LHASH *fieldIndex = db.get()->index[db_indexes[i]])
449 lh_delete(fieldIndex, row);
450 #endif
451 }
452
453 subSize(filename);
454 int ret = remove(filename.c_str());
455 if (ret < 0)
456 throw std::runtime_error("Failed to remove certficate file " + filename + " from db");
457 }
458
459 bool Ssl::CertificateDb::deleteInvalidCertificate()
460 {
461 if (!db)
462 return false;
463
464 bool removed_one = false;
465 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
466 for (int i = 0; i < sk_OPENSSL_PSTRING_num(db.get()->data); i++) {
467 const char ** current_row = ((const char **)sk_OPENSSL_PSTRING_value(db.get()->data, i));
468 #else
469 for (int i = 0; i < sk_num(db.get()->data); i++) {
470 const char ** current_row = ((const char **)sk_value(db.get()->data, i));
471 #endif
472
473 if (!sslDateIsInTheFuture(current_row[cnlExp_date])) {
474 deleteRow(current_row, i);
475 removed_one = true;
476 break;
477 }
478 }
479
480 if (!removed_one)
481 return false;
482 return true;
483 }
484
485 bool Ssl::CertificateDb::deleteOldestCertificate()
486 {
487 if (!db)
488 return false;
489
490 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
491 if (sk_OPENSSL_PSTRING_num(db.get()->data) == 0)
492 #else
493 if (sk_num(db.get()->data) == 0)
494 #endif
495 return false;
496
497 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
498 const char **row = (const char **)sk_OPENSSL_PSTRING_value(db.get()->data, 0);
499 #else
500 const char **row = (const char **)sk_value(db.get()->data, 0);
501 #endif
502
503 deleteRow(row, 0);
504
505 return true;
506 }
507
508 bool Ssl::CertificateDb::deleteByHostname(std::string const & host)
509 {
510 if (!db)
511 return false;
512
513 #if OPENSSL_VERSION_NUMBER >= 0x1000004fL
514 for (int i = 0; i < sk_OPENSSL_PSTRING_num(db.get()->data); i++) {
515 const char ** current_row = ((const char **)sk_OPENSSL_PSTRING_value(db.get()->data, i));
516 #else
517 for (int i = 0; i < sk_num(db.get()->data); i++) {
518 const char ** current_row = ((const char **)sk_value(db.get()->data, i));
519 #endif
520 if (host == current_row[cnlName]) {
521 deleteRow(current_row, i);
522 return true;
523 }
524 }
525 return false;
526 }
527
528 bool Ssl::CertificateDb::IsEnabledDiskStore() const
529 {
530 return enabled_disk_store;
531 }