]> git.ipfire.org Git - thirdparty/tor.git/commitdiff
Optimize the everloving heck out of OpenSSL AES as used by CGO.
authorNick Mathewson <nickm@torproject.org>
Thu, 15 May 2025 16:21:49 +0000 (12:21 -0400)
committerNick Mathewson <nickm@torproject.org>
Wed, 21 May 2025 17:00:03 +0000 (13:00 -0400)
Optimizations:

1. Calling EVP_CryptInit with a cipher returned by
   e.g. EVP_aes_128_ctr() is quite slow, since it needs to look
   up the _actual_ EVP_CIPHER corresponding to the given EVP,
   which involves grabbing locks, doing a search through a
   provider, and so on.  We use EVP_CIPHER_fetch to speed
   that up a lot.

2. There is not in fact any need to EVP_CIPHER_CTX_Reset a
   cipher before calling EVP_CryptInit on it a second time

2. Using an ECB cipher + CRYPTO_ctr128_encrypt was not in fact
   the most efficient way to implement a counter mode with an
   adjustable IV.  Instead, the fastest way seems to be:
     - Set the IV manually
     - Ensure that we are always aligned to block boundary
       when we do so.

src/core/crypto/relay_crypto_cgo.c
src/core/crypto/relay_crypto_cgo.h
src/lib/crypt_ops/.may_include
src/lib/crypt_ops/aes.h
src/lib/crypt_ops/aes_nss.c
src/lib/crypt_ops/aes_openssl.c
src/test/bench.c
src/test/test_crypto.c

index c833e6a87b2742cb0b2aba5c4c82e1bcc25c384d..d3622de22883f00403cfea19579947e7ac8174ef 100644 (file)
@@ -147,9 +147,10 @@ STATIC int
 cgo_prf_init(cgo_prf_t *prf, int aesbits,
              const uint8_t *key)
 {
+  const uint8_t iv[16] = {0};
   size_t aes_key_bytes = aesbits / 8;
   memset(prf,0, sizeof(*prf));
-  prf->k = aes_raw_new(key, aesbits, true);
+  prf->k = aes_new_cipher(key, iv, aesbits);
   polyval_key_init(&prf->b, key + aes_key_bytes);
   return 0;
 }
