]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[crypto] Remove the concept of a public-key algorithm reusable context
authorMichael Brown <mcb30@ipxe.org>
Wed, 21 Aug 2024 15:25:10 +0000 (16:25 +0100)
committerMichael Brown <mcb30@ipxe.org>
Wed, 21 Aug 2024 20:00:57 +0000 (21:00 +0100)
Instances of cipher and digest algorithms tend to get called
repeatedly to process substantial amounts of data.  This is not true
for public-key algorithms, which tend to get called only once or twice
for a given key.

Simplify the public-key algorithm API so that there is no reusable
algorithm context.  In particular, this allows callers to omit the
error handling currently required to handle memory allocation (or key
parsing) errors from pubkey_init(), and to omit the cleanup calls to
pubkey_final().

This change does remove the ability for a caller to distinguish
between a verification failure due to a memory allocation failure and
a verification failure due to a bad signature.  This difference is not
material in practice: in both cases, for whatever reason, the caller
was unable to verify the signature and so cannot proceed further, and
the cause of the error will be visible to the user via the return
status code.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/crypto/cms.c
src/crypto/crypto_null.c
src/crypto/ocsp.c
src/crypto/rsa.c
src/crypto/x509.c
src/drivers/net/iphone.c
src/include/ipxe/crypto.h
src/include/ipxe/rsa.h
src/include/ipxe/tls.h
src/net/tls.c
src/tests/pubkey_test.c

