]> git.ipfire.org Git - thirdparty/strongswan.git/blob - src/libstrongswan/plugins/gmp/gmp_rsa_private_key.c
e24fda8c26a26dbb8a1a6ebba074fd58e7034ce8
[thirdparty/strongswan.git] / src / libstrongswan / plugins / gmp / gmp_rsa_private_key.c
1 /*
2 * Copyright (C) 2017 Tobias Brunner
3 * Copyright (C) 2005 Jan Hutter
4 * Copyright (C) 2005-2009 Martin Willi
5 * Copyright (C) 2012 Andreas Steffen
6 * HSR Hochschule fuer Technik Rapperswil
7 *
8 * This program is free software; you can redistribute it and/or modify it
9 * under the terms of the GNU General Public License as published by the
10 * Free Software Foundation; either version 2 of the License, or (at your
11 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
12 *
13 * This program is distributed in the hope that it will be useful, but
14 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 * for more details.
17 */
18
19 #include <gmp.h>
20 #include <sys/stat.h>
21 #include <unistd.h>
22 #include <string.h>
23
24 #include "gmp_rsa_private_key.h"
25 #include "gmp_rsa_public_key.h"
26
27 #include <utils/debug.h>
28 #include <asn1/oid.h>
29 #include <asn1/asn1.h>
30 #include <asn1/asn1_parser.h>
31 #include <credentials/keys/signature_params.h>
32
33 #ifdef HAVE_MPZ_POWM_SEC
34 # undef mpz_powm
35 # define mpz_powm mpz_powm_sec
36 #endif
37
38 /**
39 * Public exponent to use for key generation.
40 */
41 #define PUBLIC_EXPONENT 0x10001
42
43 typedef struct private_gmp_rsa_private_key_t private_gmp_rsa_private_key_t;
44
45 /**
46 * Private data of a gmp_rsa_private_key_t object.
47 */
48 struct private_gmp_rsa_private_key_t {
49 /**
50 * Public interface for this signer.
51 */
52 gmp_rsa_private_key_t public;
53
54 /**
55 * Public modulus.
56 */
57 mpz_t n;
58
59 /**
60 * Public exponent.
61 */
62 mpz_t e;
63
64 /**
65 * Private prime 1.
66 */
67 mpz_t p;
68
69 /**
70 * Private Prime 2.
71 */
72 mpz_t q;
73
74 /**
75 * Carmichael function m = lambda(n) = lcm(p-1,q-1).
76 */
77 mpz_t m;
78
79 /**
80 * Private exponent and optional secret sharing polynomial coefficients.
81 */
82 mpz_t *d;
83
84 /**
85 * Private exponent 1.
86 */
87 mpz_t exp1;
88
89 /**
90 * Private exponent 2.
91 */
92 mpz_t exp2;
93
94 /**
95 * Private coefficient.
96 */
97 mpz_t coeff;
98
99 /**
100 * Total number of private key shares
101 */
102 u_int shares;
103
104 /**
105 * Secret sharing threshold
106 */
107 u_int threshold;
108
109 /**
110 * Optional verification key (threshold > 1).
111 */
112 mpz_t v;
113
114 /**
115 * Keysize in bytes.
116 */
117 size_t k;
118
119 /**
120 * reference count
121 */
122 refcount_t ref;
123 };
124
125 /**
126 * Convert a MP integer into a chunk_t
127 */
128 chunk_t gmp_mpz_to_chunk(const mpz_t value)
129 {
130 chunk_t n;
131
132 n.len = 1 + mpz_sizeinbase(value, 2) / BITS_PER_BYTE;
133 n.ptr = mpz_export(NULL, NULL, 1, n.len, 1, 0, value);
134 if (n.ptr == NULL)
135 { /* if we have zero in "value", gmp returns NULL */
136 n.len = 0;
137 }
138 return n;
139 }
140
141 /**
142 * Auxiliary function overwriting private key material with zero bytes
143 */
144 static void mpz_clear_sensitive(mpz_t z)
145 {
146 size_t len = mpz_size(z) * GMP_LIMB_BITS / BITS_PER_BYTE;
147 uint8_t *zeros = alloca(len);
148
149 memset(zeros, 0, len);
150 /* overwrite mpz_t with zero bytes before clearing it */
151 mpz_import(z, len, 1, 1, 1, 0, zeros);
152 mpz_clear(z);
153 }
154
155 /**
156 * Create a mpz prime of at least prime_size
157 */
158 static status_t compute_prime(size_t prime_size, bool safe, mpz_t *p, mpz_t *q)
159 {
160 rng_t *rng;
161 chunk_t random_bytes;
162 int count = 0;
163
164 rng = lib->crypto->create_rng(lib->crypto, RNG_TRUE);
165 if (!rng)
166 {
167 DBG1(DBG_LIB, "no RNG of quality %N found", rng_quality_names,
168 RNG_TRUE);
169 return FAILED;
170 }
171
172 mpz_init(*p);
173 mpz_init(*q);
174
175 do
176 {
177 if (!rng->allocate_bytes(rng, prime_size, &random_bytes))
178 {
179 DBG1(DBG_LIB, "failed to allocate random prime");
180 mpz_clear(*p);
181 mpz_clear(*q);
182 rng->destroy(rng);
183 return FAILED;
184 }
185
186 /* make sure the two most significant bits are set */
187 if (safe)
188 {
189 random_bytes.ptr[0] &= 0x7F;
190 random_bytes.ptr[0] |= 0x60;
191 mpz_import(*q, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
192 do
193 {
194 count++;
195 mpz_nextprime (*q, *q);
196 mpz_mul_ui(*p, *q, 2);
197 mpz_add_ui(*p, *p, 1);
198 }
199 while (mpz_probab_prime_p(*p, 10) == 0);
200 DBG2(DBG_LIB, "safe prime found after %d iterations", count);
201 }
202 else
203 {
204 random_bytes.ptr[0] |= 0xC0;
205 mpz_import(*p, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
206 mpz_nextprime (*p, *p);
207 }
208 chunk_clear(&random_bytes);
209 }
210
211 /* check if the prime isn't too large */
212 while (((mpz_sizeinbase(*p, 2) + 7) / 8) > prime_size);
213
214 rng->destroy(rng);
215
216 /* additionally return p-1 */
217 mpz_sub_ui(*q, *p, 1);
218
219 return SUCCESS;
220 }
221
222 /**
223 * PKCS#1 RSADP function
224 */
225 static chunk_t rsadp(private_gmp_rsa_private_key_t *this, chunk_t data)
226 {
227 mpz_t t1, t2;
228 chunk_t decrypted;
229
230 mpz_init(t1);
231 mpz_init(t2);
232
233 mpz_import(t1, data.len, 1, 1, 1, 0, data.ptr);
234
235 mpz_powm(t2, t1, this->exp1, this->p); /* m1 = c^dP mod p */
236 mpz_powm(t1, t1, this->exp2, this->q); /* m2 = c^dQ mod Q */
237 mpz_sub(t2, t2, t1); /* h = qInv (m1 - m2) mod p */
238 mpz_mod(t2, t2, this->p);
239 mpz_mul(t2, t2, this->coeff);
240 mpz_mod(t2, t2, this->p);
241
242 mpz_mul(t2, t2, this->q); /* m = m2 + h q */
243 mpz_add(t1, t1, t2);
244
245 decrypted.len = this->k;
246 decrypted.ptr = mpz_export(NULL, NULL, 1, decrypted.len, 1, 0, t1);
247 if (decrypted.ptr == NULL)
248 {
249 decrypted.len = 0;
250 }
251
252 mpz_clear_sensitive(t1);
253 mpz_clear_sensitive(t2);
254
255 return decrypted;
256 }
257
258 /**
259 * PKCS#1 RSASP1 function
260 */
261 static chunk_t rsasp1(private_gmp_rsa_private_key_t *this, chunk_t data)
262 {
263 return rsadp(this, data);
264 }
265
266 /**
267 * Build a signature using the PKCS#1 EMSA scheme
268 */
269 static bool build_emsa_pkcs1_signature(private_gmp_rsa_private_key_t *this,
270 hash_algorithm_t hash_algorithm,
271 chunk_t data, chunk_t *signature)
272 {
273 chunk_t digestInfo = chunk_empty;
274 chunk_t em;
275
276 if (hash_algorithm != HASH_UNKNOWN)
277 {
278 hasher_t *hasher;
279 chunk_t hash;
280 int hash_oid = hasher_algorithm_to_oid(hash_algorithm);
281
282 if (hash_oid == OID_UNKNOWN)
283 {
284 return FALSE;
285 }
286
287 hasher = lib->crypto->create_hasher(lib->crypto, hash_algorithm);
288 if (!hasher || !hasher->allocate_hash(hasher, data, &hash))
289 {
290 DESTROY_IF(hasher);
291 return FALSE;
292 }
293 hasher->destroy(hasher);
294
295 /* build DER-encoded digestInfo */
296 digestInfo = asn1_wrap(ASN1_SEQUENCE, "mm",
297 asn1_algorithmIdentifier(hash_oid),
298 asn1_simple_object(ASN1_OCTET_STRING, hash)
299 );
300 chunk_free(&hash);
301 data = digestInfo;
302 }
303
304 if (data.len > this->k - 3)
305 {
306 free(digestInfo.ptr);
307 DBG1(DBG_LIB, "unable to sign %d bytes using a %dbit key", data.len,
308 mpz_sizeinbase(this->n, 2));
309 return FALSE;
310 }
311
312 /* build chunk to rsa-decrypt:
313 * EM = 0x00 || 0x01 || PS || 0x00 || T.
314 * PS = 0xFF padding, with length to fill em
315 * T = encoded_hash
316 */
317 em.len = this->k;
318 em.ptr = malloc(em.len);
319
320 /* fill em with padding */
321 memset(em.ptr, 0xFF, em.len);
322 /* set magic bytes */
323 *(em.ptr) = 0x00;
324 *(em.ptr+1) = 0x01;
325 *(em.ptr + em.len - data.len - 1) = 0x00;
326 /* set DER-encoded hash */
327 memcpy(em.ptr + em.len - data.len, data.ptr, data.len);
328
329 /* build signature */
330 *signature = rsasp1(this, em);
331
332 free(digestInfo.ptr);
333 free(em.ptr);
334
335 return TRUE;
336 }
337
338 /**
339 * Build a signature using the PKCS#1 EMSA PSS scheme
340 */
341 static bool build_emsa_pss_signature(private_gmp_rsa_private_key_t *this,
342 rsa_pss_params_t *params, chunk_t data,
343 chunk_t *signature)
344 {
345 ext_out_function_t xof;
346 hasher_t *hasher = NULL;
347 rng_t *rng = NULL;
348 xof_t *mgf = NULL;
349 chunk_t hash, salt = chunk_empty, m, ps, db, dbmask, em;
350 size_t embits, emlen, maskbits;
351 bool success = FALSE;
352
353 if (!params)
354 {
355 return FALSE;
356 }
357 xof = xof_mgf1_from_hash_algorithm(params->mgf1_hash);
358 if (xof == XOF_UNDEFINED)
359 {
360 DBG1(DBG_LIB, "%N is not supported for MGF1", hash_algorithm_names,
361 params->mgf1_hash);
362 return FALSE;
363 }
364 /* emBits = modBits - 1 */
365 embits = mpz_sizeinbase(this->n, 2) - 1;
366 /* emLen = ceil(emBits/8) */
367 emlen = (embits + 7) / BITS_PER_BYTE;
368 /* mHash = Hash(M) */
369 hasher = lib->crypto->create_hasher(lib->crypto, params->hash);
370 if (!hasher)
371 {
372 DBG1(DBG_LIB, "hash algorithm %N not supported",
373 hash_algorithm_names, params->hash);
374 return FALSE;
375 }
376 hash = chunk_alloca(hasher->get_hash_size(hasher));
377 if (!hasher->get_hash(hasher, data, hash.ptr))
378 {
379 goto error;
380 }
381
382 salt.len = hash.len;
383 if (params->salt_len > RSA_PSS_SALT_LEN_DEFAULT)
384 {
385 salt.len = params->salt_len;
386 }
387 if (emlen < (hash.len + salt.len + 2))
388 { /* too long */
389 goto error;
390 }
391 if (salt.len)
392 {
393 salt = chunk_alloca(salt.len);
394 rng = lib->crypto->create_rng(lib->crypto, RNG_STRONG);
395 if (!rng || !rng->get_bytes(rng, salt.len, salt.ptr))
396 {
397 goto error;
398 }
399 }
400 /* M' = 0x0000000000000000 | mHash | salt */
401 m = chunk_cata("ccc",
402 chunk_from_chars(0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00),
403 hash, salt);
404 /* H = Hash(M') */
405 if (!hasher->get_hash(hasher, m, hash.ptr))
406 {
407 goto error;
408 }
409 /* PS = 00...<padding depending on hash and salt length> */
410 ps = chunk_alloca(emlen - salt.len - hash.len - 2);
411 memset(ps.ptr, 0, ps.len);
412 /* DB = PS | 0x01 | salt */
413 db = chunk_cata("ccc", ps, chunk_from_chars(0x01), salt);
414 /* dbMask = MGF(H, emLen - hLen - 1) */
415 mgf = lib->crypto->create_xof(lib->crypto, xof);
416 dbmask = chunk_alloca(db.len);
417 if (!mgf)
418 {
419 DBG1(DBG_LIB, "%N not supported", ext_out_function_names, xof);
420 goto error;
421 }
422 if (!mgf->set_seed(mgf, hash) ||
423 !mgf->get_bytes(mgf, dbmask.len, dbmask.ptr))
424 {
425 goto error;
426 }
427 /* maskedDB = DB xor dbMask */
428 memxor(db.ptr, dbmask.ptr, db.len);
429 /* zero out unused bits */
430 maskbits = (8 * emlen) - embits;
431 if (maskbits)
432 {
433 db.ptr[0] &= (0xff >> maskbits);
434 }
435 /* EM = maskedDB | H | 0xbc */
436 em = chunk_cata("ccc", db, hash, chunk_from_chars(0xbc));
437 /* S = RSASP1(K, EM) */
438 *signature = rsasp1(this, em);
439 success = TRUE;
440
441 error:
442 DESTROY_IF(hasher);
443 DESTROY_IF(rng);
444 DESTROY_IF(mgf);
445 return success;
446 }
447
448 METHOD(private_key_t, get_type, key_type_t,
449 private_gmp_rsa_private_key_t *this)
450 {
451 return KEY_RSA;
452 }
453
454 METHOD(private_key_t, sign, bool,
455 private_gmp_rsa_private_key_t *this, signature_scheme_t scheme,
456 void *params, chunk_t data, chunk_t *signature)
457 {
458 switch (scheme)
459 {
460 case SIGN_RSA_EMSA_PKCS1_NULL:
461 return build_emsa_pkcs1_signature(this, HASH_UNKNOWN, data, signature);
462 case SIGN_RSA_EMSA_PKCS1_SHA2_224:
463 return build_emsa_pkcs1_signature(this, HASH_SHA224, data, signature);
464 case SIGN_RSA_EMSA_PKCS1_SHA2_256:
465 return build_emsa_pkcs1_signature(this, HASH_SHA256, data, signature);
466 case SIGN_RSA_EMSA_PKCS1_SHA2_384:
467 return build_emsa_pkcs1_signature(this, HASH_SHA384, data, signature);
468 case SIGN_RSA_EMSA_PKCS1_SHA2_512:
469 return build_emsa_pkcs1_signature(this, HASH_SHA512, data, signature);
470 case SIGN_RSA_EMSA_PKCS1_SHA3_224:
471 return build_emsa_pkcs1_signature(this, HASH_SHA3_224, data, signature);
472 case SIGN_RSA_EMSA_PKCS1_SHA3_256:
473 return build_emsa_pkcs1_signature(this, HASH_SHA3_256, data, signature);
474 case SIGN_RSA_EMSA_PKCS1_SHA3_384:
475 return build_emsa_pkcs1_signature(this, HASH_SHA3_384, data, signature);
476 case SIGN_RSA_EMSA_PKCS1_SHA3_512:
477 return build_emsa_pkcs1_signature(this, HASH_SHA3_512, data, signature);
478 case SIGN_RSA_EMSA_PKCS1_SHA1:
479 return build_emsa_pkcs1_signature(this, HASH_SHA1, data, signature);
480 case SIGN_RSA_EMSA_PKCS1_MD5:
481 return build_emsa_pkcs1_signature(this, HASH_MD5, data, signature);
482 case SIGN_RSA_EMSA_PSS:
483 return build_emsa_pss_signature(this, params, data, signature);
484 default:
485 DBG1(DBG_LIB, "signature scheme %N not supported in RSA",
486 signature_scheme_names, scheme);
487 return FALSE;
488 }
489 }
490
491 METHOD(private_key_t, decrypt, bool,
492 private_gmp_rsa_private_key_t *this, encryption_scheme_t scheme,
493 chunk_t crypto, chunk_t *plain)
494 {
495 chunk_t em, stripped;
496 bool success = FALSE;
497
498 if (scheme != ENCRYPT_RSA_PKCS1)
499 {
500 DBG1(DBG_LIB, "encryption scheme %N not supported",
501 encryption_scheme_names, scheme);
502 return FALSE;
503 }
504 /* rsa decryption using PKCS#1 RSADP */
505 stripped = em = rsadp(this, crypto);
506
507 /* PKCS#1 v1.5 8.1 encryption-block formatting (EB = 00 || 02 || PS || 00 || D) */
508
509 /* check for hex pattern 00 02 in decrypted message */
510 if ((*stripped.ptr++ != 0x00) || (*(stripped.ptr++) != 0x02))
511 {
512 DBG1(DBG_LIB, "incorrect padding - probably wrong rsa key");
513 goto end;
514 }
515 stripped.len -= 2;
516
517 /* the plaintext data starts after first 0x00 byte */
518 while (stripped.len-- > 0 && *stripped.ptr++ != 0x00)
519
520 if (stripped.len == 0)
521 {
522 DBG1(DBG_LIB, "no plaintext data");
523 goto end;
524 }
525
526 *plain = chunk_clone(stripped);
527 success = TRUE;
528
529 end:
530 chunk_clear(&em);
531 return success;
532 }
533
534 METHOD(private_key_t, get_keysize, int,
535 private_gmp_rsa_private_key_t *this)
536 {
537 return mpz_sizeinbase(this->n, 2);
538 }
539
540 METHOD(private_key_t, get_public_key, public_key_t*,
541 private_gmp_rsa_private_key_t *this)
542 {
543 chunk_t n, e;
544 public_key_t *public;
545
546 n = gmp_mpz_to_chunk(this->n);
547 e = gmp_mpz_to_chunk(this->e);
548
549 public = lib->creds->create(lib->creds, CRED_PUBLIC_KEY, KEY_RSA,
550 BUILD_RSA_MODULUS, n, BUILD_RSA_PUB_EXP, e, BUILD_END);
551 chunk_free(&n);
552 chunk_free(&e);
553
554 return public;
555 }
556
557 METHOD(private_key_t, get_encoding, bool,
558 private_gmp_rsa_private_key_t *this, cred_encoding_type_t type,
559 chunk_t *encoding)
560 {
561 chunk_t n, e, d, p, q, exp1, exp2, coeff;
562 bool success;
563
564 n = gmp_mpz_to_chunk(this->n);
565 e = gmp_mpz_to_chunk(this->e);
566 d = gmp_mpz_to_chunk(*this->d);
567 p = gmp_mpz_to_chunk(this->p);
568 q = gmp_mpz_to_chunk(this->q);
569 exp1 = gmp_mpz_to_chunk(this->exp1);
570 exp2 = gmp_mpz_to_chunk(this->exp2);
571 coeff = gmp_mpz_to_chunk(this->coeff);
572
573 success = lib->encoding->encode(lib->encoding,
574 type, NULL, encoding, CRED_PART_RSA_MODULUS, n,
575 CRED_PART_RSA_PUB_EXP, e, CRED_PART_RSA_PRIV_EXP, d,
576 CRED_PART_RSA_PRIME1, p, CRED_PART_RSA_PRIME2, q,
577 CRED_PART_RSA_EXP1, exp1, CRED_PART_RSA_EXP2, exp2,
578 CRED_PART_RSA_COEFF, coeff, CRED_PART_END);
579 chunk_free(&n);
580 chunk_free(&e);
581 chunk_clear(&d);
582 chunk_clear(&p);
583 chunk_clear(&q);
584 chunk_clear(&exp1);
585 chunk_clear(&exp2);
586 chunk_clear(&coeff);
587
588 return success;
589 }
590
591 METHOD(private_key_t, get_fingerprint, bool,
592 private_gmp_rsa_private_key_t *this, cred_encoding_type_t type, chunk_t *fp)
593 {
594 chunk_t n, e;
595 bool success;
596
597 if (lib->encoding->get_cache(lib->encoding, type, this, fp))
598 {
599 return TRUE;
600 }
601 n = gmp_mpz_to_chunk(this->n);
602 e = gmp_mpz_to_chunk(this->e);
603
604 success = lib->encoding->encode(lib->encoding, type, this, fp,
605 CRED_PART_RSA_MODULUS, n, CRED_PART_RSA_PUB_EXP, e, CRED_PART_END);
606 chunk_free(&n);
607 chunk_free(&e);
608
609 return success;
610 }
611
612 METHOD(private_key_t, get_ref, private_key_t*,
613 private_gmp_rsa_private_key_t *this)
614 {
615 ref_get(&this->ref);
616 return &this->public.key;
617 }
618
619 METHOD(private_key_t, destroy, void,
620 private_gmp_rsa_private_key_t *this)
621 {
622 if (ref_put(&this->ref))
623 {
624 int i;
625
626 mpz_clear(this->n);
627 mpz_clear(this->e);
628 mpz_clear(this->v);
629 mpz_clear_sensitive(this->p);
630 mpz_clear_sensitive(this->q);
631 mpz_clear_sensitive(this->m);
632 mpz_clear_sensitive(this->exp1);
633 mpz_clear_sensitive(this->exp2);
634 mpz_clear_sensitive(this->coeff);
635
636 for (i = 0; i < this->threshold; i++)
637 {
638 mpz_clear_sensitive(*this->d + i);
639 }
640 free(this->d);
641
642 lib->encoding->clear_cache(lib->encoding, this);
643 free(this);
644 }
645 }
646
647 /**
648 * Check the loaded key if it is valid and usable
649 */
650 static status_t check(private_gmp_rsa_private_key_t *this)
651 {
652 mpz_t u, p1, q1;
653 status_t status = SUCCESS;
654
655 /* PKCS#1 1.5 section 6 requires modulus to have at least 12 octets.
656 * We actually require more (for security).
657 */
658 if (this->k < 512 / BITS_PER_BYTE)
659 {
660 DBG1(DBG_LIB, "key shorter than 512 bits");
661 return FAILED;
662 }
663
664 /* we picked a max modulus size to simplify buffer allocation */
665 if (this->k > 8192 / BITS_PER_BYTE)
666 {
667 DBG1(DBG_LIB, "key larger than 8192 bits");
668 return FAILED;
669 }
670
671 mpz_init(u);
672 mpz_init(p1);
673 mpz_init(q1);
674
675 /* precompute p1 = p-1 and q1 = q-1 */
676 mpz_sub_ui(p1, this->p, 1);
677 mpz_sub_ui(q1, this->q, 1);
678
679 /* check that n == p * q */
680 mpz_mul(u, this->p, this->q);
681 if (mpz_cmp(u, this->n) != 0)
682 {
683 status = FAILED;
684 }
685
686 /* check that e divides neither p-1 nor q-1 */
687 mpz_mod(u, p1, this->e);
688 if (mpz_cmp_ui(u, 0) == 0)
689 {
690 status = FAILED;
691 }
692
693 mpz_mod(u, q1, this->e);
694 if (mpz_cmp_ui(u, 0) == 0)
695 {
696 status = FAILED;
697 }
698
699 /* check that d is e^-1 (mod lcm(p-1, q-1)) */
700 /* see PKCS#1v2, aka RFC 2437, for the "lcm" */
701 mpz_lcm(this->m, p1, q1);
702 mpz_mul(u, *this->d, this->e);
703 mpz_mod(u, u, this->m);
704 if (mpz_cmp_ui(u, 1) != 0)
705 {
706 status = FAILED;
707 }
708
709 /* check that exp1 is d mod (p-1) */
710 mpz_mod(u, *this->d, p1);
711 if (mpz_cmp(u, this->exp1) != 0)
712 {
713 status = FAILED;
714 }
715
716 /* check that exp2 is d mod (q-1) */
717 mpz_mod(u, *this->d, q1);
718 if (mpz_cmp(u, this->exp2) != 0)
719 {
720 status = FAILED;
721 }
722
723 /* check that coeff is (q^-1) mod p */
724 mpz_mul(u, this->coeff, this->q);
725 mpz_mod(u, u, this->p);
726 if (mpz_cmp_ui(u, 1) != 0)
727 {
728 status = FAILED;
729 }
730
731 mpz_clear_sensitive(u);
732 mpz_clear_sensitive(p1);
733 mpz_clear_sensitive(q1);
734
735 if (status != SUCCESS)
736 {
737 DBG1(DBG_LIB, "key integrity tests failed");
738 }
739 return status;
740 }
741
742 /**
743 * Internal generic constructor
744 */
745 static private_gmp_rsa_private_key_t *gmp_rsa_private_key_create_empty(void)
746 {
747 private_gmp_rsa_private_key_t *this;
748
749 INIT(this,
750 .public = {
751 .key = {
752 .get_type = _get_type,
753 .sign = _sign,
754 .decrypt = _decrypt,
755 .get_keysize = _get_keysize,
756 .get_public_key = _get_public_key,
757 .equals = private_key_equals,
758 .belongs_to = private_key_belongs_to,
759 .get_fingerprint = _get_fingerprint,
760 .has_fingerprint = private_key_has_fingerprint,
761 .get_encoding = _get_encoding,
762 .get_ref = _get_ref,
763 .destroy = _destroy,
764 },
765 },
766 .threshold = 1,
767 .ref = 1,
768 );
769 return this;
770 }
771
772 /**
773 * See header.
774 */
775 gmp_rsa_private_key_t *gmp_rsa_private_key_gen(key_type_t type, va_list args)
776 {
777 private_gmp_rsa_private_key_t *this;
778 u_int key_size = 0, shares = 0, threshold = 1;
779 bool safe_prime = FALSE, rng_failed = FALSE, invert_failed = FALSE;
780 mpz_t p, q, p1, q1, d;
781 ;
782
783 while (TRUE)
784 {
785 switch (va_arg(args, builder_part_t))
786 {
787 case BUILD_KEY_SIZE:
788 key_size = va_arg(args, u_int);
789 continue;
790 case BUILD_SAFE_PRIMES:
791 safe_prime = TRUE;
792 continue;
793 case BUILD_SHARES:
794 shares = va_arg(args, u_int);
795 continue;
796 case BUILD_THRESHOLD:
797 threshold = va_arg(args, u_int);
798 continue;
799 case BUILD_END:
800 break;
801 default:
802 return NULL;
803 }
804 break;
805 }
806 if (!key_size)
807 {
808 return NULL;
809 }
810 key_size = key_size / BITS_PER_BYTE;
811
812 /* Get values of primes p and q */
813 if (compute_prime(key_size/2, safe_prime, &p, &p1) != SUCCESS)
814 {
815 return NULL;
816 }
817 if (compute_prime(key_size/2, safe_prime, &q, &q1) != SUCCESS)
818 {
819 mpz_clear(p);
820 mpz_clear(p1);
821 return NULL;
822 }
823
824 /* Swapping Primes so p is larger then q */
825 if (mpz_cmp(p, q) < 0)
826 {
827 mpz_swap(p, q);
828 mpz_swap(p1, q1);
829 }
830
831 /* Create and initialize RSA private key object */
832 this = gmp_rsa_private_key_create_empty();
833 this->shares = shares;
834 this->threshold = threshold;
835 this->d = malloc(threshold * sizeof(mpz_t));
836 *this->p = *p;
837 *this->q = *q;
838
839 mpz_init_set_ui(this->e, PUBLIC_EXPONENT);
840 mpz_init(this->n);
841 mpz_init(this->m);
842 mpz_init(this->exp1);
843 mpz_init(this->exp2);
844 mpz_init(this->coeff);
845 mpz_init(this->v);
846 mpz_init(d);
847
848 mpz_mul(this->n, p, q); /* n = p*q */
849 mpz_lcm(this->m, p1, q1); /* m = lcm(p-1,q-1) */
850 mpz_invert(d, this->e, this->m); /* e has an inverse mod m */
851 mpz_mod(this->exp1, d, p1); /* exp1 = d mod p-1 */
852 mpz_mod(this->exp2, d, q1); /* exp2 = d mod q-1 */
853 mpz_invert(this->coeff, q, p); /* coeff = q^-1 mod p */
854
855 invert_failed = mpz_cmp_ui(this->m, 0) == 0 ||
856 mpz_cmp_ui(this->coeff, 0) == 0;
857
858 /* store secret exponent d */
859 (*this->d)[0] = *d;
860
861 /* generate and store random coefficients of secret sharing polynomial */
862 if (threshold > 1)
863 {
864 rng_t *rng;
865 chunk_t random_bytes;
866 mpz_t u;
867 int i;
868
869 rng = lib->crypto->create_rng(lib->crypto, RNG_TRUE);
870 mpz_init(u);
871
872 for (i = 1; i < threshold; i++)
873 {
874 mpz_init(d);
875
876 if (!rng->allocate_bytes(rng, key_size, &random_bytes))
877 {
878 rng_failed = TRUE;
879 continue;
880 }
881 mpz_import(d, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
882 mpz_mod(d, d, this->m);
883 (*this->d)[i] = *d;
884 chunk_clear(&random_bytes);
885 }
886
887 /* generate verification key v as a square number */
888 do
889 {
890 if (!rng->allocate_bytes(rng, key_size, &random_bytes))
891 {
892 rng_failed = TRUE;
893 break;
894 }
895 mpz_import(this->v, random_bytes.len, 1, 1, 1, 0, random_bytes.ptr);
896 mpz_mul(this->v, this->v, this->v);
897 mpz_mod(this->v, this->v, this->n);
898 mpz_gcd(u, this->v, this->n);
899 chunk_free(&random_bytes);
900 }
901 while (mpz_cmp_ui(u, 1) != 0);
902
903 mpz_clear(u);
904 rng->destroy(rng);
905 }
906
907 mpz_clear_sensitive(p1);
908 mpz_clear_sensitive(q1);
909
910 if (rng_failed || invert_failed)
911 {
912 DBG1(DBG_LIB, "rsa key generation failed");
913 destroy(this);
914 return NULL;
915 }
916
917 /* set key size in bytes */
918 this->k = key_size;
919
920 return &this->public;
921 }
922
923 /**
924 * Recover the primes from n, e and d using the algorithm described in
925 * Appendix C of NIST SP 800-56B.
926 */
927 static bool calculate_pq(private_gmp_rsa_private_key_t *this)
928 {
929 gmp_randstate_t rstate;
930 mpz_t k, r, g, y, n1, x;
931 int i, t, j;
932 bool success = FALSE;
933
934 gmp_randinit_default(rstate);
935 mpz_inits(k, r, g, y, n1, x, NULL);
936 /* k = (d * e) - 1 */
937 mpz_mul(k, *this->d, this->e);
938 mpz_sub_ui(k, k, 1);
939 if (mpz_odd_p(k))
940 {
941 goto error;
942 }
943 /* k = 2^t * r, where r is the largest odd integer dividing k, and t >= 1 */
944 mpz_set(r, k);
945 for (t = 0; !mpz_odd_p(r); t++)
946 { /* r = r/2 */
947 mpz_divexact_ui(r, r, 2);
948 }
949 /* we need n-1 below */
950 mpz_sub_ui(n1, this->n, 1);
951 for (i = 0; i < 100; i++)
952 { /* generate random integer g in [0, n-1] */
953 mpz_urandomm(g, rstate, this->n);
954 /* y = g^r mod n */
955 mpz_powm_sec(y, g, r, this->n);
956 /* try again if y == 1 or y == n-1 */
957 if (mpz_cmp_ui(y, 1) == 0 || mpz_cmp(y, n1) == 0)
958 {
959 continue;
960 }
961 for (j = 0; j < t; j++)
962 { /* x = y^2 mod n */
963 mpz_powm_ui(x, y, 2, this->n);
964 /* stop if x == 1 */
965 if (mpz_cmp_ui(x, 1) == 0)
966 {
967 goto done;
968 }
969 /* retry with new g if x = n-1 */
970 if (mpz_cmp(x, n1) == 0)
971 {
972 break;
973 }
974 /* y = x */
975 mpz_set(y, x);
976 }
977 }
978 goto error;
979
980 done:
981 /* p = gcd(y-1, n) */
982 mpz_sub_ui(y, y, 1);
983 mpz_gcd(this->p, y, this->n);
984 /* q = n/p */
985 mpz_divexact(this->q, this->n, this->p);
986 success = TRUE;
987
988 error:
989 mpz_clear_sensitive(k);
990 mpz_clear_sensitive(r);
991 mpz_clear_sensitive(g);
992 mpz_clear_sensitive(y);
993 mpz_clear_sensitive(x);
994 mpz_clear(n1);
995 gmp_randclear(rstate);
996 return success;
997 }
998
999 /**
1000 * See header.
1001 */
1002 gmp_rsa_private_key_t *gmp_rsa_private_key_load(key_type_t type, va_list args)
1003 {
1004 private_gmp_rsa_private_key_t *this;
1005 chunk_t n, e, d, p, q, exp1, exp2, coeff;
1006
1007 n = e = d = p = q = exp1 = exp2 = coeff = chunk_empty;
1008 while (TRUE)
1009 {
1010 switch (va_arg(args, builder_part_t))
1011 {
1012 case BUILD_RSA_MODULUS:
1013 n = va_arg(args, chunk_t);
1014 continue;
1015 case BUILD_RSA_PUB_EXP:
1016 e = va_arg(args, chunk_t);
1017 continue;
1018 case BUILD_RSA_PRIV_EXP:
1019 d = va_arg(args, chunk_t);
1020 continue;
1021 case BUILD_RSA_PRIME1:
1022 p = va_arg(args, chunk_t);
1023 continue;
1024 case BUILD_RSA_PRIME2:
1025 q = va_arg(args, chunk_t);
1026 continue;
1027 case BUILD_RSA_EXP1:
1028 exp1 = va_arg(args, chunk_t);
1029 continue;
1030 case BUILD_RSA_EXP2:
1031 exp2 = va_arg(args, chunk_t);
1032 continue;
1033 case BUILD_RSA_COEFF:
1034 coeff = va_arg(args, chunk_t);
1035 continue;
1036 case BUILD_END:
1037 break;
1038 default:
1039 return NULL;
1040 }
1041 break;
1042 }
1043
1044 this = gmp_rsa_private_key_create_empty();
1045
1046 this->d = malloc(sizeof(mpz_t));
1047 mpz_init(this->n);
1048 mpz_init(this->e);
1049 mpz_init(*this->d);
1050 mpz_init(this->p);
1051 mpz_init(this->q);
1052 mpz_init(this->m);
1053 mpz_init(this->exp1);
1054 mpz_init(this->exp2);
1055 mpz_init(this->coeff);
1056 mpz_init(this->v);
1057
1058 mpz_import(this->n, n.len, 1, 1, 1, 0, n.ptr);
1059 mpz_import(this->e, e.len, 1, 1, 1, 0, e.ptr);
1060 mpz_import(*this->d, d.len, 1, 1, 1, 0, d.ptr);
1061 if (p.len)
1062 {
1063 mpz_import(this->p, p.len, 1, 1, 1, 0, p.ptr);
1064 }
1065 if (q.len)
1066 {
1067 mpz_import(this->q, q.len, 1, 1, 1, 0, q.ptr);
1068 }
1069 if (!p.len && !q.len)
1070 { /* p and q missing in key, recalculate from n, e and d */
1071 if (!calculate_pq(this))
1072 {
1073 destroy(this);
1074 return NULL;
1075 }
1076 }
1077 else if (!p.len)
1078 { /* p missing in key, recalculate: p = n / q */
1079 mpz_divexact(this->p, this->n, this->q);
1080 }
1081 else if (!q.len)
1082 { /* q missing in key, recalculate: q = n / p */
1083 mpz_divexact(this->q, this->n, this->p);
1084 }
1085 if (!exp1.len)
1086 { /* exp1 missing in key, recalculate: exp1 = d mod (p-1) */
1087 mpz_sub_ui(this->exp1, this->p, 1);
1088 mpz_mod(this->exp1, *this->d, this->exp1);
1089 }
1090 else
1091 {
1092 mpz_import(this->exp1, exp1.len, 1, 1, 1, 0, exp1.ptr);
1093 }
1094 if (!exp2.len)
1095 { /* exp2 missing in key, recalculate: exp2 = d mod (q-1) */
1096 mpz_sub_ui(this->exp2, this->q, 1);
1097 mpz_mod(this->exp2, *this->d, this->exp2);
1098 }
1099 else
1100 {
1101 mpz_import(this->exp2, exp2.len, 1, 1, 1, 0, exp2.ptr);
1102 }
1103 if (!coeff.len)
1104 { /* coeff missing in key, recalculate: coeff = q^-1 mod p */
1105 mpz_invert(this->coeff, this->q, this->p);
1106 }
1107 else
1108 {
1109 mpz_import(this->coeff, coeff.len, 1, 1, 1, 0, coeff.ptr);
1110 }
1111 this->k = (mpz_sizeinbase(this->n, 2) + 7) / BITS_PER_BYTE;
1112 if (check(this) != SUCCESS)
1113 {
1114 destroy(this);
1115 return NULL;
1116 }
1117 return &this->public;
1118 }