]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[crypto] Use Montgomery reduction for modular exponentiation
authorMichael Brown <mcb30@ipxe.org>
Mon, 25 Nov 2024 15:59:22 +0000 (15:59 +0000)
committerMichael Brown <mcb30@ipxe.org>
Thu, 28 Nov 2024 15:06:01 +0000 (15:06 +0000)
Speed up modular exponentiation by using Montgomery reduction rather
than direct modular reduction.

Montgomery reduction in base 2^n requires the modulus to be coprime to
2^n, which would limit us to requiring that the modulus is an odd
number.  Extend the implementation to include support for
exponentiation with even moduli via Garner's algorithm as described in
"Montgomery reduction with even modulus" (KoƧ, 1994).

Since almost all use cases for modular exponentation require a large
prime (and hence odd) modulus, the support for even moduli could
potentially be removed in future.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/crypto/bigint.c
src/crypto/dhe.c
src/crypto/rsa.c
src/include/ipxe/bigint.h
src/tests/bigint_test.c

index 6d75fbe9b8252365acef044d8bc946f2cd422e4e..39e1a25cde9d89103f6e7dacceec4567cf37fc14 100644 (file)
@@ -505,25 +505,142 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
                *exponent = ( ( const void * ) exponent0 );
        bigint_t ( size ) __attribute__ (( may_alias )) *result =
                ( ( void * ) result0 );
-       size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus );
+       const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
        struct {
-               bigint_t ( size ) base;
-               bigint_t ( exponent_size ) exponent;
-               uint8_t mod_multiply[mod_multiply_len];
+               union {
+                       bigint_t ( 2 * size ) padded_modulus;
+                       struct {
+                               bigint_t ( size ) modulus;
+                               bigint_t ( size ) stash;
+                       };
+               };
+               union {
+                       bigint_t ( 2 * size ) full;
+                       bigint_t ( size ) low;
+               } product;
        } *temp = tmp;
-       static const uint8_t start[1] = { 0x01 };
+       const uint8_t one[1] = { 1 };
+       bigint_t ( 1 ) modinv;
+       bigint_element_t submask;
+       unsigned int subsize;
+       unsigned int scale;
+       unsigned int max;
+       unsigned int bit;
+
+       /* Sanity check */
+       assert ( sizeof ( *temp ) == bigint_mod_exp_tmp_len ( modulus ) );
+
+       /* Handle degenerate case of zero modulus */
+       if ( ! bigint_max_set_bit ( modulus ) ) {
+               memset ( result, 0, sizeof ( *result ) );
+               return;
+       }
 
-       memcpy ( &temp->base, base, sizeof ( temp->base ) );
-       memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) );
-       bigint_init ( result, start, sizeof ( start ) );
+       /* Factor modulus as (N * 2^scale) where N is odd */
+       bigint_grow ( modulus, &temp->padded_modulus );
+       for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ;
+             scale++ ) {
+               bigint_shr ( &temp->modulus );
+       }
+       subsize = ( ( scale + width - 1 ) / width );
+       submask = ( ( 1UL << ( scale % width ) ) - 1 );
+       if ( ! submask )
+               submask = ~submask;
+
+       /* Calculate inverse of (scaled) modulus N modulo element size */
+       bigint_mod_invert ( &temp->modulus, &modinv );
+
+       /* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */
+       memset ( &temp->product.full, 0, sizeof ( temp->product.full ) );
+       bigint_subtract ( &temp->padded_modulus, &temp->product.full );
+       bigint_reduce ( &temp->padded_modulus, &temp->product.full );
+       bigint_copy ( &temp->product.low, &temp->stash );
+
+       /* Initialise result = Montgomery(1, R^2 mod N) */
+       bigint_montgomery ( &temp->modulus, &modinv,
+                           &temp->product.full, result );
+
+       /* Convert base into Montgomery form */
+       bigint_multiply ( base, &temp->stash, &temp->product.full );
+       bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
+                           &temp->stash );
+
+       /* Calculate x1 = base^exponent modulo N */
+       max = bigint_max_set_bit ( exponent );
+       for ( bit = 1 ; bit <= max ; bit++ ) {
+
+               /* Square (and reduce) */
+               bigint_multiply ( result, result, &temp->product.full );
+               bigint_montgomery ( &temp->modulus, &modinv,
+                                   &temp->product.full, result );
+
+               /* Multiply (and reduce) */
+               bigint_multiply ( &temp->stash, result, &temp->product.full );
+               bigint_montgomery ( &temp->modulus, &modinv,
+                                   &temp->product.full, &temp->product.low );
+
+               /* Conditionally swap the multiplied result */
+               bigint_swap ( result, &temp->product.low,
+                             bigint_bit_is_set ( exponent, ( max - bit ) ) );
+       }
 
