]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[crypto] Construct asymmetric ciphered data using ASN.1 builders
authorMichael Brown <mcb30@ipxe.org>
Tue, 2 Dec 2025 13:12:25 +0000 (13:12 +0000)
committerMichael Brown <mcb30@ipxe.org>
Tue, 2 Dec 2025 13:12:25 +0000 (13:12 +0000)
Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/crypto/cms.c
src/crypto/crypto_null.c
src/crypto/rsa.c
src/include/ipxe/crypto.h
src/net/tls.c
src/tests/pubkey_test.c
src/tests/pubkey_test.h

index a3c03a9b48752f1c93154b96d150f84595ba68a5..7775e581b8a9dc9dd423c83c93e6d2bdb668d407 100644 (file)
@@ -917,29 +917,26 @@ static int cms_cipher_key ( struct cms_message *cms,
        struct pubkey_algorithm *pubkey = part->pubkey;
        const struct asn1_cursor *key = privkey_cursor ( private_key );
        const struct asn1_cursor *value = &part->value;
-       size_t max_len = pubkey_max_len ( pubkey, key );
-       uint8_t cipher_key[max_len];
-       int len;
+       struct asn1_builder cipher_key = { NULL, 0 };
        int rc;
 
        /* Decrypt cipher key */
-       len = pubkey_decrypt ( pubkey, key, value->data, value->len,
-                              cipher_key );
-       if ( len < 0 ) {
-               rc = len;
+       if ( ( rc = pubkey_decrypt ( pubkey, key, value,
+                                    &cipher_key ) ) != 0 ) {
                DBGC ( cms, "CMS %p/%p could not decrypt cipher key: %s\n",
                       cms, part, strerror ( rc ) );
                DBGC_HDA ( cms, 0, value->data, value->len );
-               return rc;
+               goto err_decrypt;
        }
        DBGC ( cms, "CMS %p/%p cipher key:\n", cms, part );
-       DBGC_HDA ( cms, 0, cipher_keylen );
+       DBGC_HDA ( cms, 0, cipher_key.data, cipher_key.len );
 
        /* Set cipher key */
-       if ( ( rc = cipher_setkey ( cipher, ctx, cipher_key, len ) ) != 0 ) {
+       if ( ( rc = cipher_setkey ( cipher, ctx, cipher_key.data,
+                                   cipher_key.len ) ) != 0 ) {
                DBGC ( cms, "CMS %p could not set cipher key: %s\n",
                       cms, strerror ( rc ) );
-               return rc;
+               goto err_setkey;
        }
 
        /* Set cipher initialization vector */
@@ -949,7 +946,10 @@ static int cms_cipher_key ( struct cms_message *cms,
                DBGC_HDA ( cms, 0, cms->iv.data, cms->iv.len );
        }
 
-       return 0;
+ err_setkey:
+ err_decrypt:
+       free ( cipher_key.data );
+       return rc;
 }
 
 /**
index ee948e00d93d56f05bca7f9da6530a143bd5fec8..e8f8cbde8ed01dbe5aeb2b588d957b365af4a764 100644 (file)
@@ -98,16 +98,14 @@ size_t pubkey_null_max_len ( const struct asn1_cursor *key __unused ) {
 }
 
 int pubkey_null_encrypt ( const struct asn1_cursor *key __unused,
-                         const void *plaintext __unused,
-                         size_t plaintext_len __unused,
-                         void *ciphertext __unused ) {
+                         const struct asn1_cursor *plaintext __unused,
+                         struct asn1_builder *ciphertext __unused ) {
        return 0;
 }
 
 int pubkey_null_decrypt ( const struct asn1_cursor *key __unused,
-                         const void *ciphertext __unused,
-                         size_t ciphertext_len __unused,
-                         void *plaintext __unused ) {
+                         const struct asn1_cursor *ciphertext __unused,
+                         struct asn1_builder *plaintext __unused ) {
        return 0;
 }
 
index fd6a1ef395a02fdaf3a8fc3620de5fea1d79b037..18b2b1c1479834372a48ba474d1c4681202a4e02 100644 (file)
@@ -338,12 +338,12 @@ static void rsa_cipher ( struct rsa_context *context,
  *
  * @v key              Key
  * @v plaintext                Plaintext
- * @v plaintext_len    Length of plaintext
  * @v ciphertext       Ciphertext
  * @ret ciphertext_len Length of ciphertext, or negative error
  */
-static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
-                        size_t plaintext_len, void *ciphertext ) {
+static int rsa_encrypt ( const struct asn1_cursor *key,
+                        const struct asn1_cursor *plaintext,
+                        struct asn1_builder *ciphertext ) {
        struct rsa_context context;
        void *temp;
        uint8_t *encoded;
@@ -352,7 +352,7 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
        int rc;
 
        DBGC ( &context, "RSA %p encrypting:\n", &context );
-       DBGC_HDA ( &context, 0, plaintext, plaintext_len );
+       DBGC_HDA ( &context, 0, plaintext->data, plaintext->len );
 
        /* Initialise context */
        if ( ( rc = rsa_init ( &context, key ) ) != 0 )
@@ -360,12 +360,12 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
 
        /* Calculate lengths */
        max_len = ( context.max_len - 11 );
-       random_nz_len = ( max_len - plaintext_len + 8 );
+       random_nz_len = ( max_len - plaintext->len + 8 );
 
        /* Sanity check */
-       if ( plaintext_len > max_len ) {
+       if ( plaintext->len > max_len ) {
                DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max "
-                      "%zd)\n", &context, plaintext_len, max_len );
+                      "%zd)\n", &context, plaintext->len, max_len );
                rc = -ERANGE;
                goto err_sanity;
        }
@@ -383,19 +383,24 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
                goto err_random;
        }
        encoded[ 2 + random_nz_len ] = 0x00;
-       memcpy ( &encoded[ context.max_len - plaintext_len ],
-                plaintext, plaintext_len );
+       memcpy ( &encoded[ context.max_len - plaintext->len ],
+                plaintext->data, plaintext->len );
+
+       /* Create space for ciphertext */
+       if ( ( rc = asn1_grow ( ciphertext, context.max_len ) ) != 0 )
+               goto err_grow;
 
        /* Encipher the encoded message */
-       rsa_cipher ( &context, encoded, ciphertext );
+       rsa_cipher ( &context, encoded, ciphertext->data );
        DBGC ( &context, "RSA %p encrypted:\n", &context );
-       DBGC_HDA ( &context, 0, ciphertext, context.max_len );
+       DBGC_HDA ( &context, 0, ciphertext->data, context.max_len );
 
        /* Free context */
        rsa_free ( &context );
 
-       return context.max_len;
+       return 0;
 
+ err_grow:
  err_random:
  err_sanity:
        rsa_free ( &context );
@@ -408,33 +413,33 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext,
  *
  * @v key              Key
  * @v ciphertext       Ciphertext
- * @v ciphertext_len   Ciphertext length
  * @v plaintext                Plaintext
- * @ret plaintext_len  Plaintext length, or negative error
+ * @ret rc             Return status code
  */
-static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext,
-                        size_t ciphertext_len, void *plaintext ) {
+static int rsa_decrypt ( const struct asn1_cursor *key,
+                        const struct asn1_cursor *ciphertext,
+                        struct asn1_builder *plaintext ) {
        struct rsa_context context;
        void *temp;
        uint8_t *encoded;
        uint8_t *end;
        uint8_t *zero;
        uint8_t *start;
-       size_t plaintext_len;
+       size_t len;
        int rc;
 
        DBGC ( &context, "RSA %p decrypting:\n", &context );
-       DBGC_HDA ( &context, 0, ciphertext, ciphertext_len );
+       DBGC_HDA ( &context, 0, ciphertext->data, ciphertext->len );
 
        /* Initialise context */
        if ( ( rc = rsa_init ( &context, key ) ) != 0 )
                goto err_init;
 
        /* Sanity check */
-       if ( ciphertext_len != context.max_len ) {
+       if ( ciphertext->len != context.max_len ) {
                DBGC ( &context, "RSA %p ciphertext incorrect length (%zd "
                       "bytes, should be %zd)\n",
-                      &context, ciphertext_len, context.max_len );
+                      &context, ciphertext->len, context.max_len );
                rc = -ERANGE;
                goto err_sanity;
        }
@@ -444,7 +449,7 @@ static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext,
         */
        temp = context.input0;
        encoded = temp;
-       rsa_cipher ( &context, ciphertext, encoded );
+       rsa_cipher ( &context, ciphertext->data, encoded );
 
        /* Parse the message */
        end = ( encoded + context.max_len );
@@ -454,25 +459,31 @@ static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext,
        }
        zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) );
        if ( ! zero ) {
+               DBGC ( &context, "RSA %p invalid decrypted message:\n",
+                      &context );
+               DBGC_HDA ( &context, 0, encoded, context.max_len );
                rc = -EINVAL;
                goto err_invalid;
        }
        start = ( zero + 1 );
