]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[crypto] Separate out bigint_reduce() from bigint_mod_multiply()
authorMichael Brown <mcb30@ipxe.org>
Tue, 15 Oct 2024 12:50:51 +0000 (13:50 +0100)
committerMichael Brown <mcb30@ipxe.org>
Tue, 15 Oct 2024 12:50:51 +0000 (13:50 +0100)
Faster modular multiplication algorithms such as Montgomery
multiplication will still require the ability to perform a single
direct modular reduction.

Neaten up the implementation of direct reduction and split it out into
a separate bigint_reduce() function, complete with its own unit tests.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/crypto/bigint.c
src/include/ipxe/bigint.h
src/tests/bigint_test.c

index c7b6dafc9a4bd3279825481a9316c44f8a35f7fc..a8b99ec3c07c0162291241960cb4b8f45ad8f419 100644 (file)
@@ -34,22 +34,14 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
  * Big integer support
  */
 
+/** Modular direct reduction profiler */
+static struct profiler bigint_mod_profiler __profiler =
+       { .name = "bigint_mod" };
+
 /** Modular multiplication overall profiler */
 static struct profiler bigint_mod_multiply_profiler __profiler =
        { .name = "bigint_mod_multiply" };
 
-/** Modular multiplication multiply step profiler */
-static struct profiler bigint_mod_multiply_multiply_profiler __profiler =
-       { .name = "bigint_mod_multiply.multiply" };
-
-/** Modular multiplication rescale step profiler */
-static struct profiler bigint_mod_multiply_rescale_profiler __profiler =
-       { .name = "bigint_mod_multiply.rescale" };
-
-/** Modular multiplication subtract step profiler */
-static struct profiler bigint_mod_multiply_subtract_profiler __profiler =
-       { .name = "bigint_mod_multiply.subtract" };
-
 /**
  * Conditionally swap big integers (in constant time)
  *
@@ -144,6 +136,175 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0,
        }
 }
 
+/**
+ * Reduce big integer
+ *
+ * @v minuend0         Element 0 of big integer to be reduced
+ * @v minuend_size     Number of elements in minuend
+ * @v modulus0         Element 0 of big integer modulus
+ * @v modulus_size     Number of elements in modulus and result
+ * @v result0          Element 0 of big integer to hold result
+ * @v tmp              Temporary working space
+ */
+void bigint_reduce_raw ( const bigint_element_t *minuend0,
+                        unsigned int minuend_size,
+                        const bigint_element_t *modulus0,
+                        unsigned int modulus_size,
+                        bigint_element_t *result0, void *tmp ) {
+       const bigint_t ( minuend_size ) __attribute__ (( may_alias ))
+               *minuend = ( ( const void * ) minuend0 );
+       const bigint_t ( modulus_size ) __attribute__ (( may_alias ))
+               *modulus = ( ( const void * ) modulus0 );
+       bigint_t ( modulus_size ) __attribute__ (( may_alias ))
+               *result = ( ( void * ) result0 );
+       struct {
+               bigint_t ( minuend_size ) minuend;
+               bigint_t ( minuend_size ) modulus;
+       } *temp = tmp;
+       const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
+       const bigint_element_t msb_mask = ( 1UL << ( width - 1 ) );
+       bigint_element_t *element;
+       unsigned int minuend_max;
+       unsigned int modulus_max;
+       unsigned int subshift;
+       bigint_element_t msb;
+       int offset;
+       int shift;
+       int i;
+
+       /* Start profiling */
+       profile_start ( &bigint_mod_profiler );
+
+       /* Sanity check */
+       assert ( minuend_size >= modulus_size );
+       assert ( sizeof ( *temp ) == bigint_reduce_tmp_len ( minuend ) );
+
+       /* Copy minuend and modulus to temporary working space */
+       bigint_shrink ( minuend, &temp->minuend );
+       bigint_grow ( modulus, &temp->modulus );
+
+       /* Normalise the modulus
+        *
+        * Scale the modulus by shifting left such that both modulus
+        * "m" and minuend "x" have the same most significant set bit.
+        * (If this is not possible, then the minuend is already less
+        * than the modulus, and we may therefore skip reduction
+        * completely.)
+        */
+       minuend_max = bigint_max_set_bit ( minuend );
+       modulus_max = bigint_max_set_bit ( modulus );
+       shift = ( minuend_max - modulus_max );
+       if ( shift < 0 )
+               goto skip;
+       subshift = ( shift & ( width - 1 ) );
+       offset = ( shift / width );
+       element = temp->modulus.element;
+       for ( i = ( ( minuend_max - 1 ) / width ) ; ; i-- ) {
+               element[i] = ( element[ i - offset ] << subshift );
+               if ( i <= offset )
+                       break;
+               if ( subshift ) {
+                       element[i] |= ( element[ i - offset - 1 ]
+                                       >> ( width - subshift ) );
+               }
+       }
+       for ( i-- ; i >= 0 ; i-- )
+               element[i] = 0;
+
+       /* Reduce the minuend "x" by iteratively adding or subtracting
+        * the scaled modulus "m".
+        *
+        * On each loop iteration, we maintain the invariant:
+        *
+        *    -2m <= x < 2m
+        *
+        * If x is positive, we obtain the new minuend x' by
+        * subtracting m, otherwise we add m:
+        *
+        *      0 <= x < 2m   =>   x' := x - m   =>   -m <= x' < m
+        *    -2m <= x < 0    =>   x' := x + m   =>   -m <= x' < m
+        *
+        * and then halve the modulus (by shifting right):
+        *
+        *      m' = m/2
+        *
+        * We therefore end up with:
+        *
+        *     -m <= x' < m   =>   -2m' <= x' < 2m'
+        *
+        * i.e. we have preseved the invariant while reducing the
+        * bounds on x' by one power of two.
+        *
+        * The issue remains of how to determine on each iteration
+        * whether or not x is currently positive, given that both
+        * input values are unsigned big integers that may use all
+        * available bits (including the MSB).
+        *
+        * On the first loop iteration, we may simply assume that x is
+        * positive, since it is unmodified from the input value and
+        * so is positive by definition (even if the MSB is set).  We
+        * therefore unconditionally perform a subtraction on the
+        * first loop iteration.
+        *
+        * Let k be the MSB after normalisation.  We then have:
+        *
+        *    2^k <= m < 2^(k+1)
+        *    2^k <= x < 2^(k+1)
+        *
+        * On the first loop iteration, we therefore have:
+        *
+        *     x' = (x - m)
+        *        < 2^(k+1) - 2^k
+        *        < 2^k
+        *
+        * Any positive value of x' therefore has its MSB set to zero,
+        * and so we may validly treat the MSB of x' as a sign bit at
+        * the end of the first loop iteration.
+        *
+        * On all subsequent loop iterations, the starting value m is
+        * guaranteed to have its MSB set to zero (since it has
+        * already been shifted right at least once).  Since we know
+        * from above that we preserve the loop invariant:
+        *
+        *     -m <= x' < m
+        *
+        * we immediately know that any positive value of x' also has
+        * its MSB set to zero, and so we may validly treat the MSB of
+        * x' as a sign bit at the end of all subsequent loop
+        * iterations.
+        *
+        * After the last loop iteration (when m' has been shifted
+        * back down to the original value of the modulus), we may
+        * need to add a single multiple of m' to ensure that x' is
+        * positive, i.e. lies within the range 0 <= x' < m'.  To
+        * allow for reusing the (inlined) expansion of
+        * bigint_subtract(), we achieve this via a potential
+        * additional loop iteration that performs the addition and is
+        * then guaranteed to terminate (since the result will be
+        * positive).
+        */
+       for ( msb = 0 ; ( msb || ( shift >= 0 ) ) ; shift-- ) {
+               if ( msb ) {
+                       bigint_add ( &temp->modulus, &temp->minuend );
+               } else {
+                       bigint_subtract ( &temp->modulus, &temp->minuend );
+               }
+               msb = ( temp->minuend.element[ minuend_size - 1 ] & msb_mask );
+               if ( shift > 0 )
+                       bigint_shr ( &temp->modulus );
+       }
+
+ skip:
+       /* Sanity check */
+       assert ( ! bigint_is_geq ( &temp->minuend, &temp->modulus ) );
+
+       /* Copy result */
+       bigint_shrink ( &temp->minuend, result );
+
+       /* Stop profiling */
+       profile_stop ( &bigint_mod_profiler );
+}
+
 /**
  * Perform modular multiplication of big integers
  *
@@ -171,8 +332,6 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
                bigint_t ( size * 2 ) result;
                bigint_t ( size * 2 ) modulus;
        } *temp = tmp;
-       int shift;
-       int i;
 
        /* Start profiling */
        profile_start ( &bigint_mod_multiply_profiler );