@@ -161,7 +162,7 @@ cgo_prf_set_key(cgo_prf_t *prf, int aesbits,
                 const uint8_t *key)
 {
   size_t aes_key_bytes = aesbits / 8;
-  aes_raw_set_key(&prf->k, key, aesbits, true);
+  aes_cipher_set_key(prf->k, key, aesbits);
   polyval_key_init(&prf->b, key + aes_key_bytes);
 }
 /**
@@ -183,7 +184,15 @@ cgo_prf_xor_t0(cgo_prf_t *prf, const uint8_t *input,
   polyval_get_tag(&pv, hash);
   hash[15] &= 0xC0; // Clear the low six bits.
 
-  aes_raw_counter_xor(prf->k, hash, 0, data, PRF_T0_DATA_LEN);
+  aes_cipher_set_iv_aligned(prf->k, hash);
+  aes_crypt_inplace(prf->k, (char*) data, PRF_T0_DATA_LEN);
+
+  // Re-align the cipher.
+  //
+  // This approach is faster than EVP_CIPHER_set_num!
+  const int ns = 16 - (PRF_T0_DATA_LEN % 0xf);
+  // We're not using the hash for anything, so it's okay to overwrite
+  aes_crypt_inplace(prf->k, (char*)hash,  ns);
 }
 /**
  * Generate 'n' bytes of the PRF's results on 'input', for position t=1,
@@ -202,9 +211,18 @@ cgo_prf_gen_t1(cgo_prf_t *prf, const uint8_t *input,
   polyval_add_block(&pv, input);
   polyval_get_tag(&pv, hash);
   hash[15] &= 0xC0; // Clear the low six bits.
+  hash[15] += T1_OFFSET; // Can't overflow!
 
   memset(buf, 0, n);
-  aes_raw_counter_xor(prf->k, hash, T1_OFFSET, buf, n);
+  aes_cipher_set_iv_aligned(prf->k, hash);
+  aes_crypt_inplace(prf->k, (char*)buf, n);
+
+  // Re-align the cipher.
+  size_t ns = 16-(n&0x0f);
+  if (ns) {
+    // We're not using the hash for anything, so it's okay to overwrite
+    aes_crypt_inplace(prf->k, (char*) hash, ns);
+  }
 }
 /**
  * Release any storage held in 'prf'.
@@ -214,7 +232,7 @@ cgo_prf_gen_t1(cgo_prf_t *prf, const uint8_t *input,
 STATIC void
 cgo_prf_clear(cgo_prf_t *prf)
 {
-  aes_raw_free(prf->k);
+  aes_cipher_free(prf->k);
 }
 
 static int
index 83a8274ef529e48240a5f9bdcecbe253210a273e..ef1ffb0f8c89af7a6bfc7a695cc813733cf1cc8a 100644 (file)
@@ -86,14 +86,9 @@ typedef struct cgo_et_t {
  */
 typedef struct cgo_prf_t {
   /**
-   * AES key: may be 128, 192, or 256 bits.
-   *
-   * Even though we're going to be using this in counter mode,
-   * we don't make an aes_cnt_cipher_t here, since that type
-   * does not support efficient re-use of the key with multiple
-   * IVs.
+   * AES stream cipher: may be 128, 192, or 256 bits.
    */
-  aes_raw_t *k;
+  aes_cnt_cipher_t *k;
   /**
    * Polyval instance.
    */
index 810e7772711d5b7c7a1512d03de1d6ffe556ad83..3fa3b9268a9b5c21c4cd8f343b4f56d02525f603 100644 (file)
@@ -26,3 +26,4 @@ keccak-tiny/*.h
 ed25519/*.h
 
 ext/siphash.h
+ext/polyval/*.h
index d6844d9b9d0d48b80491ea9dac6fe7b50a7f61f9..cee14b18389f36a08f396bea37daef46114f7ed5 100644 (file)
@@ -21,6 +21,9 @@ typedef struct aes_cnt_cipher_t aes_cnt_cipher_t;
 
 aes_cnt_cipher_t* aes_new_cipher(const uint8_t *key, const uint8_t *iv,
                                  int key_bits);
+void aes_cipher_set_iv_aligned(aes_cnt_cipher_t *cipher_, const uint8_t *iv);
+void aes_cipher_set_key(aes_cnt_cipher_t *cipher_,
+                        const uint8_t *key, int key_bits);
 void aes_cipher_free_(aes_cnt_cipher_t *cipher);
 #define aes_cipher_free(cipher) \
   FREE_AND_NULL(aes_cnt_cipher_t, aes_cipher_free_, (cipher))
@@ -40,27 +43,6 @@ void aes_raw_free_(aes_raw_t *cipher);
   FREE_AND_NULL(aes_raw_t, aes_raw_free_, (cipher))
 void aes_raw_encrypt(const aes_raw_t *cipher, uint8_t *block);
 void aes_raw_decrypt(const aes_raw_t *cipher, uint8_t *block);
-
-void aes_raw_counter_xor(const aes_raw_t *aes,
-                         const uint8_t *iv, uint32_t iv_offset,
-                         uint8_t *data, size_t n);
-#endif
-
-#ifdef TOR_AES_PRIVATE
-#include "lib/arch/bytes.h"
-
-/** Increment the big-endian 128-bit counter in 'iv' by 'offset'. */
-static inline void
-aes_ctr_add_iv_offset(uint8_t *iv, uint32_t offset)
-{
-
-  uint64_t h_hi = tor_ntohll(get_uint64(iv + 0));
-  uint64_t h_lo = tor_ntohll(get_uint64(iv + 8));
-  h_lo += offset;
-  h_hi += (h_lo < offset);
-  set_uint64(iv + 0, tor_htonll(h_hi));
-  set_uint64(iv + 8, tor_htonll(h_lo));
-}
 #endif
 
 #endif /* !defined(TOR_AES_H) */
index f2550e91c1fa1e87b15b6cc2eb082522b5aadd4e..ab72c12fe123a5a10ebcb52f45092f9bcdfd05d5 100644 (file)
@@ -23,9 +23,18 @@ DISABLE_GCC_WARNING("-Wstrict-prototypes")
 #include <secerr.h>
 ENABLE_GCC_WARNING("-Wstrict-prototypes")
 
-aes_cnt_cipher_t *
-aes_new_cipher(const uint8_t *key, const uint8_t *iv,
-               int key_bits)
+struct aes_cnt_cipher_t {
+  PK11Context *context;
+  // We need to keep a copy of the key here since we can't set the IV only.
+  // It would be nice to fix that, but NSS doesn't see a huge number of
+  // users.
+  uint8_t kbytes;
+  uint8_t key[32];
+};
+
+static PK11Context *
+aes_new_cipher_internal(const uint8_t *key, const uint8_t *iv,
+                        int key_bits)
 {
   const CK_MECHANISM_TYPE ckm = CKM_AES_CTR;
   SECItem keyItem = { .type = siBuffer,
@@ -68,7 +77,18 @@ aes_new_cipher(const uint8_t *key, const uint8_t *iv,
     PK11_FreeSlot(slot);
 
   tor_assert(result);
-  return (aes_cnt_cipher_t *)result;
+  return result;
+}
+
+aes_cnt_cipher_t *
+aes_new_cipher(const uint8_t *key, const uint8_t *iv,
+                        int key_bits)
+{
+  aes_cnt_cipher_t *cipher = tor_malloc_zero(sizeof(*cipher));
+  cipher->context = aes_new_cipher_internal(key, iv, key_bits);
+  cipher->kbytes = key_bits / 8;
+  memcpy(cipher->key, key, cipher->kbytes);
+  return cipher;
 }
 
 void
@@ -76,7 +96,34 @@ aes_cipher_free_(aes_cnt_cipher_t *cipher)
 {
   if (!cipher)
     return;
-  PK11_DestroyContext((PK11Context*) cipher, PR_TRUE);
+  PK11_DestroyContext(cipher->context, PR_TRUE);
+  memwipe(cipher, 0, sizeof(*cipher));
+  tor_free(cipher);
+}
+
+void
+aes_cipher_set_iv_aligned(aes_cnt_cipher_t *cipher, const uint8_t *iv)
+{
+  // For NSS, I could not find a method to change the IV
+  // of an existing context.  Maybe I missed one?
+  PK11_DestroyContext(cipher->context, PR_TRUE);
+  cipher->context = aes_new_cipher_internal(cipher->key, iv,
+                                            8*(int)cipher->kbytes);
+}
+
+void
+aes_cipher_set_key(aes_cnt_cipher_t *cipher,
+                   const uint8_t *key, int key_bits)
+{
+  const uint8_t iv[16] = {0};
+  // For NSS, I could not find a method to change the key
+  // of an existing context. Maybe I missed one?
+  PK11_DestroyContext(cipher->context, PR_TRUE);
+  memwipe(cipher->key, 0, sizeof(cipher->key));
+
+  cipher->context = aes_new_cipher_internal(key, iv, key_bits);
+  cipher->kbytes = key_bits / 8;
+  memcpy(cipher->key, key, cipher->kbytes);
 }
 
 void
@@ -85,12 +132,11 @@ aes_crypt_inplace(aes_cnt_cipher_t *cipher, char *data_, size_t len_)
   tor_assert(len_ <= INT_MAX);
 
   SECStatus s;
-  PK11Context *ctx = (PK11Context*)cipher;
   unsigned char *data = (unsigned char *)data_;
   int len = (int) len_;
   int result_len = 0;
 
-  s = PK11_CipherOp(ctx, data, &result_len, len, data, len);
+  s = PK11_CipherOp(cipher->context, data, &result_len, len, data, len);
   tor_assert(s == SECSuccess);
   tor_assert(result_len == len);
 }
@@ -186,37 +232,3 @@ aes_raw_decrypt(const aes_raw_t *cipher, uint8_t *block)
   /* This is the same function call for NSS. */
   aes_raw_encrypt(cipher, block);
 }
-
-static inline void
-xor_bytes(uint8_t *outp, const uint8_t *inp, size_t n)
-{
-  for (size_t i = 0; i < n; ++i) {
-    outp[i] ^= inp[i];
-  }
-}
-
-void
-aes_raw_counter_xor(const aes_raw_t *cipher,
-                    const uint8_t *iv, uint32_t iv_offset,
-                    uint8_t *data, size_t n)
-{
-  uint8_t counter[16];
-  uint8_t buf[16];
-
-  memcpy(counter, iv, 16);
-  aes_ctr_add_iv_offset(counter, iv_offset);
-
-  while (n) {
-    memcpy(buf, counter, 16);
-    aes_raw_encrypt(cipher, buf);
-    if (n >= 16) {
-      xor_bytes(data, buf, 16);
-      n -= 16;
-      data += 16;
-    } else {
-      xor_bytes(data, buf, n);
-      break;
-    }
-    aes_ctr_add_iv_offset(counter, 1);
-  }
-}
index 7a03024440fcee7ae11579bfbda7c83b6e546dd9..270712be87b83a7840e7bb03ed0a96ae51317015 100644 (file)
@@ -72,6 +72,44 @@ ENABLE_GCC_WARNING("-Wredundant-decls")
 
 #endif /* OPENSSL_VERSION_NUMBER >= OPENSSL_V_NOPATCH(1,1,0) || ... */
 
+/* Cached values of our EVP_CIPHER items.  If we don't pre-fetch them,
+ * then EVP_CipherInit calls EVP_CIPHER_fetch itself,
+ * which is surprisingly expensive.
+ */
+static const EVP_CIPHER *aes128ctr = NULL;
+static const EVP_CIPHER *aes192ctr = NULL;
+static const EVP_CIPHER *aes256ctr = NULL;
+static const EVP_CIPHER *aes128ecb = NULL;
+static const EVP_CIPHER *aes192ecb = NULL;
+static const EVP_CIPHER *aes256ecb = NULL;
+
+#if OPENSSL_VERSION_NUMBER >= OPENSSL_V_NOPATCH(3,0,0) \
+  && !defined(LIBRESSL_VERSION_NUMBER)
+#define RESOLVE_CIPHER(c) \
+  EVP_CIPHER_fetch(NULL, OBJ_nid2sn(EVP_CIPHER_get_nid(c)), "")
+#else
+#define RESOLVE_CIPHER(c) (c)
+#endif
+
+/**
+ * Pre-fetch the versions of every AES cipher with its associated provider.
+ */
+static void
+init_ciphers(void)
+{
+  aes128ctr = RESOLVE_CIPHER(EVP_aes_128_ctr());
+  aes192ctr = RESOLVE_CIPHER(EVP_aes_192_ctr());
+  aes256ctr = RESOLVE_CIPHER(EVP_aes_256_ctr());
+  aes128ecb = RESOLVE_CIPHER(EVP_aes_128_ecb());
+  aes192ecb = RESOLVE_CIPHER(EVP_aes_192_ecb());
+  aes256ecb = RESOLVE_CIPHER(EVP_aes_256_ecb());
+}
+#define INIT_CIPHERS() STMT_BEGIN { \
+    if (PREDICT_UNLIKELY(NULL == aes128ctr)) {  \
+      init_ciphers();                           \
+    }                                           \
+  } STMT_END
+
 /* We have 2 strategies for getting the AES block cipher: Via OpenSSL's
  * AES_encrypt function, or via OpenSSL's EVP_EncryptUpdate function.
  *
@@ -91,17 +129,6 @@ ENABLE_GCC_WARNING("-Wredundant-decls")
  * make sure that we have a fixed version.)
  */
 
