]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
crypto/arm64: aes-blk - Switch to 'ksimd' scoped guard API
authorArd Biesheuvel <ardb@kernel.org>
Wed, 1 Oct 2025 11:43:33 +0000 (13:43 +0200)
committerArd Biesheuvel <ardb@kernel.org>
Wed, 12 Nov 2025 08:52:01 +0000 (09:52 +0100)
Switch to the more abstract 'scoped_ksimd()' API, which will be modified
in a future patch to transparently allocate a kernel mode FP/SIMD state
buffer on the stack, so that kernel mode FP/SIMD code remains
preemptible in principe, but without the memory overhead that adds 528
bytes to the size of struct task_struct.

Reviewed-by: Eric Biggers <ebiggers@kernel.org>
Reviewed-by: Jonathan Cameron <jonathan.cameron@huawei.com>
Acked-by: Catalin Marinas <catalin.marinas@arm.com>
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
arch/arm64/crypto/aes-ce-glue.c
arch/arm64/crypto/aes-glue.c
arch/arm64/crypto/aes-neonbs-glue.c

index 00b8749013c5bf1a08985482597d5768b62eb012..a4dad370991df6eabcf4e494b17aa3e528ee1a59 100644 (file)
@@ -52,9 +52,8 @@ static void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
                return;
        }
 
-       kernel_neon_begin();
-       __aes_ce_encrypt(ctx->key_enc, dst, src, num_rounds(ctx));
-       kernel_neon_end();
+       scoped_ksimd()
+               __aes_ce_encrypt(ctx->key_enc, dst, src, num_rounds(ctx));
 }
 
 static void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
@@ -66,9 +65,8 @@ static void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
                return;
        }
 
-       kernel_neon_begin();
-       __aes_ce_decrypt(ctx->key_dec, dst, src, num_rounds(ctx));
-       kernel_neon_end();
+       scoped_ksimd()
+               __aes_ce_decrypt(ctx->key_dec, dst, src, num_rounds(ctx));
 }
 
 int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
@@ -94,47 +92,48 @@ int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
        for (i = 0; i < kwords; i++)
                ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
 
-       kernel_neon_begin();
-       for (i = 0; i < sizeof(rcon); i++) {
-               u32 *rki = ctx->key_enc + (i * kwords);
-               u32 *rko = rki + kwords;
-
-               rko[0] = ror32(__aes_ce_sub(rki[kwords - 1]), 8) ^ rcon[i] ^ rki[0];
-               rko[1] = rko[0] ^ rki[1];
-               rko[2] = rko[1] ^ rki[2];
-               rko[3] = rko[2] ^ rki[3];
-
-               if (key_len == AES_KEYSIZE_192) {
-                       if (i >= 7)
-                               break;
-                       rko[4] = rko[3] ^ rki[4];
-                       rko[5] = rko[4] ^ rki[5];
-               } else if (key_len == AES_KEYSIZE_256) {
-                       if (i >= 6)
-                               break;
-                       rko[4] = __aes_ce_sub(rko[3]) ^ rki[4];
-                       rko[5] = rko[4] ^ rki[5];
-                       rko[6] = rko[5] ^ rki[6];
-                       rko[7] = rko[6] ^ rki[7];
+       scoped_ksimd() {
+               for (i = 0; i < sizeof(rcon); i++) {
+                       u32 *rki = ctx->key_enc + (i * kwords);
+                       u32 *rko = rki + kwords;
+
+                       rko[0] = ror32(__aes_ce_sub(rki[kwords - 1]), 8) ^
+                                rcon[i] ^ rki[0];
+                       rko[1] = rko[0] ^ rki[1];
+                       rko[2] = rko[1] ^ rki[2];
+                       rko[3] = rko[2] ^ rki[3];
+
+                       if (key_len == AES_KEYSIZE_192) {
+                               if (i >= 7)
+                                       break;
+                               rko[4] = rko[3] ^ rki[4];
+                               rko[5] = rko[4] ^ rki[5];
+                       } else if (key_len == AES_KEYSIZE_256) {
+                               if (i >= 6)
+                                       break;
+                               rko[4] = __aes_ce_sub(rko[3]) ^ rki[4];
+                               rko[5] = rko[4] ^ rki[5];
+                               rko[6] = rko[5] ^ rki[6];
+                               rko[7] = rko[6] ^ rki[7];
+                       }
                }
-       }
 
-       /*
-        * Generate the decryption keys for the Equivalent Inverse Cipher.
-        * This involves reversing the order of the round keys, and applying
-        * the Inverse Mix Columns transformation on all but the first and
-        * the last one.
-        */
-       key_enc = (struct aes_block *)ctx->key_enc;
-       key_dec = (struct aes_block *)ctx->key_dec;
-       j = num_rounds(ctx);
-
-       key_dec[0] = key_enc[j];
-       for (i = 1, j--; j > 0; i++, j--)
-               __aes_ce_invert(key_dec + i, key_enc + j);
-       key_dec[i] = key_enc[0];
+               /*
+                * Generate the decryption keys for the Equivalent Inverse
+                * Cipher.  This involves reversing the order of the round
+                * keys, and applying the Inverse Mix Columns transformation on
+                * all but the first and the last one.
+                */
+               key_enc = (struct aes_block *)ctx->key_enc;
+               key_dec = (struct aes_block *)ctx->key_dec;
+               j = num_rounds(ctx);
+
+               key_dec[0] = key_enc[j];
+               for (i = 1, j--; j > 0; i++, j--)
+                       __aes_ce_invert(key_dec + i, key_enc + j);
+               key_dec[i] = key_enc[0];
+       }
 
-       kernel_neon_end();
        return 0;
 }
 EXPORT_SYMBOL(ce_aes_expandkey);