@@ -181,33 +340,13 @@ void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
        assert ( sizeof ( *temp ) == bigint_mod_multiply_tmp_len ( modulus ) );
 
        /* Perform multiplication */
-       profile_start ( &bigint_mod_multiply_multiply_profiler );
        bigint_multiply ( multiplicand, multiplier, &temp->result );
-       profile_stop ( &bigint_mod_multiply_multiply_profiler );
-
-       /* Rescale modulus to match result */
-       profile_start ( &bigint_mod_multiply_rescale_profiler );
-       bigint_grow ( modulus, &temp->modulus );
-       shift = ( bigint_max_set_bit ( &temp->result ) -
-                 bigint_max_set_bit ( &temp->modulus ) );
-       for ( i = 0 ; i < shift ; i++ )
-               bigint_shl ( &temp->modulus );
-       profile_stop ( &bigint_mod_multiply_rescale_profiler );
-
-       /* Subtract multiples of modulus */
-       profile_start ( &bigint_mod_multiply_subtract_profiler );
-       for ( i = 0 ; i <= shift ; i++ ) {
-               if ( bigint_is_geq ( &temp->result, &temp->modulus ) )
-                       bigint_subtract ( &temp->modulus, &temp->result );
-               bigint_shr ( &temp->modulus );
-       }
-       profile_stop ( &bigint_mod_multiply_subtract_profiler );
 