-/* Helper function to use EVP with openssl's counter-mode wrapper. */
-static void
-evp_block128_fn(const uint8_t in[16],
-                uint8_t out[16],
-                const void *key)
-{
-  EVP_CIPHER_CTX *ctx = (void*)key;
-  int inl=16, outl=16;
-  EVP_EncryptUpdate(ctx, out, &outl, in, inl);
-}
-
 #ifdef USE_EVP_AES_CTR
 
 /* We don't actually define the struct here. */
@@ -109,12 +136,13 @@ evp_block128_fn(const uint8_t in[16],
 aes_cnt_cipher_t *
 aes_new_cipher(const uint8_t *key, const uint8_t *iv, int key_bits)
 {
+  INIT_CIPHERS();
   EVP_CIPHER_CTX *cipher = EVP_CIPHER_CTX_new();
   const EVP_CIPHER *c = NULL;
   switch (key_bits) {
-    case 128: c = EVP_aes_128_ctr(); break;
-    case 192: c = EVP_aes_192_ctr(); break;
-    case 256: c = EVP_aes_256_ctr(); break;
+    case 128: c = aes128ctr; break;
+    case 192: c = aes192ctr; break;
+    case 256: c = aes256ctr; break;
     default: tor_assert_unreached(); // LCOV_EXCL_LINE
   }
   EVP_EncryptInit(cipher, c, key, iv);
@@ -129,6 +157,44 @@ aes_cipher_free_(aes_cnt_cipher_t *cipher_)
   EVP_CIPHER_CTX_reset(cipher);
   EVP_CIPHER_CTX_free(cipher);
 }
+
+/** Changes the key of the cipher;
+ * sets the IV to 0.
+ */
+void
+aes_cipher_set_key(aes_cnt_cipher_t *cipher_, const uint8_t *key, int key_bits)
+{
+  EVP_CIPHER_CTX *cipher = (EVP_CIPHER_CTX *) cipher_;
+  uint8_t iv[16] = {0};
+  const EVP_CIPHER *c = NULL;
+  switch (key_bits) {
+    case 128: c = aes128ctr; break;
+    case 192: c = aes192ctr; break;
+    case 256: c = aes256ctr; break;
+    default: tor_assert_unreached(); // LCOV_EXCL_LINE
+  }
+
+  // No need to call EVP_CIPHER_CTX_Reset here; EncryptInit already
+  // does it for us.
+  EVP_EncryptInit(cipher, c, key, iv);
+}
+/** Change the IV of this stream cipher without changing the key.
+ *
+ * Requires that the cipher stream position is at an even multiple of 16 bytes.
+ */
+void
+aes_cipher_set_iv_aligned(aes_cnt_cipher_t *cipher_, const uint8_t *iv)
+{
+  EVP_CIPHER_CTX *cipher = (EVP_CIPHER_CTX *) cipher_;
+#ifdef LIBRESSL_VERSION_NUMBER
+  EVP_CIPHER_CTX_set_iv(cipher, iv, 16);
+#else
+  // We would have to do this if the cipher's position were not aligned:
+  // EVP_CIPHER_CTX_set_num(cipher, 0);
+
+  memcpy(EVP_CIPHER_CTX_iv_noconst(cipher), iv, 16);
+#endif
+}
 void
 aes_crypt_inplace(aes_cnt_cipher_t *cipher_, char *data, size_t len)
 {
@@ -306,9 +372,9 @@ aes_set_key(aes_cnt_cipher_t *cipher, const uint8_t *key, int key_bits)
   if (should_use_EVP) {
     const EVP_CIPHER *c = 0;
     switch (key_bits) {
-      case 128: c = EVP_aes_128_ecb(); break;
-      case 192: c = EVP_aes_192_ecb(); break;
-      case 256: c = EVP_aes_256_ecb(); break;
+      case 128: c = aes128ecb; break;
+      case 192: c = aes192ecb; break;
+      case 256: c = aes256ecb; break;
       default: tor_assert(0); // LCOV_EXCL_LINE
     }
     EVP_EncryptInit(&cipher->key.evp, c, key, NULL);
@@ -406,16 +472,19 @@ aes_crypt_inplace(aes_cnt_cipher_t *cipher, char *data, size_t len)
 aes_raw_t *
 aes_raw_new(const uint8_t *key, int key_bits, bool encrypt)
 {
+  INIT_CIPHERS();
   EVP_CIPHER_CTX *cipher = EVP_CIPHER_CTX_new();
   tor_assert(cipher);
   const EVP_CIPHER *c = NULL;
   switch (key_bits) {
-    case 128: c = EVP_aes_128_ecb(); break;
-    case 192: c = EVP_aes_192_ecb(); break;
-    case 256: c = EVP_aes_256_ecb(); break;
+    case 128: c = aes128ecb; break;
+    case 192: c = aes192ecb; break;
+    case 256: c = aes256ecb; break;
     default: tor_assert_unreached();
   }
 
+  // No need to call EVP_CIPHER_CTX_Reset here; EncryptInit already
+  // does it for us.
   int r = EVP_CipherInit(cipher, c, key, NULL, encrypt);
   tor_assert(r == 1);
   EVP_CIPHER_CTX_set_padding(cipher, 0);
@@ -432,14 +501,13 @@ aes_raw_set_key(aes_raw_t **cipher_, const uint8_t *key,
 {
   const EVP_CIPHER *c = *(EVP_CIPHER**) cipher_;
   switch (key_bits) {
-    case 128: c = EVP_aes_128_ecb(); break;
-    case 192: c = EVP_aes_192_ecb(); break;
-    case 256: c = EVP_aes_256_ecb(); break;
+    case 128: c = aes128ecb; break;
+    case 192: c = aes192ecb; break;
+    case 256: c = aes256ecb; break;
     default: tor_assert_unreached();
   }
   aes_raw_t *cipherp = *cipher_;
   EVP_CIPHER_CTX *cipher = (EVP_CIPHER_CTX *)cipherp;
-  EVP_CIPHER_CTX_reset(cipher);
   int r = EVP_CipherInit(cipher, c, key, NULL, encrypt);
   tor_assert(r == 1);
   EVP_CIPHER_CTX_set_padding(cipher, 0);
@@ -487,30 +555,3 @@ aes_raw_decrypt(const aes_raw_t *cipher, uint8_t *block)
   tor_assert(r == 1);
   tor_assert(outl == 16);
 }
-
-/**
- * Use the AES encryption key AES in counter mode,
- * starting at the position (iv + iv_offset)*16,
- * to encrypt the 'n' bytes of data in 'data'.
- *
- * Unlike aes_crypt_inplace, this function can re-use the same key repeatedly
- * with diferent IVs.
- */
-void
-aes_raw_counter_xor(const aes_raw_t *cipher,
-                    const uint8_t *iv, uint32_t iv_offset,
-                    uint8_t *data, size_t n)
-{
-  uint8_t counter[16];
-  uint8_t buf[16];
-  unsigned int pos = 0;
-
-  memcpy(counter, iv, 16);
-  if (iv_offset) {
-    aes_ctr_add_iv_offset(counter, iv_offset);
-  }
-
-  CRYPTO_ctr128_encrypt(data, data, n,
-                        (EVP_CIPHER_CTX *)cipher,
-                        counter, buf, &pos, evp_block128_fn);
-}
index 6ac57cf65c2d44890074738e1210d1fe6981a4dc..27e71c9eb8a83127496157d5c3091cca8c6f9c7c 100644 (file)
@@ -670,7 +670,7 @@ bench_cell_ops_cgo(void)
   printf("%s: %.2f per cell (%.2f cpb)\n",              \
          (operation),                                   \
          NANOCOUNT(start,end,iters),                    \
-         cpb(cstart, cend, iters * payload_len))
+         cpb(cstart, cend, (double)iters * payload_len))
 
   // Initialize crypto
   cgo_crypt_t *r_f = cgo_crypt_new(CGO_MODE_RELAY_FORWARD, 128, keys, keylen);
index f8e1c7a8fea2ef883b3236ff1b024c8c0df86346..1281545e29ddfc772f97cb7dbbadc8b5d5e09a18 100644 (file)
@@ -3340,103 +3340,102 @@ test_crypto_aes_raw(void *arg)
 #undef T
 }
 
+/** Make sure that we can set keys on live AES instances correctly. */
 static void
-test_crypto_aes_raw_ctr_equiv(void *arg)
+test_crypto_aes_keymanip_cnt(void *arg)
 {
   (void) arg;
-  size_t buflen = 65536;
-  uint8_t *buf = tor_malloc_zero(buflen);
-  aes_cnt_cipher_t *c = NULL;
-  aes_raw_t *c_raw = NULL;
-
-  const uint8_t iv[16];
-  const uint8_t key[16];
-
-  // Simple case, IV  with zero offset.
-  for (int i = 0; i < 32; ++i) {
-    crypto_rand((char*)iv, sizeof(iv));
-    crypto_rand((char*)key, sizeof(key));
-    c = aes_new_cipher(key, iv, 128);
-    c_raw = aes_raw_new(key, 128, true);
-
-    aes_crypt_inplace(c, (char*)buf, buflen);
-    aes_raw_counter_xor(c_raw, iv, 0, buf, buflen);
-    tt_assert(fast_mem_is_zero((char*)buf, buflen));
-
-    aes_cipher_free(c);
-    aes_raw_free(c_raw);
-  }
-  // Trickier case, IV with offset == 31.
-  for (int i = 0; i < 32; ++i) {
-    crypto_rand((char*)iv, sizeof(iv));
-    crypto_rand((char*)key, sizeof(key));
-    c = aes_new_cipher(key, iv, 128);
-    c_raw = aes_raw_new(key, 128, true);
-
-    aes_crypt_inplace(c, (char*)buf, buflen);
-    size_t off = 31*16;
-    aes_raw_counter_xor(c_raw, iv, 31, buf + off, buflen - off);
-    tt_assert(fast_mem_is_zero((char*)buf + off, buflen - off));
-
-    aes_cipher_free(c);
-    aes_raw_free(c_raw);
-  }
+  uint8_t k1[16] = "123456780123678";
+  uint8_t k2[16] = "abcdefghijklmno";
+  int kbits = 128;
+  uint8_t iv1[16]= "{return 4;}////";
+  uint8_t iv2[16] = {0};
+  uint8_t buf[128] = {0};
+  uint8_t buf2[128] = {0};
+
+  aes_cnt_cipher_t *aes = aes_new_cipher(k1, iv1, kbits);
+  aes_crypt_inplace(aes, (char*)buf, sizeof(buf));
+
+  aes_cnt_cipher_t *aes2 = aes_new_cipher(k2, iv2, kbits);
+  // 128-5 to make sure internal buf is cleared when we set key.
+  aes_crypt_inplace(aes2, (char*)buf2, sizeof(buf2)-5);
+  aes_cipher_set_key(aes2, k1, kbits);
+  aes_cipher_set_iv_aligned(aes2, iv1); // should work in this case.
+  memset(buf2, 0, sizeof(buf2));
+  aes_crypt_inplace(aes2, (char*)buf2, sizeof(buf2));
+  tt_mem_op(buf, OP_EQ, buf2, sizeof(buf));
 
  done:
-  aes_cipher_free(c);
-  aes_raw_free(c_raw);
-  tor_free(buf);
+  aes_cipher_free(aes);
+  aes_cipher_free(aes2);
 }
 
-/* Make sure that our IV addition code is correct.
- *
- * We test this function separately to make sure we handle corner cases well;
- * the corner cases are rare enough that we shouldn't expect to see them in
- * randomized testing.
- */
 static void
-test_crypto_aes_cnt_iv_manip(void *arg)
+test_crypto_aes_keymanip_ecb(void *arg)
 {
-  (void)arg;
-  uint8_t buf[16];
-  uint8_t expect[16];
-  int n;
-#define T(pre, off, post) STMT_BEGIN {                                  \
-    n = base16_decode((char*)buf, sizeof(buf),                          \
-                  (pre), strlen(pre));                                  \
-    tt_int_op(n, OP_EQ, sizeof(buf));                                   \
-    n = base16_decode((char*)expect, sizeof(expect),                    \
-                  (post), strlen(post));                                \
-    tt_int_op(n, OP_EQ, sizeof(expect));                                \
-    aes_ctr_add_iv_offset(buf, (off));                                  \
-    tt_mem_op(buf, OP_EQ, expect, 16);                                  \
-  } STMT_END
-
-  T("00000000000000000000000000000000", 0x4032,
-    "00000000000000000000000000004032");
-  T("0000000000000000000000000000ffff", 0x4032,
-    "00000000000000000000000000014031");
-  // We focus on "31" here because that's what CGO uses.
-  T("000000000000000000000000ffffffe0", 31,
-    "000000000000000000000000ffffffff");
-  T("000000000000000000000000ffffffe1", 31,
-    "00000000000000000000000100000000");
-  T("0000000100000000ffffffffffffffe0", 31,
-    "0000000100000000ffffffffffffffff");
-  T("0000000100000000ffffffffffffffe1", 31,
-    "00000001000000010000000000000000");
-  T("0000000ffffffffffffffffffffffff0", 31,
-    "0000001000000000000000000000000f");
-  T("ffffffffffffffffffffffffffffffe0", 31,
-    "ffffffffffffffffffffffffffffffff");
-  T("ffffffffffffffffffffffffffffffe1", 31,
-    "00000000000000000000000000000000");
-  T("ffffffffffffffffffffffffffffffe8", 31,
-    "00000000000000000000000000000007");
+  (void) arg;
+  uint8_t k1[16] = "123456780123678";
+  uint8_t k2[16] = "abcdefghijklmno";
+  int kbits = 128;
+  uint8_t buf_orig[16] = {1,2,3,0};
+  uint8_t buf1[16];
+  uint8_t buf2[16];
+
+  aes_raw_t *aes1 = aes_raw_new(k1, kbits, true);
+  aes_raw_t *aes2 = aes_raw_new(k1, kbits, false);
+  aes_raw_set_key(&aes2, k2, kbits, false);
+
+  memcpy(buf1, buf_orig, 16);
+  memcpy(buf2, buf_orig, 16);
+
+  aes_raw_encrypt(aes1, buf1);
+  aes_raw_encrypt(aes1, buf2);
+  tt_mem_op(buf1, OP_EQ, buf2, 16);
+
+  aes_raw_decrypt(aes2, buf1);
+  aes_raw_set_key(&aes2, k1, kbits, false);
+  aes_raw_decrypt(aes2, buf2);
+
+  tt_mem_op(buf1, OP_NE, buf2, 16);
+  tt_mem_op(buf2, OP_EQ, buf_orig, 16);
 
-#undef T
  done:
-  ;
+  aes_raw_free(aes1);
+  aes_raw_free(aes2);
+}
+
+static void
+test_crypto_aes_cnt_set_iv(void *arg)
+{
+  (void)arg;
+  uint8_t k1[16] = "123456780123678";
+  uint8_t iv_zero[16] = {0};
+  int kbits = 128;
+  const int iters = 100;
+  uint8_t buf1[128];
+  uint8_t buf2[128];
+
+  aes_cnt_cipher_t *aes1, *aes2 = NULL;
+  aes1 = aes_new_cipher(k1, iv_zero, kbits);
+
+  for (int i = 0; i < iters; ++i) {
+    uint8_t iv[16];
+    crypto_rand((char*) iv, sizeof(iv));
+    memset(buf1, 0, sizeof(buf1));
+    memset(buf2, 0, sizeof(buf2));
+
+    aes_cipher_set_iv_aligned(aes1, iv);
+    aes2 = aes_new_cipher(k1, iv, kbits);
+
+    aes_crypt_inplace(aes1, (char*)buf1, sizeof(buf1));
+    aes_crypt_inplace(aes2, (char*)buf1, sizeof(buf2));
+    tt_mem_op(buf1, OP_EQ, buf2, sizeof(buf1));
+
+    aes_cipher_free(aes2);
+  }
+ done:
+  aes_cipher_free(aes1);
+  aes_cipher_free(aes2);
 }
 
 #ifndef COCCI
@@ -3508,7 +3507,8 @@ struct testcase_t crypto_tests[] = {
   { "failure_modes", test_crypto_failure_modes, TT_FORK, NULL, NULL },
   { "polyval", test_crypto_polyval, 0, NULL, NULL },
   { "aes_raw", test_crypto_aes_raw, 0, NULL, NULL },
-  { "aes_raw_ctr_equiv", test_crypto_aes_raw_ctr_equiv, 0, NULL, NULL },
-  { "aes_cnt_iv_manip", test_crypto_aes_cnt_iv_manip, 0, NULL, NULL },
+  { "aes_keymanip_cnt", test_crypto_aes_keymanip_cnt, 0, NULL, NULL },
+  { "aes_keymanip_ecb", test_crypto_aes_keymanip_ecb, 0, NULL, NULL },
+  { "aes_cnt_set_iv", test_crypto_aes_cnt_set_iv, 0, NULL, NULL },
   END_OF_TESTCASES
 };