index 5e207ff34482f54b969b02b189555b080d7ea8f9..b087b900d2790b4e1b328f083b91546098d26bb8 100644 (file)
@@ -5,8 +5,6 @@
  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
  */
 
-#include <asm/hwcap.h>
-#include <asm/neon.h>
 #include <crypto/aes.h>
 #include <crypto/ctr.h>
 #include <crypto/internal/hash.h>
@@ -20,6 +18,9 @@
 #include <linux/module.h>
 #include <linux/string.h>
 
+#include <asm/hwcap.h>
+#include <asm/simd.h>
+
 #include "aes-ce-setkey.h"
 
 #ifdef USE_V8_CRYPTO_EXTENSIONS
@@ -186,10 +187,9 @@ static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
        err = skcipher_walk_virt(&walk, req, false);
 
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-               kernel_neon_begin();
-               aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key_enc, rounds, blocks);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                       ctx->key_enc, rounds, blocks);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        return err;
@@ -206,10 +206,9 @@ static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
        err = skcipher_walk_virt(&walk, req, false);
 
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
-               kernel_neon_begin();
-               aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key_dec, rounds, blocks);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                       ctx->key_dec, rounds, blocks);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        return err;
@@ -224,10 +223,9 @@ static int cbc_encrypt_walk(struct skcipher_request *req,
        unsigned int blocks;
 
        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
-               kernel_neon_begin();
-               aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
-                               ctx->key_enc, rounds, blocks, walk->iv);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
+                                       ctx->key_enc, rounds, blocks, walk->iv);
                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
@@ -253,10 +251,9 @@ static int cbc_decrypt_walk(struct skcipher_request *req,
        unsigned int blocks;
 
        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
-               kernel_neon_begin();
-               aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
-                               ctx->key_dec, rounds, blocks, walk->iv);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
+                                       ctx->key_dec, rounds, blocks, walk->iv);
                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
        }
        return err;
@@ -322,10 +319,9 @@ static int cts_cbc_encrypt(struct skcipher_request *req)
        if (err)
                return err;
 
-       kernel_neon_begin();
-       aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                           ctx->key_enc, rounds, walk.nbytes, walk.iv);
-       kernel_neon_end();
+       scoped_ksimd()
+               aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                   ctx->key_enc, rounds, walk.nbytes, walk.iv);
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -379,10 +375,9 @@ static int cts_cbc_decrypt(struct skcipher_request *req)
        if (err)
                return err;
 