-       /* Resize result */
-       bigint_shrink ( &temp->result, result );
+       /* Reduce result */
+       bigint_reduce ( &temp->result, modulus, result, temp );
 
        /* Sanity check */
-       assert ( bigint_is_geq ( modulus, result ) );
+       assert ( ! bigint_is_geq ( result, modulus ) );
 
        /* Stop profiling */
        profile_stop ( &bigint_mod_multiply_profiler );
index c556afbc164591e3f6609547481fd555672abe54..c56b2155f1770d2741d46a3c5a479ef344fcc292 100644 (file)
@@ -217,6 +217,35 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
                              multiplier_size, (result)->element );     \
        } while ( 0 )
 
+/**
+ * Reduce big integer
+ *
+ * @v minuend          Big integer to be reduced
+ * @v modulus          Big integer modulus
+ * @v result           Big integer to hold result
+ * @v tmp              Temporary working space
+ */
+#define bigint_reduce( minuend, modulus, result, tmp ) do {            \
+       unsigned int minuend_size = bigint_size (minuend);              \
+       unsigned int modulus_size = bigint_size (modulus);              \
+       bigint_reduce_raw ( (minuend)->element, minuend_size,           \
+                           (modulus)->element, modulus_size,           \
+                           (result)->element, tmp );                   \
+       } while ( 0 )
+
+/**
+ * Calculate temporary working space required for reduction
+ *
+ * @v minuend          Big integer to be reduced
+ * @ret len            Length of temporary working space
+ */
+#define bigint_reduce_tmp_len( minuend ) ( {                           \
+       unsigned int size = bigint_size (minuend);                      \
+       sizeof ( struct {                                               \
+               bigint_t ( size ) temp_minuend;                         \
+               bigint_t ( size ) temp_modulus;                         \
+       } ); } )
+
 /**
  * Perform modular multiplication of big integers
  *
@@ -339,6 +368,11 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0,
                           const bigint_element_t *multiplier0,
                           unsigned int multiplier_size,
                           bigint_element_t *result0 );
+void bigint_reduce_raw ( const bigint_element_t *minuend0,
+                        unsigned int minuend_size,
+                        const bigint_element_t *modulus0,
+                        unsigned int modulus_size,
+                        bigint_element_t *result0, void *tmp );
 void bigint_mod_multiply_raw ( const bigint_element_t *multiplicand0,
                               const bigint_element_t *multiplier0,
                               const bigint_element_t *modulus0,
index 65f124f249e332e48da9d0cb1bb6c953ee70ab18..104e1f362cc1ead2c4b1056074f32493d6d51669 100644 (file)
@@ -185,6 +185,21 @@ void bigint_multiply_sample ( const bigint_element_t *multiplicand0,
        bigint_multiply ( multiplicand, multiplier, result );
 }
 
+void bigint_reduce_sample ( const bigint_element_t *minuend0,
+                           unsigned int minuend_size,
+                           const bigint_element_t *modulus0,
+                           unsigned int modulus_size,
+                           bigint_element_t *result0, void *tmp ) {
+       const bigint_t ( minuend_size ) __attribute__ (( may_alias ))
+               *minuend = ( ( const void * ) minuend0 );
+       const bigint_t ( modulus_size ) __attribute__ (( may_alias ))
+               *modulus = ( ( const void * ) modulus0 );
+       bigint_t ( modulus_size ) __attribute__ (( may_alias ))
+               *result = ( ( void * ) result0 );
+
+       bigint_reduce ( minuend, modulus, result, tmp );
+}
+
 void bigint_mod_multiply_sample ( const bigint_element_t *multiplicand0,
                                  const bigint_element_t *multiplier0,
                                  const bigint_element_t *modulus0,
@@ -516,6 +531,48 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0,
                      sizeof ( result_raw ) ) == 0 );                   \
        } while ( 0 )
 
+/**
+ * Report result of big integer modular direct reduction test
+ *
+ * @v minuend          Big integer to be reduced
+ * @v modulus          Big integer modulus
+ * @v expected         Big integer expected result
+ */
+#define bigint_reduce_ok( minuend, modulus, expected ) do {            \
+       static const uint8_t minuend_raw[] = minuend;                   \
+       static const uint8_t modulus_raw[] = modulus;                   \
+       static const uint8_t expected_raw[] = expected;                 \
+       uint8_t result_raw[ sizeof ( expected_raw ) ];                  \
+       unsigned int minuend_size =                                     \
+               bigint_required_size ( sizeof ( minuend_raw ) );        \
+       unsigned int modulus_size =                                     \
+               bigint_required_size ( sizeof ( modulus_raw ) );        \
+       bigint_t ( minuend_size ) minuend_temp;                         \
+       bigint_t ( modulus_size ) modulus_temp;                         \
+       bigint_t ( modulus_size ) result_temp;                          \
+       size_t tmp_len = bigint_reduce_tmp_len ( &minuend_temp );       \
+       uint8_t tmp[tmp_len];                                           \
+       {} /* Fix emacs alignment */                                    \
+                                                                       \
+       assert ( bigint_size ( &result_temp ) ==                        \
+                bigint_size ( &modulus_temp ) );                       \
+       bigint_init ( &minuend_temp, minuend_raw,                       \
+                     sizeof ( minuend_raw ) );                         \
+       bigint_init ( &modulus_temp, modulus_raw,                       \
+                     sizeof ( modulus_raw ) );                         \
+       DBG ( "Modular reduce:\n" );                                    \
+       DBG_HDA ( 0, &minuend_temp, sizeof ( minuend_temp ) );          \
+       DBG_HDA ( 0, &modulus_temp, sizeof ( modulus_temp ) );          \
+       bigint_reduce ( &minuend_temp, &modulus_temp, &result_temp,     \
+                       tmp );                                          \
+       DBG_HDA ( 0, &result_temp, sizeof ( result_temp ) );            \
+       bigint_done ( &result_temp, result_raw,                         \
+                     sizeof ( result_raw ) );                          \
+                                                                       \
+       ok ( memcmp ( result_raw, expected_raw,                         \
+                     sizeof ( result_raw ) ) == 0 );                   \
+       } while ( 0 )
+
 /**
  * Report result of big integer modular multiplication test
  *
@@ -1674,6 +1731,35 @@ static void bigint_test_exec ( void ) {
                                      0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                                      0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                                      0x00, 0x00, 0x00, 0x01 ) );
+       bigint_reduce_ok ( BIGINT ( 0x00 ),
+                          BIGINT ( 0xaf ),
+                          BIGINT ( 0x00 ) );
+       bigint_reduce_ok ( BIGINT ( 0xab ),
+                          BIGINT ( 0xab ),
+                          BIGINT ( 0x00 ) );
+       bigint_reduce_ok ( BIGINT ( 0x1d, 0x97, 0x63, 0xc9, 0x97, 0xcd, 0x43,
+                                   0xcb, 0x8e, 0x71, 0xac, 0x41, 0xdd ),
+                          BIGINT ( 0xcc, 0x9d, 0xa0, 0x79, 0x96, 0x6a, 0x46,
+                                   0xd5, 0xb4, 0x30, 0xd2, 0x2b, 0xbf ),
+                          BIGINT ( 0x1d, 0x97, 0x63, 0xc9, 0x97, 0xcd, 0x43,
+                                   0xcb, 0x8e, 0x71, 0xac, 0x41, 0xdd ) );
+       bigint_reduce_ok ( BIGINT ( 0x21, 0xfa, 0x4f, 0xce, 0x0f, 0x0f, 0x4d,
+                                   0x43, 0xaa, 0xad, 0x21, 0x30, 0xe5 ),
+                          BIGINT ( 0x21, 0xfa, 0x4f, 0xce, 0x0f, 0x0f, 0x4d,
+                                   0x43, 0xaa, 0xad, 0x21, 0x30, 0xe5 ),
+                          BIGINT ( 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                                   0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ) );
+       bigint_reduce_ok ( BIGINT ( 0xf9, 0x78, 0x96, 0x39, 0xee, 0x98, 0x42,
+                                   0x6a, 0xb8, 0x74, 0x0b, 0xe8, 0x5c, 0x76,
+                                   0x34, 0xaf ),
+                          BIGINT ( 0xf3, 0x65, 0x35, 0x41, 0x66, 0x65 ),
+                          BIGINT ( 0xb3, 0x07, 0xe8, 0xb7, 0x01, 0xf6 ) );
+       bigint_reduce_ok ( BIGINT ( 0xfe, 0x30, 0xe1, 0xc6, 0x65, 0x97, 0x48,
+                                   0x2e, 0x94, 0xd4 ),
+                          BIGINT ( 0x47, 0xaa, 0x88, 0x00, 0xd0, 0x30, 0x62,
+                                   0xfb, 0x5d, 0x55 ),
+                          BIGINT ( 0x27, 0x31, 0x49, 0xc3, 0xf5, 0x06, 0x1f,
+                                   0x3c, 0x7c, 0xd5 ) );
        bigint_mod_multiply_ok ( BIGINT ( 0x37 ),
                                 BIGINT ( 0x67 ),
                                 BIGINT ( 0x3f ),