-       plaintext_len = ( end - start );
+       len = ( end - start );
+
+       /* Create space for plaintext */
+       if ( ( rc = asn1_grow ( plaintext, len ) ) != 0 )
+               goto err_grow;
 
        /* Copy out message */
-       memcpy ( plaintext, start, plaintext_len );
+       memcpy ( plaintext->data, start, len );
        DBGC ( &context, "RSA %p decrypted:\n", &context );
-       DBGC_HDA ( &context, 0, plaintext, plaintext_len );
+       DBGC_HDA ( &context, 0, plaintext->data, len );
 
        /* Free context */
        rsa_free ( &context );
 
-       return plaintext_len;
+       return 0;
 
+ err_grow:
  err_invalid:
-       DBGC ( &context, "RSA %p invalid decrypted message:\n", &context );
-       DBGC_HDA ( &context, 0, encoded, context.max_len );
  err_sanity:
        rsa_free ( &context );
  err_init:
index c457a74b11e3f3bc6a999b485b4e5ad7804cba7a..68bd230482d6ee012599dcd63e7132bd5a539b5d 100644 (file)
@@ -131,22 +131,22 @@ struct pubkey_algorithm {
         *
         * @v key               Key
         * @v plaintext         Plaintext
-        * @v plaintext_len     Length of plaintext
         * @v ciphertext        Ciphertext
-        * @ret ciphertext_len  Length of ciphertext, or negative error
+        * @ret rc              Return status code
         */
-       int ( * encrypt ) ( const struct asn1_cursor *key, const void *data,
-                           size_t len, void *out );
+       int ( * encrypt ) ( const struct asn1_cursor *key,
+                           const struct asn1_cursor *plaintext,
+                           struct asn1_builder *ciphertext );
        /** Decrypt
         *
         * @v key               Key
         * @v ciphertext        Ciphertext
-        * @v ciphertext_len    Ciphertext length
         * @v plaintext         Plaintext
-        * @ret plaintext_len   Plaintext length, or negative error
+        * @ret rc              Return status code
         */
-       int ( * decrypt ) ( const struct asn1_cursor *key, const void *data,
-                           size_t len, void *out );
+       int ( * decrypt ) ( const struct asn1_cursor *key,
+                           const struct asn1_cursor *ciphertext,
+                           struct asn1_builder *plaintext );
        /** Sign digest value
         *
         * @v key               Key
@@ -274,14 +274,16 @@ pubkey_max_len ( struct pubkey_algorithm *pubkey,
 
 static inline __attribute__ (( always_inline )) int
 pubkey_encrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
-                const void *data, size_t len, void *out ) {
-       return pubkey->encrypt ( key, data, len, out );
+                const struct asn1_cursor *plaintext,
+                struct asn1_builder *ciphertext ) {
+       return pubkey->encrypt ( key, plaintext, ciphertext );
 }
 
 static inline __attribute__ (( always_inline )) int
 pubkey_decrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key,
-                const void *data, size_t len, void *out ) {
-       return pubkey->decrypt ( key, data, len, out );
+                const struct asn1_cursor *ciphertext,
+                struct asn1_builder *plaintext ) {
+       return pubkey->decrypt ( key, ciphertext, plaintext );
 }
 
 static inline __attribute__ (( always_inline )) int
@@ -325,11 +327,11 @@ extern void cipher_null_auth ( void *ctx, void *auth );
 
 extern size_t pubkey_null_max_len ( const struct asn1_cursor *key );
 extern int pubkey_null_encrypt ( const struct asn1_cursor *key,
-                                const void *plaintext, size_t plaintext_len,
-                                void *ciphertext );
+                                const struct asn1_cursor *plaintext,
+                                struct asn1_builder *ciphertext );
 extern int pubkey_null_decrypt ( const struct asn1_cursor *key,
-                                const void *ciphertext, size_t ciphertext_len,
-                                void *plaintext );
+                                const struct asn1_cursor *ciphertext,
+                                struct asn1_builder *plaintext );
 extern int pubkey_null_sign ( const struct asn1_cursor *key,
                              struct digest_algorithm *digest,
                              const void *value,
index c01ce9515489a6854479c1c90f7c3d9d8756e6cf..6140ca58a90f07f934aec796de6af5d50c7781cd 100644 (file)
@@ -1416,59 +1416,69 @@ static int tls_send_certificate ( struct tls_connection *tls ) {
 static int tls_send_client_key_exchange_pubkey ( struct tls_connection *tls ) {
        struct tls_cipherspec *cipherspec = &tls->tx.cipherspec.pending;
        struct pubkey_algorithm *pubkey = cipherspec->suite->pubkey;
-       size_t max_len = pubkey_max_len ( pubkey, &tls->server.key );
        struct {
                uint16_t version;
                uint8_t random[46];
        } __attribute__ (( packed )) pre_master_secret;
-       struct {
-               uint32_t type_length;
-               uint16_t encrypted_pre_master_secret_len;
-               uint8_t encrypted_pre_master_secret[max_len];
-       } __attribute__ (( packed )) key_xchg;
-       size_t unused;
-       int len;
+       struct asn1_cursor cursor = {
+               .data = &pre_master_secret,
+               .len = sizeof ( pre_master_secret ),
+       };
+       struct asn1_builder builder = { NULL, 0 };
        int rc;
 
        /* Generate pre-master secret */
        pre_master_secret.version = htons ( TLS_VERSION_MAX );
        if ( ( rc = tls_generate_random ( tls, &pre_master_secret.random,
                          ( sizeof ( pre_master_secret.random ) ) ) ) != 0 ) {
-               return rc;
+               goto err_random;
        }
 
        /* Encrypt pre-master secret using server's public key */
-       memset ( &key_xchg, 0, sizeof ( key_xchg ) );
-       len = pubkey_encrypt ( pubkey, &tls->server.key, &pre_master_secret,
-                              sizeof ( pre_master_secret ),
-                              key_xchg.encrypted_pre_master_secret );
-       if ( len < 0 ) {
-               rc = len;
+       if ( ( rc = pubkey_encrypt ( pubkey, &tls->server.key, &cursor,
+                                    &builder ) ) != 0 ) {
                DBGC ( tls, "TLS %p could not encrypt pre-master secret: %s\n",
                       tls, strerror ( rc ) );
-               return rc;
+               goto err_encrypt;
+       }
+
+       /* Construct Client Key Exchange record */
+       {
+               struct {
+                       uint32_t type_length;
+                       uint16_t encrypted_pre_master_secret_len;
+               } __attribute__ (( packed )) header;
+
+               header.type_length =
+                       ( cpu_to_le32 ( TLS_CLIENT_KEY_EXCHANGE ) |
+                         htonl ( builder.len + sizeof ( header ) -
+                                 sizeof ( header.type_length ) ) );
+               header.encrypted_pre_master_secret_len = htons ( builder.len );
+
+               if ( ( rc = asn1_prepend_raw ( &builder, &header,
+                                              sizeof ( header ) ) ) != 0 ) {
+                       DBGC ( tls, "TLS %p could not construct Client Key "
+                              "Exchange: %s\n", tls, strerror ( rc ) );
+                       goto err_prepend;
+               }
        }
-       unused = ( max_len - len );
-       key_xchg.type_length =
-               ( cpu_to_le32 ( TLS_CLIENT_KEY_EXCHANGE ) |
-                 htonl ( sizeof ( key_xchg ) -
-                         sizeof ( key_xchg.type_length ) - unused ) );
-       key_xchg.encrypted_pre_master_secret_len =
-               htons ( sizeof ( key_xchg.encrypted_pre_master_secret ) -
-                       unused );
 
        /* Transmit Client Key Exchange record */
-       if ( ( rc = tls_send_handshake ( tls, &key_xchg,
-                                        ( sizeof ( key_xchg ) -
-                                          unused ) ) ) != 0 ) {
-               return rc;
+       if ( ( rc = tls_send_handshake ( tls, builder.data,
+                                        builder.len ) ) != 0 ) {
+               goto err_send;
        }
 
        /* Generate master secret */
        tls_generate_master_secret ( tls, &pre_master_secret,
                                     sizeof ( pre_master_secret ) );
 
-       return 0;
+ err_random:
+ err_encrypt:
+ err_prepend:
+ err_send:
+       free ( builder.data );
+       return rc;
 }
 
 /** Public key exchange algorithm */
index e3fbc3b3f530e85f79fb392325cff5d98ec546e5..d110b294623e8e60f6ff32a8bb03b88c8a1d561a 100644 (file)
@@ -50,41 +50,47 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
 void pubkey_okx ( struct pubkey_test *test, const char *file,
                  unsigned int line ) {
        struct pubkey_algorithm *pubkey = test->pubkey;
-       size_t max_len = pubkey_max_len ( pubkey, &test->private );
-       uint8_t encrypted[max_len];
-       uint8_t decrypted[max_len];
-       int encrypted_len;
-       int decrypted_len;
+       struct asn1_builder plaintext;
+       struct asn1_builder ciphertext;
 
        /* Test decrypting with private key to obtain known plaintext */
-       decrypted_len = pubkey_decrypt ( pubkey, &test->private,
-                                        test->ciphertext, test->ciphertext_len,
-                                        decrypted );
-       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
-       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
-             file, line );
+       plaintext.data = NULL;
+       plaintext.len = 0;
+       okx ( pubkey_decrypt ( pubkey, &test->private, &test->ciphertext,
+                              &plaintext ) == 0, file, line );
+       okx ( asn1_compare ( asn1_built ( &plaintext ),
+                            &test->plaintext ) == 0, file, line );
+       free ( plaintext.data );
 
        /* Test encrypting with private key and decrypting with public key */
-       encrypted_len = pubkey_encrypt ( pubkey, &test->private,
-                                        test->plaintext, test->plaintext_len,
-                                        encrypted );
-       okx ( encrypted_len >= 0, file, line );
-       decrypted_len = pubkey_decrypt ( pubkey, &test->public, encrypted,
-                                        encrypted_len, decrypted );
-       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
-       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
-             file, line );
+       ciphertext.data = NULL;
+       ciphertext.len = 0;
+       plaintext.data = NULL;
+       plaintext.len = 0;
+       okx ( pubkey_encrypt ( pubkey, &test->private, &test->plaintext,
+                              &ciphertext ) == 0, file, line );
+       okx ( pubkey_decrypt ( pubkey, &test->public,
+                              asn1_built ( &ciphertext ),
+                              &plaintext ) == 0, file, line );
+       okx ( asn1_compare ( asn1_built ( &plaintext ),
+                            &test->plaintext ) == 0, file, line );
+       free ( ciphertext.data );
+       free ( plaintext.data );
 
        /* Test encrypting with public key and decrypting with private key */
-       encrypted_len = pubkey_encrypt ( pubkey, &test->public,
-                                        test->plaintext, test->plaintext_len,
-                                        encrypted );
-       okx ( encrypted_len >= 0, file, line );
-       decrypted_len = pubkey_decrypt ( pubkey, &test->private, encrypted,
-                                        encrypted_len, decrypted );
-       okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line );
-       okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0,
-             file, line );
+       ciphertext.data = NULL;
+       ciphertext.len = 0;
+       plaintext.data = NULL;
+       plaintext.len = 0;
+       okx ( pubkey_encrypt ( pubkey, &test->public, &test->plaintext,
+                              &ciphertext ) == 0, file, line );
+       okx ( pubkey_decrypt ( pubkey, &test->private,
+                              asn1_built ( &ciphertext ),
+                              &plaintext ) == 0, file, line );
+       okx ( asn1_compare ( asn1_built ( &plaintext ),
+                            &test->plaintext ) == 0, file, line );
+       free ( ciphertext.data );
+       free ( plaintext.data );
 }
 
 /**
index 1bb6caf515eb1fb700093f9aeee9e90f3170e810..33b301a6ec3cf6dd9d0b778e9955407ae462a24c 100644 (file)
@@ -16,18 +16,14 @@ struct pubkey_test {
        /** Public key */
        const struct asn1_cursor public;
        /** Plaintext */