-       kernel_neon_begin();
-       aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                           ctx->key_dec, rounds, walk.nbytes, walk.iv);
-       kernel_neon_end();
+       scoped_ksimd()
+               aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                   ctx->key_dec, rounds, walk.nbytes, walk.iv);
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -399,11 +394,11 @@ static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
 
        blocks = walk.nbytes / AES_BLOCK_SIZE;
        if (blocks) {
-               kernel_neon_begin();
-               aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                     ctx->key1.key_enc, rounds, blocks,
-                                     req->iv, ctx->key2.key_enc);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_essiv_cbc_encrypt(walk.dst.virt.addr,
+                                             walk.src.virt.addr,
+                                             ctx->key1.key_enc, rounds, blocks,
+                                             req->iv, ctx->key2.key_enc);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        return err ?: cbc_encrypt_walk(req, &walk);
@@ -421,11 +416,11 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
 
        blocks = walk.nbytes / AES_BLOCK_SIZE;
        if (blocks) {
-               kernel_neon_begin();
-               aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                     ctx->key1.key_dec, rounds, blocks,
-                                     req->iv, ctx->key2.key_enc);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_essiv_cbc_decrypt(walk.dst.virt.addr,
+                                             walk.src.virt.addr,
+                                             ctx->key1.key_dec, rounds, blocks,
+                                             req->iv, ctx->key2.key_enc);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        return err ?: cbc_decrypt_walk(req, &walk);
@@ -461,10 +456,9 @@ static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
                else if (nbytes < walk.total)
                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               kernel_neon_begin();
-               aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
-                                                walk.iv, byte_ctr);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
+                                                        walk.iv, byte_ctr);
 
                if (unlikely(nbytes < AES_BLOCK_SIZE))
                        memcpy(walk.dst.virt.addr,
@@ -506,10 +500,9 @@ static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
                else if (nbytes < walk.total)
                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               kernel_neon_begin();
-               aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
-                               walk.iv);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
+                                       walk.iv);
 
                if (unlikely(nbytes < AES_BLOCK_SIZE))
                        memcpy(walk.dst.virt.addr,
@@ -562,11 +555,10 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
                if (walk.nbytes < walk.total)
                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               kernel_neon_begin();
-               aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key1.key_enc, rounds, nbytes,
-                               ctx->key2.key_enc, walk.iv, first);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                       ctx->key1.key_enc, rounds, nbytes,
+                                       ctx->key2.key_enc, walk.iv, first);
                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
 
@@ -584,11 +576,10 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
        if (err)
                return err;
 
-       kernel_neon_begin();
-       aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                       ctx->key1.key_enc, rounds, walk.nbytes,
-                       ctx->key2.key_enc, walk.iv, first);
-       kernel_neon_end();
+       scoped_ksimd()
+               aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                               ctx->key1.key_enc, rounds, walk.nbytes,
+                               ctx->key2.key_enc, walk.iv, first);
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -634,11 +625,10 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
                if (walk.nbytes < walk.total)
                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               kernel_neon_begin();
-               aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key1.key_dec, rounds, nbytes,
-                               ctx->key2.key_enc, walk.iv, first);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                       ctx->key1.key_dec, rounds, nbytes,
+                                       ctx->key2.key_enc, walk.iv, first);
                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
 
@@ -657,11 +647,10 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
                return err;
 
 
-       kernel_neon_begin();
-       aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                       ctx->key1.key_dec, rounds, walk.nbytes,
-                       ctx->key2.key_enc, walk.iv, first);
-       kernel_neon_end();
+       scoped_ksimd()
+               aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                               ctx->key1.key_dec, rounds, walk.nbytes,
+                               ctx->key2.key_enc, walk.iv, first);
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -808,10 +797,9 @@ static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
                return err;
 
        /* encrypt the zero vector */