index 0b772f1cfa1633d26292b403b1b37fdf368e4db9..2e153d81919a7696c328b287708e97a71d060a1f 100644 (file)
@@ -612,33 +612,22 @@ static int cms_verify_digest ( struct cms_message *cms,
                               userptr_t data, size_t len ) {
        struct digest_algorithm *digest = part->digest;
        struct pubkey_algorithm *pubkey = part->pubkey;
-       struct x509_public_key *public_key = &cert->subject.public_key;
+       struct asn1_cursor *key = &cert->subject.public_key.raw;
        uint8_t digest_out[ digest->digestsize ];
-       uint8_t ctx[ pubkey->ctxsize ];
        int rc;
 
        /* Generate digest */
        cms_digest ( cms, part, data, len, digest_out );
 
-       /* Initialise public-key algorithm */
-       if ( ( rc = pubkey_init ( pubkey, ctx, &public_key->raw ) ) != 0 ) {
-               DBGC ( cms, "CMS %p/%p could not initialise public key: %s\n",
-                      cms, part, strerror ( rc ) );
-               goto err_init;
-       }
-
        /* Verify digest */
-       if ( ( rc = pubkey_verify ( pubkey, ctx, digest, digest_out,
+       if ( ( rc = pubkey_verify ( pubkey, key, digest, digest_out,
                                    part->value, part->len ) ) != 0 ) {
                DBGC ( cms, "CMS %p/%p signature verification failed: %s\n",
                       cms, part, strerror ( rc ) );
-               goto err_verify;
+               return rc;
        }
 
- err_verify:
-       pubkey_final ( pubkey, ctx );
- err_init:
-       return rc;
+       return 0;
 }
 
 /**
index b4169382b8f675f24c9984378338cd5f224c3f41..d5863f958e8f257eb69a9051a94b7ab5210a1680 100644 (file)
@@ -93,34 +93,31 @@ struct cipher_algorithm cipher_null = {
        .auth = cipher_null_auth,
 };
 
-int pubkey_null_init ( void *ctx __unused,
-                      const struct asn1_cursor *key __unused ) {
+size_t pubkey_null_max_len ( const struct asn1_cursor *key __unused ) {
        return 0;
 }
 
-size_t pubkey_null_max_len ( void *ctx __unused ) {
-       return 0;
-}
-
-int pubkey_null_encrypt ( void *ctx __unused, const void *plaintext __unused,
+int pubkey_null_encrypt ( const struct asn1_cursor *key __unused,
+                         const void *plaintext __unused,
                          size_t plaintext_len __unused,
                          void *ciphertext __unused ) {
        return 0;
 }
 
-int pubkey_null_decrypt ( void *ctx __unused, const void *ciphertext __unused,
+int pubkey_null_decrypt ( const struct asn1_cursor *key __unused,
+                         const void *ciphertext __unused,
                          size_t ciphertext_len __unused,
                          void *plaintext __unused ) {
        return 0;
 }
 
-int pubkey_null_sign ( void *ctx __unused,
+int pubkey_null_sign ( const struct asn1_cursor *key __unused,
                       struct digest_algorithm *digest __unused,
                       const void *value __unused, void *signature __unused ) {
        return 0;
 }
 
-int pubkey_null_verify ( void *ctx __unused,
+int pubkey_null_verify ( const struct asn1_cursor *key __unused,
                         struct digest_algorithm *digest __unused,
                         const void *value __unused,
                         const void *signature __unused ,
@@ -128,18 +125,11 @@ int pubkey_null_verify ( void *ctx __unused,
        return 0;
 }
 
-void pubkey_null_final ( void *ctx __unused ) {
-       /* Do nothing */
-}
-
 struct pubkey_algorithm pubkey_null = {
        .name = "null",
-       .ctxsize = 0,
-       .init = pubkey_null_init,
        .max_len = pubkey_null_max_len,
        .encrypt = pubkey_null_encrypt,
        .decrypt = pubkey_null_decrypt,
        .sign = pubkey_null_sign,
        .verify = pubkey_null_verify,
-       .final = pubkey_null_final,
 };
index f35593454214d9a3f186ee2d77979a496810cd71..e65f7180aedebfd9e677633a46943a56c7f5b955 100644 (file)
@@ -844,10 +844,9 @@ static int ocsp_check_signature ( struct ocsp_check *ocsp,
        struct ocsp_response *response = &ocsp->response;
        struct digest_algorithm *digest = response->algorithm->digest;
        struct pubkey_algorithm *pubkey = response->algorithm->pubkey;
-       struct x509_public_key *public_key = &signer->subject.public_key;
+       struct asn1_cursor *key = &signer->subject.public_key.raw;
        uint8_t digest_ctx[ digest->ctxsize ];
        uint8_t digest_out[ digest->digestsize ];
-       uint8_t pubkey_ctx[ pubkey->ctxsize ];
        int rc;
 
        /* Generate digest */
@@ -856,30 +855,18 @@ static int ocsp_check_signature ( struct ocsp_check *ocsp,
                        response->tbs.len );
        digest_final ( digest, digest_ctx, digest_out );
 
-       /* Initialise public-key algorithm */
-       if ( ( rc = pubkey_init ( pubkey, pubkey_ctx,
-                                 &public_key->raw ) ) != 0 ) {
-               DBGC ( ocsp, "OCSP %p \"%s\" could not initialise public key: "
-                      "%s\n", ocsp, x509_name ( ocsp->cert ), strerror ( rc ));
-               goto err_init;
-       }
-
        /* Verify digest */
-       if ( ( rc = pubkey_verify ( pubkey, pubkey_ctx, digest, digest_out,
+       if ( ( rc = pubkey_verify ( pubkey, key, digest, digest_out,
                                    response->signature.data,
                                    response->signature.len ) ) != 0 ) {
                DBGC ( ocsp, "OCSP %p \"%s\" signature verification failed: "
                       "%s\n", ocsp, x509_name ( ocsp->cert ), strerror ( rc ));
-               goto err_verify;
+               return rc;
        }
 
        DBGC2 ( ocsp, "OCSP %p \"%s\" signature is correct\n",
                ocsp, x509_name ( ocsp->cert ) );
-
- err_verify:
-       pubkey_final ( pubkey, pubkey_ctx );
- err_init:
-       return rc;
+       return 0;
 }
 
 /**
index 2d288a95349a073a7c01652ec4dae16d4d7101c5..19472c121b5aca300e08cafe50ffcc1844731035 100644 (file)
@@ -47,6 +47,28 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
 #define EINFO_EACCES_VERIFY \
        __einfo_uniqify ( EINFO_EACCES, 0x01, "RSA signature incorrect" )
 
+/** An RSA context */
+struct rsa_context {
+       /** Allocated memory */
+       void *dynamic;
+       /** Modulus */
+       bigint_element_t *modulus0;
+       /** Modulus size */
+       unsigned int size;
+       /** Modulus length */
+       size_t max_len;
+       /** Exponent */
+       bigint_element_t *exponent0;
+       /** Exponent size */
+       unsigned int exponent_size;
+       /** Input buffer */
+       bigint_element_t *input0;
+       /** Output buffer */
+       bigint_element_t *output0;
+       /** Temporary working space for modular exponentiation */
+       void *tmp;
+};
+
 /**
  * Identify RSA prefix
  *
@@ -69,10 +91,9 @@ rsa_find_prefix ( struct digest_algorithm *digest ) {
  *
  * @v context          RSA context
  */
-static void rsa_free ( struct rsa_context *context ) {
+static inline void rsa_free ( struct rsa_context *context ) {
 
        free ( context->dynamic );
-       context->dynamic = NULL;
 }
 
 /**
@@ -98,9 +119,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
                uint8_t tmp[tmp_len];
        } __attribute__ (( packed )) *dynamic;
 
-       /* Free any existing dynamic storage */
-       rsa_free ( context );
-
        /* Allocate dynamic storage */
        dynamic = malloc ( sizeof ( *dynamic ) );
        if ( ! dynamic )
@@ -231,12 +249,12 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
 /**
  * Initialise RSA cipher
  *
- * @v ctx              RSA context
+ * @v context          RSA context
  * @v key              Key
  * @ret rc             Return status code
  */
-static int rsa_init ( void *ctx, const struct asn1_cursor *key ) {
-       struct rsa_context *context = ctx;
+static int rsa_init ( struct rsa_context *context,
+                     const struct asn1_cursor *key ) {
        struct asn1_cursor modulus;
        struct asn1_cursor exponent;
        int rc;
@@ -277,13 +295,22 @@ static int rsa_init ( void *ctx, const struct asn1_cursor *key ) {
 /**
  * Calculate RSA maximum output length
  *
- * @v ctx              RSA context
+ * @v key              Key
  * @ret max_len                Maximum output length
  */
-static size_t rsa_max_len ( void *ctx ) {
-       struct rsa_context *context = ctx;
+static size_t rsa_max_len ( const struct asn1_cursor *key ) {
+       struct asn1_cursor modulus;
+       struct asn1_cursor exponent;
+       int rc;
 
-       return context->max_len;
+       /* Parse moduli and exponents */
+       if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, key ) ) != 0 ) {
+               /* Return a zero maximum length on error */
+               return 0;
+       }
+
+       /* Output length can never exceed modulus length */
+       return modulus.len;
 }
 
 /**
@@ -314,111 +341,147 @@ static void rsa_cipher ( struct rsa_context *context,
 /**
  * Encrypt using RSA
  *
- * @v ctx              RSA context
+ * @v key              Key
  * @v plaintext                Plaintext
  * @v plaintext_len    Length of plaintext
  * @v ciphertext       Ciphertext
  * @ret ciphertext_len Length of ciphertext, or negative error
  */
-static int rsa_encrypt ( void *ctx, const void *plaintext,
+static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
                         size_t plaintext_len, void *ciphertext ) {
-       struct rsa_context *context = ctx;
+       struct rsa_context context;
        void *temp;
        uint8_t *encoded;
-       size_t max_len = ( context->max_len - 11 );
-       size_t random_nz_len = ( max_len - plaintext_len + 8 );
+       size_t max_len;
+       size_t random_nz_len;
        int rc;
 
+       DBGC ( &context, "RSA %p encrypting:\n", &context );
+       DBGC_HDA ( &context, 0, plaintext, plaintext_len );
+
+       /* Initialise context */
+       if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+               goto err_init;
+
+       /* Calculate lengths */
+       max_len = ( context.max_len - 11 );
+       random_nz_len = ( max_len - plaintext_len + 8 );
+
        /* Sanity check */
        if ( plaintext_len > max_len ) {
-               DBGC ( context, "RSA %p plaintext too long (%zd bytes, max "
-                      "%zd)\n", context, plaintext_len, max_len );
-               return -ERANGE;
+               DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max "
+                      "%zd)\n", &context, plaintext_len, max_len );
+               rc = -ERANGE;
+               goto err_sanity;
        }
-       DBGC ( context, "RSA %p encrypting:\n", context );
-       DBGC_HDA ( context, 0, plaintext, plaintext_len );
 
        /* Construct encoded message (using the big integer output
         * buffer as temporary storage)
         */
-       temp = context->output0;
+       temp = context.output0;
        encoded = temp;
        encoded[0] = 0x00;
        encoded[1] = 0x02;
        if ( ( rc = get_random_nz ( &encoded[2], random_nz_len ) ) != 0 ) {
-               DBGC ( context, "RSA %p could not generate random data: %s\n",
-                      context, strerror ( rc ) );
-               return rc;
+               DBGC ( &context, "RSA %p could not generate random data: %s\n",
+                      &context, strerror ( rc ) );
+               goto err_random;
        }
        encoded[ 2 + random_nz_len ] = 0x00;
-       memcpy ( &encoded[ context->max_len - plaintext_len ],
+       memcpy ( &encoded[ context.max_len - plaintext_len ],
                 plaintext, plaintext_len );
 
        /* Encipher the encoded message */
-       rsa_cipher ( context, encoded, ciphertext );
-       DBGC ( context, "RSA %p encrypted:\n", context );
-       DBGC_HDA ( context, 0, ciphertext, context->max_len );
+       rsa_cipher ( &context, encoded, ciphertext );
+       DBGC ( &context, "RSA %p encrypted:\n", &context );
+       DBGC_HDA ( &context, 0, ciphertext, context.max_len );
+
+       /* Free context */
+       rsa_free ( &context );
 
-       return context->max_len;
+       return context.max_len;
+
+ err_random:
+ err_sanity:
+       rsa_free ( &context );
+ err_init:
+       return rc;
 }
 
 /**
  * Decrypt using RSA
  *
- * @v ctx              RSA context
+ * @v key              Key
  * @v ciphertext       Ciphertext
  * @v ciphertext_len   Ciphertext length
  * @v plaintext                Plaintext
  * @ret plaintext_len  Plaintext length, or negative error
  */
-static int rsa_decrypt ( void *ctx, const void *ciphertext,
+static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext,
                         size_t ciphertext_len, void *plaintext ) {
-       struct rsa_context *context = ctx;
+       struct rsa_context context;
        void *temp;
        uint8_t *encoded;
        uint8_t *end;
        uint8_t *zero;
        uint8_t *start;
        size_t plaintext_len;
+       int rc;
+
+       DBGC ( &context, "RSA %p decrypting:\n", &context );
+       DBGC_HDA ( &context, 0, ciphertext, ciphertext_len );
+
+       /* Initialise context */
+       if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+               goto err_init;
 
        /* Sanity check */
-       if ( ciphertext_len != context->max_len ) {
-               DBGC ( context, "RSA %p ciphertext incorrect length (%zd "
+       if ( ciphertext_len != context.max_len ) {
+               DBGC ( &context, "RSA %p ciphertext incorrect length (%zd "
                       "bytes, should be %zd)\n",
-                      context, ciphertext_len, context->max_len );
-               return -ERANGE;
+                      &context, ciphertext_len, context.max_len );
+               rc = -ERANGE;
+               goto err_sanity;
        }
-       DBGC ( context, "RSA %p decrypting:\n", context );
-       DBGC_HDA ( context, 0, ciphertext, ciphertext_len );
 
        /* Decipher the message (using the big integer input buffer as
         * temporary storage)
         */
-       temp = context->input0;
+       temp = context.input0;
        encoded = temp;
-       rsa_cipher ( context, ciphertext, encoded );
+       rsa_cipher ( &context, ciphertext, encoded );
 
        /* Parse the message */
-       end = ( encoded + context->max_len );
-       if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) )
-               goto invalid;
+       end = ( encoded + context.max_len );
+       if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) {
+               rc = -EINVAL;
+               goto err_invalid;
+       }
        zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) );
-       if ( ! zero )
-               goto invalid;
+       if ( ! zero ) {
+               rc = -EINVAL;
+               goto err_invalid;
+       }
        start = ( zero + 1 );
        plaintext_len = ( end - start );
 
        /* Copy out message */
        memcpy ( plaintext, start, plaintext_len );
-       DBGC ( context, "RSA %p decrypted:\n", context );
-       DBGC_HDA ( context, 0, plaintext, plaintext_len );
+       DBGC ( &context, "RSA %p decrypted:\n", &context );
+       DBGC_HDA ( &context, 0, plaintext, plaintext_len );
+
+       /* Free context */
+       rsa_free ( &context );
 
        return plaintext_len;
 
- invalid:
-       DBGC ( context, "RSA %p invalid decrypted message:\n", context );
-       DBGC_HDA ( context, 0, encoded, context->max_len );
-       return -EINVAL;
+ err_invalid:
+       DBGC ( &context, "RSA %p invalid decrypted message:\n", &context );
+       DBGC_HDA ( &context, 0, encoded, context.max_len );
+ err_sanity:
+       rsa_free ( &context );
+ err_init:
+       return rc;
 }
 
 /**
@@ -452,9 +515,9 @@ static int rsa_encode_digest ( struct rsa_context *context,
        /* Sanity check */
        max_len = ( context->max_len - 11 );
        if ( digestinfo_len > max_len ) {
-               DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, max"
-                      "%zd)\n",
-                      context, digest->name, digestinfo_len, max_len );
+               DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, "
+                      "max %zd)\n", context, digest->name, digestinfo_len,
+                      max_len );
                return -ERANGE;
        }
        DBGC ( context, "RSA %p encoding %s digest:\n",
@@ -482,104 +545,125 @@ static int rsa_encode_digest ( struct rsa_context *context,
 /**
  * Sign digest value using RSA
  *
- * @v ctx              RSA context
+ * @v key              Key
  * @v digest           Digest algorithm
  * @v value            Digest value
  * @v signature                Signature
  * @ret signature_len  Signature length, or negative error
  */
-static int rsa_sign ( void *ctx, struct digest_algorithm *digest,
-                     const void *value, void *signature ) {
-       struct rsa_context *context = ctx;
+static int rsa_sign ( const struct asn1_cursor *key,
+                     struct digest_algorithm *digest, const void *value,
+                     void *signature ) {
+       struct rsa_context context;
        void *temp;
        int rc;
 
-       DBGC ( context, "RSA %p signing %s digest:\n", context, digest->name );
-       DBGC_HDA ( context, 0, value, digest->digestsize );
+       DBGC ( &context, "RSA %p signing %s digest:\n",
+              &context, digest->name );
+       DBGC_HDA ( &context, 0, value, digest->digestsize );
+
+       /* Initialise context */
+       if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+               goto err_init;
 
        /* Encode digest (using the big integer output buffer as
         * temporary storage)
         */
-       temp = context->output0;
-       if ( ( rc = rsa_encode_digest ( context, digest, value, temp ) ) != 0 )
-               return rc;
+       temp = context.output0;
+       if ( ( rc = rsa_encode_digest ( &context, digest, value, temp ) ) != 0 )
+               goto err_encode;
 
        /* Encipher the encoded digest */
-       rsa_cipher ( context, temp, signature );
-       DBGC ( context, "RSA %p signed %s digest:\n", context, digest->name );
-       DBGC_HDA ( context, 0, signature, context->max_len );
+       rsa_cipher ( &context, temp, signature );
+       DBGC ( &context, "RSA %p signed %s digest:\n", &context, digest->name );
+       DBGC_HDA ( &context, 0, signature, context.max_len );
+
+       /* Free context */
+       rsa_free ( &context );
 
-       return context->max_len;
+       return context.max_len;
+
+ err_encode:
+       rsa_free ( &context );
+ err_init:
+       return rc;
 }
 
 /**
  * Verify signed digest value using RSA
  *
- * @v ctx              RSA context
+ * @v key              Key
  * @v digest           Digest algorithm
  * @v value            Digest value
  * @v signature                Signature
  * @v signature_len    Signature length
  * @ret rc             Return status code
  */
-static int rsa_verify ( void *ctx, struct digest_algorithm *digest,
-                       const void *value, const void *signature,
-                       size_t signature_len ) {
-       struct rsa_context *context = ctx;
+static int rsa_verify ( const struct asn1_cursor *key,
+                       struct digest_algorithm *digest, const void *value,
+                       const void *signature, size_t signature_len ) {
+       struct rsa_context context;
        void *temp;
        void *expected;
        void *actual;
        int rc;
 
+       DBGC ( &context, "RSA %p verifying %s digest:\n",
+              &context, digest->name );
+       DBGC_HDA ( &context, 0, value, digest->digestsize );
+       DBGC_HDA ( &context, 0, signature, signature_len );
+
+       /* Initialise context */
+       if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+               goto err_init;
+
        /* Sanity check */
-       if ( signature_len != context->max_len ) {
-               DBGC ( context, "RSA %p signature incorrect length (%zd "
+       if ( signature_len != context.max_len ) {
+               DBGC ( &context, "RSA %p signature incorrect length (%zd "
                       "bytes, should be %zd)\n",
-                      context, signature_len, context->max_len );
-               return -ERANGE;
+                      &context, signature_len, context.max_len );
+               rc = -ERANGE;
+               goto err_sanity;
        }
-       DBGC ( context, "RSA %p verifying %s digest:\n",
-              context, digest->name );
-       DBGC_HDA ( context, 0, value, digest->digestsize );
-       DBGC_HDA ( context, 0, signature, signature_len );
 
        /* Decipher the signature (using the big integer input buffer
         * as temporary storage)
         */
-       temp = context->input0;
+       temp = context.input0;
        expected = temp;
-       rsa_cipher ( context, signature, expected );
-       DBGC ( context, "RSA %p deciphered signature:\n", context );
-       DBGC_HDA ( context, 0, expected, context->max_len );
+       rsa_cipher ( &context, signature, expected );
+       DBGC ( &context, "RSA %p deciphered signature:\n", &context );
+       DBGC_HDA ( &context, 0, expected, context.max_len );
 
        /* Encode digest (using the big integer output buffer as
         * temporary storage)
         */
-       temp = context->output0;
+       temp = context.output0;
        actual = temp;
-       if ( ( rc = rsa_encode_digest ( context, digest, value, actual ) ) !=0 )
-               return rc;
+       if ( ( rc = rsa_encode_digest ( &context, digest, value,
+                                       actual ) ) != 0 )
+               goto err_encode;
 
        /* Verify the signature */
-       if ( memcmp ( actual, expected, context->max_len ) != 0 ) {
-               DBGC ( context, "RSA %p signature verification failed\n",
-                      context );
-               return -EACCES_VERIFY;
+       if ( memcmp ( actual, expected, context.max_len ) != 0 ) {
+               DBGC ( &context, "RSA %p signature verification failed\n",
+                      &context );
+               rc = -EACCES_VERIFY;
+               goto err_verify;
        }
 
-       DBGC ( context, "RSA %p signature verified successfully\n", context );
-       return 0;
-}
+       /* Free context */
+       rsa_free ( &context );
 
-/**
- * Finalise RSA cipher
- *
- * @v ctx              RSA context
- */
-static void rsa_final ( void *ctx ) {
-       struct rsa_context *context = ctx;
+       DBGC ( &context, "RSA %p signature verified successfully\n", &context );
+       return 0;
 
-       rsa_free ( context );
+ err_verify:
+ err_encode:
+ err_sanity:
+       rsa_free ( &context );
+ err_init:
+       return rc;
 }
 
 /**
@@ -615,14 +699,11 @@ static int rsa_match ( const struct asn1_cursor *private_key,
 /** RSA public-key algorithm */
 struct pubkey_algorithm rsa_algorithm = {
        .name           = "rsa",
-       .ctxsize        = RSA_CTX_SIZE,
-       .init           = rsa_init,
        .max_len        = rsa_max_len,
        .encrypt        = rsa_encrypt,
        .decrypt        = rsa_decrypt,
        .sign           = rsa_sign,
        .verify         = rsa_verify,
-       .final          = rsa_final,
        .match          = rsa_match,
 };
 
index c0762740eea0e4d958077cc8a2acb10332bea175..4101c8094141d57b476c11ea2544f77da002d30d 100644 (file)
@@ -1125,7 +1125,6 @@ static int x509_check_signature ( struct x509_certificate *cert,
        struct pubkey_algorithm *pubkey = algorithm->pubkey;
        uint8_t digest_ctx[ digest->ctxsize ];
        uint8_t digest_out[ digest->digestsize ];
-       uint8_t pubkey_ctx[ pubkey->ctxsize ];
        int rc;
 
        /* Sanity check */
@@ -1149,14 +1148,8 @@ static int x509_check_signature ( struct x509_certificate *cert,
        }
 
        /* Verify signature using signer's public key */
-       if ( ( rc = pubkey_init ( pubkey, pubkey_ctx,
-                                 &public_key->raw ) ) != 0 ) {
-               DBGC ( cert, "X509 %p \"%s\" cannot initialise public key: "
-                      "%s\n", cert, x509_name ( cert ), strerror ( rc ) );
-               goto err_pubkey_init;
-       }
-       if ( ( rc = pubkey_verify ( pubkey, pubkey_ctx, digest, digest_out,
-                                   signature->value.data,
+       if ( ( rc = pubkey_verify ( pubkey, &public_key->raw, digest,
+                                   digest_out, signature->value.data,
                                    signature->value.len ) ) != 0 ) {
                DBGC ( cert, "X509 %p \"%s\" signature verification failed: "
                       "%s\n", cert, x509_name ( cert ), strerror ( rc ) );
@@ -1167,8 +1160,6 @@ static int x509_check_signature ( struct x509_certificate *cert,
        rc = 0;
 
  err_pubkey_verify:
-       pubkey_final ( pubkey, pubkey_ctx );
- err_pubkey_init:
  err_mismatch:
        return rc;
 }
index 96eb0952b7306553dc19cb5b998817a7da577d0a..08459a6e239e2807f250125283c6c505b626f8ca 100644 (file)
@@ -362,17 +362,9 @@ static int icert_cert ( struct icert *icert, struct asn1_cursor *subject,
        struct asn1_builder raw = { NULL, 0 };
        uint8_t digest_ctx[SHA256_CTX_SIZE];
        uint8_t digest_out[SHA256_DIGEST_SIZE];
-       uint8_t pubkey_ctx[RSA_CTX_SIZE];
        int len;
        int rc;
 
-       /* Initialise "private" key */
-       if ( ( rc = pubkey_init ( pubkey, pubkey_ctx, private ) ) != 0 ) {
-               DBGC ( icert, "ICERT %p could not initialise private key: "
-                      "%s\n", icert, strerror ( rc ) );
-               goto err_pubkey_init;
-       }
-
        /* Construct subjectPublicKeyInfo */
        if ( ( rc = ( asn1_prepend_raw ( &spki, public->data, public->len ),
                      asn1_prepend_raw ( &spki, icert_nul,
@@ -406,14 +398,14 @@ static int icert_cert ( struct icert *icert, struct asn1_cursor *subject,
        digest_update ( digest, digest_ctx, tbs.data, tbs.len );
        digest_final ( digest, digest_ctx, digest_out );
 
-       /* Construct signature */
-       if ( ( rc = asn1_grow ( &raw, pubkey_max_len ( pubkey,
-                                                      pubkey_ctx ) ) ) != 0 ) {
+       /* Construct signature using "private" key */
+       if ( ( rc = asn1_grow ( &raw,
+                               pubkey_max_len ( pubkey, private ) ) ) != 0 ) {
                DBGC ( icert, "ICERT %p could not build signature: %s\n",
                       icert, strerror ( rc ) );
                goto err_grow;
        }
-       if ( ( len = pubkey_sign ( pubkey, pubkey_ctx, digest, digest_out,
+       if ( ( len = pubkey_sign ( pubkey, private, digest, digest_out,
                                   raw.data ) ) < 0 ) {
                rc = len;
                DBGC ( icert, "ICERT %p could not sign: %s\n",
@@ -452,8 +444,6 @@ static int icert_cert ( struct icert *icert, struct asn1_cursor *subject,
  err_tbs:
        free ( spki.data );
  err_spki:
-       pubkey_final ( pubkey, pubkey_ctx );
- err_pubkey_init:
        return rc;
 }
 
index 8b6eb94f6abe2a0c4b64d970de1884690bba9f24..dcc73f3efcf182a547814fe322956023871b10f3 100644 (file)
@@ -121,68 +121,55 @@ struct cipher_algorithm {
 struct pubkey_algorithm {
        /** Algorithm name */
        const char *name;
-       /** Context size */
-       size_t ctxsize;
-       /** Initialise algorithm
-        *
-        * @v ctx               Context
-        * @v key               Key
-        * @ret rc              Return status code
-        */
-       int ( * init ) ( void *ctx, const struct asn1_cursor *key );
        /** Calculate maximum output length
         *
-        * @v ctx               Context
+        * @v key               Key
         * @ret max_len         Maximum output length
         */
-       size_t ( * max_len ) ( void *ctx );
+       size_t ( * max_len ) ( const struct asn1_cursor *key );
        /** Encrypt
         *
-        * @v ctx               Context
+        * @v key               Key
         * @v plaintext         Plaintext
         * @v plaintext_len     Length of plaintext
         * @v ciphertext        Ciphertext
         * @ret ciphertext_len  Length of ciphertext, or negative error
         */
-       int ( * encrypt ) ( void *ctx, const void *data, size_t len,
-                           void *out );
+       int ( * encrypt ) ( const struct asn1_cursor *key, const void *data,
+                           size_t len, void *out );
        /** Decrypt
         *
-        * @v ctx               Context
+        * @v key               Key
         * @v ciphertext        Ciphertext
         * @v ciphertext_len    Ciphertext length
         * @v plaintext         Plaintext
         * @ret plaintext_len   Plaintext length, or negative error
         */
-       int ( * decrypt ) ( void *ctx, const void *data, size_t len,
-                           void *out );
+       int ( * decrypt ) ( const struct asn1_cursor *key, const void *data,
+                           size_t len, void *out );
        /** Sign digest value
         *
-        * @v ctx               Context
+        * @v key               Key
         * @v digest            Digest algorithm
         * @v value             Digest value
         * @v signature         Signature
         * @ret signature_len   Signature length, or negative error
         */
-       int ( * sign ) ( void *ctx, struct digest_algorithm *digest,
-                        const void *value, void *signature );
+       int ( * sign ) ( const struct asn1_cursor *key,
+                        struct digest_algorithm *digest, const void *value,
+                        void *signature );
        /** Verify signed digest value
         *
-        * @v ctx               Context
+        * @v key               Key
         * @v digest            Digest algorithm
         * @v value             Digest value
         * @v signature         Signature
         * @v signature_len     Signature length
         * @ret rc              Return status code
         */
-       int ( * verify ) ( void *ctx, struct digest_algorithm *digest,
-                          const void *value, const void *signature,
-                          size_t signature_len );
-       /** Finalise algorithm
-        *
-        * @v ctx               Context
-        */
-       void ( * final ) ( void *ctx );
+       int ( * verify ) ( const struct asn1_cursor *key,
+                          struct digest_algorithm *digest, const void *value,
+                          const void *signature, size_t signature_len );
        /** Check that public key matches private key
         *
         * @v private_key       Private key
@@ -278,46 +265,36 @@ is_auth_cipher ( struct cipher_algorithm *cipher ) {
        return cipher->authsize;
 }
 
-static inline __attribute__ (( always_inline )) int
-pubkey_init ( struct pubkey_algorithm *pubkey, void *ctx,
-             const struct asn1_cursor *key ) {
-       return pubkey->init ( ctx, key );
-}
-
 static inline __attribute__ (( always_inline )) size_t
-pubkey_max_len ( struct pubkey_algorithm *pubkey, void *ctx ) {
-       return pubkey->max_len ( ctx );
+pubkey_max_len ( struct pubkey_algorithm *pubkey,
+                const struct asn1_cursor *key ) {
+       return pubkey->max_len ( key );
 }
 
 static inline __attribute__ (( always_inline )) int
-pubkey_encrypt ( struct pubkey_algorithm *pubkey, void *ctx,
+pubkey_encrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
                 const void *data, size_t len, void *out ) {
-       return pubkey->encrypt ( ctx, data, len, out );
+       return pubkey->encrypt ( key, data, len, out );
 }
 
 static inline __attribute__ (( always_inline )) int
-pubkey_decrypt ( struct pubkey_algorithm *pubkey, void *ctx,
+pubkey_decrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
                 const void *data, size_t len, void *out ) {
-       return pubkey->decrypt ( ctx, data, len, out );
+       return pubkey->decrypt ( key, data, len, out );
 }
 
 static inline __attribute__ (( always_inline )) int
-pubkey_sign ( struct pubkey_algorithm *pubkey, void *ctx,
+pubkey_sign ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
              struct digest_algorithm *digest, const void *value,
              void *signature ) {
-       return pubkey->sign ( ctx, digest, value, signature );
+       return pubkey->sign ( key, digest, value, signature );
 }
 
 static inline __attribute__ (( always_inline )) int
-pubkey_verify ( struct pubkey_algorithm *pubkey, void *ctx,
+pubkey_verify ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
                struct digest_algorithm *digest, const void *value,
                const void *signature, size_t signature_len ) {
-       return pubkey->verify ( ctx, digest, value, signature, signature_len );
-}
-
-static inline __attribute__ (( always_inline )) void
-pubkey_final ( struct pubkey_algorithm *pubkey, void *ctx ) {
-       pubkey->final ( ctx );
+       return pubkey->verify ( key, digest, value, signature, signature_len );
 }
 
 static inline __attribute__ (( always_inline )) int
@@ -345,15 +322,18 @@ extern void cipher_null_decrypt ( void *ctx, const void *src, void *dst,
                                  size_t len );
 extern void cipher_null_auth ( void *ctx, void *auth );
 
-extern int pubkey_null_init ( void *ctx, const struct asn1_cursor *key );
-extern size_t pubkey_null_max_len ( void *ctx );
-extern int pubkey_null_encrypt ( void *ctx, const void *plaintext,
-                                size_t plaintext_len, void *ciphertext );
-extern int pubkey_null_decrypt ( void *ctx, const void *ciphertext,
-                                size_t ciphertext_len, void *plaintext );
-extern int pubkey_null_sign ( void *ctx, struct digest_algorithm *digest,
+extern size_t pubkey_null_max_len ( const struct asn1_cursor *key );
+extern int pubkey_null_encrypt ( const struct asn1_cursor *key,
+                                const void *plaintext, size_t plaintext_len,
+                                void *ciphertext );
+extern int pubkey_null_decrypt ( const struct asn1_cursor *key,
+                                const void *ciphertext, size_t ciphertext_len,
+                                void *plaintext );
+extern int pubkey_null_sign ( const struct asn1_cursor *key,
+                             struct digest_algorithm *digest,
                              const void *value, void *signature );
-extern int pubkey_null_verify ( void *ctx, struct digest_algorithm *digest,
+extern int pubkey_null_verify ( const struct asn1_cursor *key,
+                               struct digest_algorithm *digest,
                                const void *value, const void *signature ,
                                size_t signature_len );
 
index a1b5e0c0399c832a3fd5293ba497b0a1875a5deb..e36a75edfa15e177573d525e1ef0ee6b3b2b10b1 100644 (file)
@@ -55,31 +55,6 @@ struct rsa_digestinfo_prefix {
 /** Declare an RSA digestInfo prefix */
 #define __rsa_digestinfo_prefix __table_entry ( RSA_DIGESTINFO_PREFIXES, 01 )
 
-/** An RSA context */
-struct rsa_context {
-       /** Allocated memory */
-       void *dynamic;
-       /** Modulus */
-       bigint_element_t *modulus0;
-       /** Modulus size */
-       unsigned int size;
-       /** Modulus length */
-       size_t max_len;
-       /** Exponent */
-       bigint_element_t *exponent0;
-       /** Exponent size */
-       unsigned int exponent_size;
-       /** Input buffer */
-       bigint_element_t *input0;
-       /** Output buffer */
-       bigint_element_t *output0;
-       /** Temporary working space for modular exponentiation */
-       void *tmp;
-};
-
-/** RSA context size */
-#define RSA_CTX_SIZE sizeof ( struct rsa_context )
-
 extern struct pubkey_algorithm rsa_algorithm;
 
 #endif /* _IPXE_RSA_H */
index 9494eaa05d2bb9956e7df44e82dc05cb65777dc2..08d58689e9bf9f087863437ed40d68679b789f25 100644 (file)
@@ -240,8 +240,6 @@ struct tls_cipherspec {
        struct tls_cipher_suite *suite;
        /** Dynamically-allocated storage */
        void *dynamic;
-       /** Public key encryption context */
-       void *pubkey_ctx;
        /** Bulk encryption cipher context */
        void *cipher_ctx;
        /** MAC secret */
@@ -402,6 +400,8 @@ struct tls_server {
        struct x509_root *root;
        /** Certificate chain */
        struct x509_chain *chain;
+       /** Public key (within server certificate) */
+       struct asn1_cursor key;
        /** Certificate validator */
        struct interface validator;
        /** Certificate validation pending operation */
index ec503e43d289958790e3723722043fa152dcb155..ded100d0e427959f63270bbe8cc989c242e93691 100644 (file)
@@ -856,10 +856,6 @@ tls_find_cipher_suite ( unsigned int cipher_suite ) {
 static void tls_clear_cipher ( struct tls_connection *tls __unused,
                               struct tls_cipherspec *cipherspec ) {
 
-       if ( cipherspec->suite ) {
-               pubkey_final ( cipherspec->suite->pubkey,
-                              cipherspec->pubkey_ctx );
-       }
        free ( cipherspec->dynamic );
        memset ( cipherspec, 0, sizeof ( *cipherspec ) );
        cipherspec->suite = &tls_cipher_suite_null;
@@ -876,7 +872,6 @@ static void tls_clear_cipher ( struct tls_connection *tls __unused,
 static int tls_set_cipher ( struct tls_connection *tls,
                            struct tls_cipherspec *cipherspec,
                            struct tls_cipher_suite *suite ) {
-       struct pubkey_algorithm *pubkey = suite->pubkey;
        struct cipher_algorithm *cipher = suite->cipher;
        size_t total;
        void *dynamic;
@@ -885,8 +880,7 @@ static int tls_set_cipher ( struct tls_connection *tls,
        tls_clear_cipher ( tls, cipherspec );
 
        /* Allocate dynamic storage */
-       total = ( pubkey->ctxsize + cipher->ctxsize + suite->mac_len +
-                 suite->fixed_iv_len );
+       total = ( cipher->ctxsize + suite->mac_len + suite->fixed_iv_len );
        dynamic = zalloc ( total );
        if ( ! dynamic ) {
                DBGC ( tls, "TLS %p could not allocate %zd bytes for crypto "
@@ -896,7 +890,6 @@ static int tls_set_cipher ( struct tls_connection *tls,
 
        /* Assign storage */
        cipherspec->dynamic = dynamic;
-       cipherspec->pubkey_ctx = dynamic;       dynamic += pubkey->ctxsize;
        cipherspec->cipher_ctx = dynamic;       dynamic += cipher->ctxsize;
        cipherspec->mac_secret = dynamic;       dynamic += suite->mac_len;
        cipherspec->fixed_iv = dynamic;         dynamic += suite->fixed_iv_len;
@@ -1392,7 +1385,7 @@ static int tls_send_certificate ( struct tls_connection *tls ) {
 static int tls_send_client_key_exchange_pubkey ( struct tls_connection *tls ) {
        struct tls_cipherspec *cipherspec = &tls->tx.cipherspec.pending;
        struct pubkey_algorithm *pubkey = cipherspec->suite->pubkey;
-       size_t max_len = pubkey_max_len ( pubkey, cipherspec->pubkey_ctx );
+       size_t max_len = pubkey_max_len ( pubkey, &tls->server.key );
        struct {
                uint16_t version;
                uint8_t random[46];
@@ -1419,8 +1412,8 @@ static int tls_send_client_key_exchange_pubkey ( struct tls_connection *tls ) {
 
        /* Encrypt pre-master secret using server's public key */
        memset ( &key_xchg, 0, sizeof ( key_xchg ) );
-       len = pubkey_encrypt ( pubkey, cipherspec->pubkey_ctx,
-                              &pre_master_secret, sizeof ( pre_master_secret ),
+       len = pubkey_encrypt ( pubkey, &tls->server.key, &pre_master_secret,
+                              sizeof ( pre_master_secret ),
                               key_xchg.encrypted_pre_master_secret );
        if ( len < 0 ) {
                rc = len;
@@ -1523,7 +1516,7 @@ static int tls_verify_dh_params ( struct tls_connection *tls,
                digest_final ( digest, ctx, hash );
 
                /* Verify signature */
-               if ( ( rc = pubkey_verify ( pubkey, cipherspec->pubkey_ctx,
+               if ( ( rc = pubkey_verify ( pubkey, &tls->server.key,
                                            digest, hash, signature,
                                            signature_len ) ) != 0 ) {
                        DBGC ( tls, "TLS %p ServerKeyExchange failed "
@@ -1820,20 +1813,12 @@ static int tls_send_certificate_verify ( struct tls_connection *tls ) {
        struct pubkey_algorithm *pubkey = cert->signature_algorithm->pubkey;
        struct asn1_cursor *key = privkey_cursor ( tls->client.key );
        uint8_t digest_out[ digest->digestsize ];
-       uint8_t ctx[ pubkey->ctxsize ];
        struct tls_signature_hash_algorithm *sig_hash = NULL;
        int rc;
 
        /* Generate digest to be signed */
        tls_verify_handshake ( tls, digest_out );
 
-       /* Initialise public-key algorithm */
-       if ( ( rc = pubkey_init ( pubkey, ctx, key ) ) != 0 ) {
-               DBGC ( tls, "TLS %p could not initialise %s client private "
-                      "key: %s\n", tls, pubkey->name, strerror ( rc ) );
-               goto err_pubkey_init;
-       }
-
        /* TLSv1.2 and later use explicit algorithm identifiers */
        if ( tls_version ( tls, TLS_VERSION_TLS_1_2 ) ) {
                sig_hash = tls_signature_hash_algorithm ( pubkey, digest );
@@ -1848,7 +1833,7 @@ static int tls_send_certificate_verify ( struct tls_connection *tls ) {
 
        /* Generate and transmit record */
        {
-               size_t max_len = pubkey_max_len ( pubkey, ctx );
+               size_t max_len = pubkey_max_len ( pubkey, key );
                int use_sig_hash = ( ( sig_hash == NULL ) ? 0 : 1 );
                struct {
                        uint32_t type_length;
@@ -1860,7 +1845,7 @@ static int tls_send_certificate_verify ( struct tls_connection *tls ) {
                int len;
 
                /* Sign digest */
-               len = pubkey_sign ( pubkey, ctx, digest, digest_out,
+               len = pubkey_sign ( pubkey, key, digest, digest_out,
                                    certificate_verify.signature );
                if ( len < 0 ) {
                        rc = len;
@@ -1893,8 +1878,6 @@ static int tls_send_certificate_verify ( struct tls_connection *tls ) {
 
  err_pubkey_sign:
  err_sig_hash:
-       pubkey_final ( pubkey, ctx );
- err_pubkey_init:
        return rc;
 }
 
@@ -2312,6 +2295,7 @@ static int tls_parse_chain ( struct tls_connection *tls,
        int rc;
 
        /* Free any existing certificate chain */
+       memset ( &tls->server.key, 0, sizeof ( tls->server.key ) );
        x509_chain_put ( tls->server.chain );
        tls->server.chain = NULL;
 
@@ -2371,6 +2355,7 @@ static int tls_parse_chain ( struct tls_connection *tls,
  err_parse:
  err_overlength:
  err_underlength:
+       memset ( &tls->server.key, 0, sizeof ( tls->server.key ) );
        x509_chain_put ( tls->server.chain );
        tls->server.chain = NULL;
  err_alloc_chain:
@@ -3555,8 +3540,6 @@ static struct interface_descriptor tls_cipherstream_desc =
  */
 static void tls_validator_done ( struct tls_connection *tls, int rc ) {
        struct tls_session *session = tls->session;
-       struct tls_cipherspec *cipherspec = &tls->tx.cipherspec.pending;
-       struct pubkey_algorithm *pubkey = cipherspec->suite->pubkey;
        struct x509_certificate *cert;
 
        /* Mark validation as complete */
@@ -3584,13 +3567,9 @@ static void tls_validator_done ( struct tls_connection *tls, int rc ) {
                goto err;
        }
 
-       /* Initialise public key algorithm */
-       if ( ( rc = pubkey_init ( pubkey, cipherspec->pubkey_ctx,
-                                 &cert->subject.public_key.raw ) ) != 0 ) {
-               DBGC ( tls, "TLS %p cannot initialise public key: %s\n",
-                      tls, strerror ( rc ) );
-               goto err;
-       }
+       /* Extract the now trusted server public key */
+       memcpy ( &tls->server.key, &cert->subject.public_key.raw,
+                sizeof ( tls->server.key ) );
 
        /* Schedule Client Key Exchange, Change Cipher, and Finished */
        tls->tx.pending |= ( TLS_TX_CLIENT_KEY_EXCHANGE |
index 93962516a5037638965e456134059abde62b9e65..ff318bfb78a3080476bceb1d4da5dd9d5d437688 100644 (file)
@@ -50,77 +50,41 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
 void pubkey_okx ( struct pubkey_test *test, const char *file,
                  unsigned int line ) {
        struct pubkey_algorithm *pubkey = test->pubkey;
-       uint8_t private_ctx[pubkey->ctxsize];
-       uint8_t public_ctx[pubkey->ctxsize];
-       size_t max_len;
-
-       /* Initialize contexts */
-       okx ( pubkey_init ( pubkey, private_ctx, &test->private ) == 0,
-             file, line );
-       okx ( pubkey_init ( pubkey, public_ctx, &test->public ) == 0,
-             file, line );
-       max_len = pubkey_max_len ( pubkey, private_ctx );
+       size_t max_len = pubkey_max_len ( pubkey, &test->private );
+       uint8_t encrypted[max_len];
+       uint8_t decrypted[max_len];
+       int encrypted_len;
+       int decrypted_len;
 
        /* Test decrypting with private key to obtain known plaintext */
-       {
-               uint8_t decrypted[max_len];
-               int decrypted_len;
-
-               decrypted_len = pubkey_decrypt ( pubkey, private_ctx,
-                                                test->ciphertext,
-                                                test->ciphertext_len,
-                                                decrypted );
-               okx ( decrypted_len == ( ( int ) test->plaintext_len ),
-                     file, line );
-               okx ( memcmp ( decrypted, test->plaintext,
-                              test->plaintext_len ) == 0, file, line );
-       }
+       decrypted_len = pubkey_decrypt ( pubkey, &test->private,
+                                        test->ciphertext, test->ciphertext_len,
+                                        decrypted );
+       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
+       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
+             file, line );
 
        /* Test encrypting with private key and decrypting with public key */
-       {
-               uint8_t encrypted[max_len];
-               uint8_t decrypted[max_len];
-               int encrypted_len;
-               int decrypted_len;
-
-               encrypted_len = pubkey_encrypt ( pubkey, private_ctx,
-                                                test->plaintext,
-                                                test->plaintext_len,
-                                                encrypted );
-               okx ( encrypted_len >= 0, file, line );
-               decrypted_len = pubkey_decrypt ( pubkey, public_ctx,
-                                                encrypted, encrypted_len,
-                                                decrypted );
-               okx ( decrypted_len == ( ( int ) test->plaintext_len ),
-                     file, line );
-               okx ( memcmp ( decrypted, test->plaintext,
-                              test->plaintext_len ) == 0, file, line );
-       }
+       encrypted_len = pubkey_encrypt ( pubkey, &test->private,
+                                        test->plaintext, test->plaintext_len,
+                                        encrypted );
+       okx ( encrypted_len >= 0, file, line );
+       decrypted_len = pubkey_decrypt ( pubkey, &test->public, encrypted,
+                                        encrypted_len, decrypted );
+       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
+       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
+             file, line );
 
        /* Test encrypting with public key and decrypting with private key */
-       {
-               uint8_t encrypted[max_len];
-               uint8_t decrypted[max_len];
-               int encrypted_len;
-               int decrypted_len;
-
-               encrypted_len = pubkey_encrypt ( pubkey, public_ctx,
-                                                test->plaintext,
-                                                test->plaintext_len,
-                                                encrypted );
-               okx ( encrypted_len >= 0, file, line );
-               decrypted_len = pubkey_decrypt ( pubkey, private_ctx,
-                                                encrypted, encrypted_len,
-                                                decrypted );
-               okx ( decrypted_len == ( ( int ) test->plaintext_len ),
-                     file, line );
-               okx ( memcmp ( decrypted, test->plaintext,
-                              test->plaintext_len ) == 0, file, line );
-       }
-
-       /* Free contexts */
-       pubkey_final ( pubkey, public_ctx );
-       pubkey_final ( pubkey, private_ctx );
+       encrypted_len = pubkey_encrypt ( pubkey, &test->public,
+                                        test->plaintext, test->plaintext_len,
+                                        encrypted );
+       okx ( encrypted_len >= 0, file, line );
+       decrypted_len = pubkey_decrypt ( pubkey, &test->private, encrypted,
+                                        encrypted_len, decrypted );
+       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
+       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
+             file, line );
 }
 
 /**
@@ -134,18 +98,12 @@ void pubkey_sign_okx ( struct pubkey_sign_test *test, const char *file,
                       unsigned int line ) {
        struct pubkey_algorithm *pubkey = test->pubkey;
        struct digest_algorithm *digest = test->digest;
-       uint8_t private_ctx[pubkey->ctxsize];
-       uint8_t public_ctx[pubkey->ctxsize];
+       size_t max_len = pubkey_max_len ( pubkey, &test->private );
+       uint8_t bad[test->signature_len];
        uint8_t digestctx[digest->ctxsize ];
        uint8_t digestout[digest->digestsize];
-       size_t max_len;
-
-       /* Initialize contexts */
-       okx ( pubkey_init ( pubkey, private_ctx, &test->private ) == 0,
-             file, line );
-       okx ( pubkey_init ( pubkey, public_ctx, &test->public ) == 0,
-             file, line );
-       max_len = pubkey_max_len ( pubkey, private_ctx );
+       uint8_t signature[max_len];
+       int signature_len;
 
        /* Construct digest over plaintext */
        digest_init ( digest, digestctx );
@@ -154,34 +112,20 @@ void pubkey_sign_okx ( struct pubkey_sign_test *test, const char *file,
        digest_final ( digest, digestctx, digestout );
 
        /* Test signing using private key */
-       {
-               uint8_t signature[max_len];
-               int signature_len;
-
-               signature_len = pubkey_sign ( pubkey, private_ctx, digest,
-                                             digestout, signature );
-               okx ( signature_len == ( ( int ) test->signature_len ),
-                     file, line );
-               okx ( memcmp ( signature, test->signature,
-                              test->signature_len ) == 0, file, line );
-       }
+       signature_len = pubkey_sign ( pubkey, &test->private, digest,
+                                     digestout, signature );
+       okx ( signature_len == ( ( int ) test->signature_len ), file, line );
+       okx ( memcmp ( signature, test->signature, test->signature_len ) == 0,
+             file, line );
 
        /* Test verification using public key */
-       okx ( pubkey_verify ( pubkey, public_ctx, digest, digestout,
+       okx ( pubkey_verify ( pubkey, &test->public, digest, digestout,
                              test->signature, test->signature_len ) == 0,
              file, line );
 
        /* Test verification failure of modified signature */
-       {
-               uint8_t bad[test->signature_len];
-
-               memcpy ( bad, test->signature, test->signature_len );
-               bad[ test->signature_len / 2 ] ^= 0x40;
-               okx ( pubkey_verify ( pubkey, public_ctx, digest, digestout,
-                                     bad, sizeof ( bad ) ) != 0, file, line );
-       }
-
-       /* Free contexts */
-       pubkey_final ( pubkey, public_ctx );
-       pubkey_final ( pubkey, private_ctx );
+       memcpy ( bad, test->signature, test->signature_len );
+       bad[ test->signature_len / 2 ] ^= 0x40;
+       okx ( pubkey_verify ( pubkey, &test->public, digest, digestout,
+                             bad, sizeof ( bad ) ) != 0, file, line );
 }