-       const void *plaintext;
-       /** Length of plaintext */
-       size_t plaintext_len;
+       const struct asn1_cursor plaintext;
        /** Ciphertext
         *
         * Note that the encryption process may include some random
         * padding, so a given plaintext will encrypt to multiple
         * different ciphertexts.
         */
-       const void *ciphertext;
-       /** Length of ciphertext */
-       size_t ciphertext_len;
+       const struct asn1_cursor ciphertext;
 };
 
 /** A public-key signature test */
@@ -90,10 +86,14 @@ struct pubkey_sign_test {
                        .data = name ## _public,                        \
                        .len = sizeof ( name ## _public ),              \
                },                                                      \
-               .plaintext = name ## _plaintext,                        \
-               .plaintext_len = sizeof ( name ## _plaintext ),         \
-               .ciphertext = name ## _ciphertext,                      \
-               .ciphertext_len = sizeof ( name ## _ciphertext ),       \
+               .plaintext = {                                          \
+                       .data = name ## _plaintext,                     \
+                       .len = sizeof ( name ## _plaintext ),           \
+               },                                                      \
+               .ciphertext = {                                         \
+                       .data = name ## _ciphertext,                    \
+                       .len = sizeof ( name ## _ciphertext ),          \
+               },                                                      \
        }
 
 /**