-       kernel_neon_begin();
-       aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
-                       rounds, 1);
-       kernel_neon_end();
+       scoped_ksimd()
+               aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){},
+                               ctx->key.key_enc, rounds, 1);
 
        cmac_gf128_mul_by_x(consts, consts);
        cmac_gf128_mul_by_x(consts + 1, consts);
@@ -837,10 +825,10 @@ static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
        if (err)
                return err;
 
-       kernel_neon_begin();
-       aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
-       aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
-       kernel_neon_end();
+       scoped_ksimd() {
+               aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
+               aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
+       }
 
        return cbcmac_setkey(tfm, key, sizeof(key));
 }
@@ -860,10 +848,9 @@ static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
        int rem;
 
        do {
-               kernel_neon_begin();
-               rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
-                                    dg, enc_before, !enc_before);
-               kernel_neon_end();
+               scoped_ksimd()
+                       rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
+                                            dg, enc_before, !enc_before);
                in += (blocks - rem) * AES_BLOCK_SIZE;
                blocks = rem;
        } while (blocks);
index c4a623e865934b02232c5e23a6c6de91c2c00fb3..d496effb0a5b77119b4d018770c0ddbe749b3efc 100644 (file)
@@ -85,9 +85,8 @@ static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 
        ctx->rounds = 6 + key_len / 4;
 
-       kernel_neon_begin();
-       aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
-       kernel_neon_end();
+       scoped_ksimd()
+               aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
 
        return 0;
 }
@@ -110,10 +109,9 @@ static int __ecb_crypt(struct skcipher_request *req,
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
-               kernel_neon_begin();
-               fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
-                  ctx->rounds, blocks);
-               kernel_neon_end();
+               scoped_ksimd()
+                       fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
+                          ctx->rounds, blocks);
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
@@ -146,9 +144,8 @@ static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 
        memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
 
-       kernel_neon_begin();
-       aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
-       kernel_neon_end();
+       scoped_ksimd()
+               aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
        memzero_explicit(&rk, sizeof(rk));
 
        return 0;
@@ -167,11 +164,11 @@ static int cbc_encrypt(struct skcipher_request *req)
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
                /* fall back to the non-bitsliced NEON implementation */
-               kernel_neon_begin();
-               neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                    ctx->enc, ctx->key.rounds, blocks,
-                                    walk.iv);
-               kernel_neon_end();
+               scoped_ksimd()
+                       neon_aes_cbc_encrypt(walk.dst.virt.addr,
+                                            walk.src.virt.addr,
+                                            ctx->enc, ctx->key.rounds, blocks,
+                                            walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        return err;
@@ -193,11 +190,10 @@ static int cbc_decrypt(struct skcipher_request *req)
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
-               kernel_neon_begin();
-               aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                 ctx->key.rk, ctx->key.rounds, blocks,
-                                 walk.iv);
-               kernel_neon_end();
+               scoped_ksimd()
+                       aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                                         ctx->key.rk, ctx->key.rounds, blocks,
+                                         walk.iv);
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
@@ -220,30 +216,32 @@ static int ctr_encrypt(struct skcipher_request *req)
                const u8 *src = walk.src.virt.addr;
                u8 *dst = walk.dst.virt.addr;
 