-       while ( ! bigint_is_zero ( &temp->exponent ) ) {
-               if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) {
-                       bigint_mod_multiply ( result, &temp->base, modulus,
-                                             result, temp->mod_multiply );
+       /* Convert back out of Montgomery form */
+       bigint_grow ( result, &temp->product.full );
+       bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
+                           result );
+
+       /* Handle even moduli via Garner's algorithm */
+       if ( subsize ) {
+               const bigint_t ( subsize ) __attribute__ (( may_alias ))
+                       *subbase = ( ( const void * ) base );
+               bigint_t ( subsize ) __attribute__ (( may_alias ))
+                       *submodulus = ( ( void * ) &temp->modulus );
+               bigint_t ( subsize ) __attribute__ (( may_alias ))
+                       *substash = ( ( void * ) &temp->stash );
+               bigint_t ( subsize ) __attribute__ (( may_alias ))
+                       *subresult = ( ( void * ) result );
+               union {
+                       bigint_t ( 2 * subsize ) full;
+                       bigint_t ( subsize ) low;
+               } __attribute__ (( may_alias ))
+                       *subproduct = ( ( void * ) &temp->product.full );
+
+               /* Calculate x2 = base^exponent modulo 2^k */
+               bigint_init ( substash, one, sizeof ( one ) );
+               for ( bit = 1 ; bit <= max ; bit++ ) {
+
+                       /* Square (and reduce) */
+                       bigint_multiply ( substash, substash,
+                                         &subproduct->full );
+                       bigint_copy ( &subproduct->low, substash );
+
+                       /* Multiply (and reduce) */
+                       bigint_multiply ( subbase, substash,
+                                         &subproduct->full );
+
+                       /* Conditionally swap the multiplied result */
+                       bigint_swap ( substash, &subproduct->low,
+                                     bigint_bit_is_set ( exponent,
+                                                         ( max - bit ) ) );
                }
-               bigint_shr ( &temp->exponent );
-               bigint_mod_multiply ( &temp->base, &temp->base, modulus,
-                                     &temp->base, temp->mod_multiply );
+
+               /* Calculate N^-1 modulo 2^k */
+               bigint_mod_invert ( submodulus, &subproduct->low );
+               bigint_copy ( &subproduct->low, submodulus );
+
+               /* Calculate y = (x2 - x1) * N^-1 modulo 2^k */
+               bigint_subtract ( subresult, substash );
+               bigint_multiply ( substash, submodulus, &subproduct->full );
+               subproduct->low.element[ subsize - 1 ] &= submask;
+               bigint_grow ( &subproduct->low, &temp->stash );
+
+               /* Reconstruct N */
+               bigint_mod_invert ( submodulus, &subproduct->low );
+               bigint_copy ( &subproduct->low, submodulus );
+
+               /* Calculate x = x1 + N * y */
+               bigint_multiply ( &temp->modulus, &temp->stash,
+                                 &temp->product.full );
+               bigint_add ( &temp->product.low, result );
        }
 }
index 2da107d24f89a1c46926a71a9dbba87cb33cf718..a249f9b40b3f4bb4ee60ccdd52afad5b92d30549 100644 (file)
@@ -57,8 +57,7 @@ int dhe_key ( const void *modulus, size_t len, const void *generator,
        unsigned int size = bigint_required_size ( len );
        unsigned int private_size = bigint_required_size ( private_len );
        bigint_t ( size ) *mod;
-       bigint_t ( private_size ) *exp;
-       size_t tmp_len = bigint_mod_exp_tmp_len ( mod, exp );
+       size_t tmp_len = bigint_mod_exp_tmp_len ( mod );
        struct {
                bigint_t ( size ) modulus;
                bigint_t ( size ) generator;
index 19472c121b5aca300e08cafe50ffcc1844731035..44041da3ec66e56511c25813ef4e254316db6586 100644 (file)
@@ -109,8 +109,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
        unsigned int size = bigint_required_size ( modulus_len );
        unsigned int exponent_size = bigint_required_size ( exponent_len );
        bigint_t ( size ) *modulus;
-       bigint_t ( exponent_size ) *exponent;
-       size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent );
+       size_t tmp_len = bigint_mod_exp_tmp_len ( modulus );
        struct {
                bigint_t ( size ) modulus;
                bigint_t ( exponent_size ) exponent;
index 6c9730252ea4f743120b6d0fd652644e61d230ed..3ca871962cff5457d5faec025b16062a97dc8de7 100644 (file)
@@ -322,18 +322,12 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
  * Calculate temporary working space required for moduluar exponentiation
  *
  * @v modulus          Big integer modulus
- * @v exponent         Big integer exponent
  * @ret len            Length of temporary working space
  */
-#define bigint_mod_exp_tmp_len( modulus, exponent ) ( {                        \
+#define bigint_mod_exp_tmp_len( modulus ) ( {                          \
        unsigned int size = bigint_size (modulus);                      \
-       unsigned int exponent_size = bigint_size (exponent);            \
-       size_t mod_multiply_len =                                       \
-               bigint_mod_multiply_tmp_len (modulus);                  \
        sizeof ( struct {                                               \
-               bigint_t ( size ) temp_base;                            \
-               bigint_t ( exponent_size ) temp_exponent;               \
-               uint8_t mod_multiply[mod_multiply_len];                 \
+               bigint_t ( size ) temp[4];                              \
        } ); } )
 
 #include <bits/bigint.h>