-               kernel_neon_begin();
-               if (blocks >= 8) {
-                       aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
-                                         blocks, walk.iv);
-                       dst += blocks * AES_BLOCK_SIZE;
-                       src += blocks * AES_BLOCK_SIZE;
-               }
-               if (nbytes && walk.nbytes == walk.total) {
-                       u8 buf[AES_BLOCK_SIZE];
-                       u8 *d = dst;
-
-                       if (unlikely(nbytes < AES_BLOCK_SIZE))
-                               src = dst = memcpy(buf + sizeof(buf) - nbytes,
-                                                  src, nbytes);
-
-                       neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
-                                            nbytes, walk.iv);
+               scoped_ksimd() {
+                       if (blocks >= 8) {
+                               aesbs_ctr_encrypt(dst, src, ctx->key.rk,
+                                                 ctx->key.rounds, blocks,
+                                                 walk.iv);
+                               dst += blocks * AES_BLOCK_SIZE;
+                               src += blocks * AES_BLOCK_SIZE;
+                       }
+                       if (nbytes && walk.nbytes == walk.total) {
+                               u8 buf[AES_BLOCK_SIZE];
+                               u8 *d = dst;
+
+                               if (unlikely(nbytes < AES_BLOCK_SIZE))
+                                       src = dst = memcpy(buf + sizeof(buf) -
+                                                          nbytes, src, nbytes);
+
+                               neon_aes_ctr_encrypt(dst, src, ctx->enc,
+                                                    ctx->key.rounds, nbytes,
+                                                    walk.iv);
 
-                       if (unlikely(nbytes < AES_BLOCK_SIZE))
-                               memcpy(d, dst, nbytes);
+                               if (unlikely(nbytes < AES_BLOCK_SIZE))
+                                       memcpy(d, dst, nbytes);
 
-                       nbytes = 0;
+                               nbytes = 0;
+                       }
                }
-               kernel_neon_end();
                err = skcipher_walk_done(&walk, nbytes);
        }
        return err;
@@ -320,33 +318,33 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
                in = walk.src.virt.addr;
                nbytes = walk.nbytes;
 
-               kernel_neon_begin();
-               if (blocks >= 8) {
-                       if (first == 1)
-                               neon_aes_ecb_encrypt(walk.iv, walk.iv,
-                                                    ctx->twkey,
-                                                    ctx->key.rounds, 1);
-                       first = 2;
-
-                       fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
-                          walk.iv);
-
-                       out += blocks * AES_BLOCK_SIZE;
-                       in += blocks * AES_BLOCK_SIZE;
-                       nbytes -= blocks * AES_BLOCK_SIZE;
+               scoped_ksimd() {
+                       if (blocks >= 8) {
+                               if (first == 1)
+                                       neon_aes_ecb_encrypt(walk.iv, walk.iv,
+                                                            ctx->twkey,
+                                                            ctx->key.rounds, 1);
+                               first = 2;
+
+                               fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
+                                  walk.iv);
+
+                               out += blocks * AES_BLOCK_SIZE;
+                               in += blocks * AES_BLOCK_SIZE;
+                               nbytes -= blocks * AES_BLOCK_SIZE;
+                       }
+                       if (walk.nbytes == walk.total && nbytes > 0) {
+                               if (encrypt)
+                                       neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
+                                                            ctx->key.rounds, nbytes,
+                                                            ctx->twkey, walk.iv, first);
+                               else
+                                       neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
+                                                            ctx->key.rounds, nbytes,
+                                                            ctx->twkey, walk.iv, first);
+                               nbytes = first = 0;
+                       }
                }
-               if (walk.nbytes == walk.total && nbytes > 0) {
-                       if (encrypt)
-                               neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
-                                                    ctx->key.rounds, nbytes,
-                                                    ctx->twkey, walk.iv, first);
-                       else
-                               neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
-                                                    ctx->key.rounds, nbytes,
-                                                    ctx->twkey, walk.iv, first);
-                       nbytes = first = 0;
-               }
-               kernel_neon_end();
                err = skcipher_walk_done(&walk, nbytes);
        }
 
@@ -369,14 +367,16 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
        in = walk.src.virt.addr;
        nbytes = walk.nbytes;
 
-       kernel_neon_begin();
-       if (encrypt)
-               neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
-                                    nbytes, ctx->twkey, walk.iv, first);
-       else
-               neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
-                                    nbytes, ctx->twkey, walk.iv, first);
-       kernel_neon_end();
+       scoped_ksimd() {
+               if (encrypt)
+                       neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
+                                            ctx->key.rounds, nbytes, ctx->twkey,
+                                            walk.iv, first);
+               else
+                       neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
+                                            ctx->key.rounds, nbytes, ctx->twkey,
+                                            walk.iv, first);
+       }
 
        return skcipher_walk_done(&walk, 0);
 }