index 1f2f5f24430a575667103be71ee684289d9be27a..f3291f6a6b587cdf287c0a0ce894a239099dcf57 100644 (file)
@@ -746,8 +746,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0,
        bigint_t ( size ) modulus_temp;                                 \
        bigint_t ( exponent_size ) exponent_temp;                       \
        bigint_t ( size ) result_temp;                                  \
-       size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp,        \
-                                                 &exponent_temp );     \
+       size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp );      \
        uint8_t tmp[tmp_len];                                           \
        {} /* Fix emacs alignment */                                    \
                                                                        \
@@ -2070,6 +2069,14 @@ static void bigint_test_exec ( void ) {
                            BIGINT ( 0xb9 ),
                            BIGINT ( 0x39, 0x68, 0xba, 0x7d ),
                            BIGINT ( 0x17 ) );
+       bigint_mod_exp_ok ( BIGINT ( 0x71, 0x4d, 0x02, 0xe9 ),
+                           BIGINT ( 0x00, 0x00, 0x00, 0x00 ),
+                           BIGINT ( 0x91, 0x7f, 0x4e, 0x3a, 0x5d, 0x5c ),
+                           BIGINT ( 0x00, 0x00, 0x00, 0x00 ) );
+       bigint_mod_exp_ok ( BIGINT ( 0x2b, 0xf5, 0x07, 0xaf ),
+                           BIGINT ( 0x6e, 0xb5, 0xda, 0x5a ),
+                           BIGINT ( 0x00, 0x00, 0x00, 0x00, 0x00 ),
+                           BIGINT ( 0x00, 0x00, 0x00, 0x01 ) );
        bigint_mod_exp_ok ( BIGINT ( 0x2e ),
                            BIGINT ( 0xb7 ),
                            BIGINT ( 0x39, 0x07, 0x1b, 0x49, 0x5b, 0xea,
@@ -2774,6 +2781,25 @@ static void bigint_test_exec ( void ) {
                                     0xfa, 0x83, 0xd4, 0x7c, 0xe9, 0x77,
                                     0x46, 0x91, 0x3a, 0x50, 0x0d, 0x6a,
                                     0x25, 0xd0 ) );
+       bigint_mod_exp_ok ( BIGINT ( 0x5b, 0x80, 0xc5, 0x03, 0xb3, 0x1e,
+                                    0x46, 0x9b, 0xa3, 0x0a, 0x70, 0x43,
+                                    0x51, 0x2a, 0x4a, 0x44, 0xcb, 0x87,
+                                    0x3e, 0x00, 0x2a, 0x48, 0x46, 0xf5,
+                                    0xb3, 0xb9, 0x73, 0xa7, 0x77, 0xfc,
+                                    0x2a, 0x1d ),
+                           BIGINT ( 0x5e, 0x8c, 0x80, 0x03, 0xe7, 0xb0,
+                                    0x45, 0x23, 0x8f, 0xe0, 0x77, 0x02,
+                                    0xc0, 0x7e, 0xfb, 0xc4, 0xbe, 0x7b,
+                                    0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                                    0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                                    0x00, 0x00 ),
+                           BIGINT ( 0x71, 0xd9, 0x38, 0xb6 ),
+                           BIGINT ( 0x52, 0xfc, 0x73, 0x55, 0x2f, 0x86,
+                                    0x0f, 0xde, 0x04, 0xbc, 0x6d, 0xb8,
+                                    0xfd, 0x48, 0xf8, 0x8c, 0x91, 0x1c,
+                                    0xa0, 0x8a, 0x70, 0xa8, 0xc6, 0x20,
+                                    0x0a, 0x0d, 0x3b, 0x2a, 0x92, 0x65,
+                                    0x9c, 0x59 ) );
 }
 
 /** Big integer self-test */