]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
crypto/arm64: sm4 - Switch to 'ksimd' scoped guard API
authorArd Biesheuvel <ardb@kernel.org>
Wed, 1 Oct 2025 11:50:51 +0000 (13:50 +0200)
committerArd Biesheuvel <ardb@kernel.org>
Wed, 12 Nov 2025 08:52:02 +0000 (09:52 +0100)
Switch to the more abstract 'scoped_ksimd()' API, which will be modified
in a future patch to transparently allocate a kernel mode FP/SIMD state
buffer on the stack, so that kernel mode FP/SIMD code remains
preemptible in principle, but without the memory overhead that adds 528
bytes to the size of struct task_struct.

Reviewed-by: Eric Biggers <ebiggers@kernel.org>
Reviewed-by: Jonathan Cameron <jonathan.cameron@huawei.com>
Acked-by: Catalin Marinas <catalin.marinas@arm.com>
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
arch/arm64/crypto/sm4-ce-ccm-glue.c
arch/arm64/crypto/sm4-ce-cipher-glue.c
arch/arm64/crypto/sm4-ce-gcm-glue.c
arch/arm64/crypto/sm4-ce-glue.c
arch/arm64/crypto/sm4-neon-glue.c

index e92cbdf1aaee70398fcc30a41caaf8ee33a81ab4..332f02167a96f9aca83722f14545f9d0f5d72576 100644 (file)
@@ -11,7 +11,7 @@
 #include <linux/crypto.h>
 #include <linux/kernel.h>
 #include <linux/cpufeature.h>
-#include <asm/neon.h>
+#include <asm/simd.h>
 #include <crypto/scatterwalk.h>
 #include <crypto/internal/aead.h>
 #include <crypto/internal/skcipher.h>
@@ -35,10 +35,9 @@ static int ccm_setkey(struct crypto_aead *tfm, const u8 *key,
        if (key_len != SM4_KEY_SIZE)
                return -EINVAL;
 
-       kernel_neon_begin();
-       sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
-       kernel_neon_end();
+       scoped_ksimd()
+               sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
+                                 crypto_sm4_fk, crypto_sm4_ck);
 
        return 0;
 }
@@ -167,28 +166,25 @@ static int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,
        memcpy(ctr0, walk->iv, SM4_BLOCK_SIZE);
        crypto_inc(walk->iv, SM4_BLOCK_SIZE);
 
-       kernel_neon_begin();
+       scoped_ksimd() {
+               if (req->assoclen)
+                       ccm_calculate_auth_mac(req, mac);
 
-       if (req->assoclen)
-               ccm_calculate_auth_mac(req, mac);
-
-       while (walk->nbytes) {
-               unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
+               while (walk->nbytes) {
+                       unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
 
-               if (walk->nbytes == walk->total)
-                       tail = 0;
+                       if (walk->nbytes == walk->total)
+                               tail = 0;
 
-               sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
-                                walk->src.virt.addr, walk->iv,
-                                walk->nbytes - tail, mac);
+                       sm4_ce_ccm_crypt(rkey_enc, walk->dst.virt.addr,
+                                        walk->src.virt.addr, walk->iv,
+                                        walk->nbytes - tail, mac);
 
-               err = skcipher_walk_done(walk, tail);
+                       err = skcipher_walk_done(walk, tail);
+               }
+               sm4_ce_ccm_final(rkey_enc, ctr0, mac);
        }
 
-       sm4_ce_ccm_final(rkey_enc, ctr0, mac);
-
-       kernel_neon_end();
-
        return err;
 }
 
index c31d76fb5a17707440b299d164c68a41e2308f57..bceec833ef4e789d0f33b9a955e8e79c389fe784 100644 (file)
@@ -32,9 +32,8 @@ static void sm4_ce_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
        if (!crypto_simd_usable()) {
                sm4_crypt_block(ctx->rkey_enc, out, in);
        } else {
-               kernel_neon_begin();
-               sm4_ce_do_crypt(ctx->rkey_enc, out, in);
-               kernel_neon_end();
+               scoped_ksimd()
+                       sm4_ce_do_crypt(ctx->rkey_enc, out, in);
        }
 }
 
@@ -45,9 +44,8 @@ static void sm4_ce_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
        if (!crypto_simd_usable()) {
                sm4_crypt_block(ctx->rkey_dec, out, in);
        } else {
-               kernel_neon_begin();
-               sm4_ce_do_crypt(ctx->rkey_dec, out, in);
-               kernel_neon_end();
+               scoped_ksimd()
+                       sm4_ce_do_crypt(ctx->rkey_dec, out, in);
        }
 }
 
index 8f6fc8c33c3fece2753ec7b2e7a892a5c394beab..ef06f4f768a1dd620df25893041fdbe36f833a25 100644 (file)
@@ -11,7 +11,7 @@
 #include <linux/crypto.h>
 #include <linux/kernel.h>
 #include <linux/cpufeature.h>
-#include <asm/neon.h>
+#include <asm/simd.h>
 #include <crypto/b128ops.h>
 #include <crypto/scatterwalk.h>
 #include <crypto/internal/aead.h>
@@ -48,13 +48,11 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *key,
        if (key_len != SM4_KEY_SIZE)
                return -EINVAL;
 
-       kernel_neon_begin();
-
-       sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
-       sm4_ce_pmull_ghash_setup(ctx->key.rkey_enc, ctx->ghash_table);
-
-       kernel_neon_end();
+       scoped_ksimd() {
+               sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
+                               crypto_sm4_fk, crypto_sm4_ck);
+               sm4_ce_pmull_ghash_setup(ctx->key.rkey_enc, ctx->ghash_table);
+       }
        return 0;
 }
 
@@ -149,31 +147,28 @@ static int gcm_crypt(struct aead_request *req, struct skcipher_walk *walk,
        memcpy(iv, req->iv, GCM_IV_SIZE);
        put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
-       kernel_neon_begin();
+       scoped_ksimd() {
+               if (req->assoclen)
+                       gcm_calculate_auth_mac(req, ghash);
 
-       if (req->assoclen)
-               gcm_calculate_auth_mac(req, ghash);
+               do {
+                       unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
+                       const u8 *src = walk->src.virt.addr;
+                       u8 *dst = walk->dst.virt.addr;
+                       const u8 *l = NULL;
 
-       do {
-               unsigned int tail = walk->nbytes % SM4_BLOCK_SIZE;
-               const u8 *src = walk->src.virt.addr;
-               u8 *dst = walk->dst.virt.addr;
-               const u8 *l = NULL;
-
-               if (walk->nbytes == walk->total) {
-                       l = (const u8 *)&lengths;
-                       tail = 0;
-               }
-
-               sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, dst, src, iv,
-                                      walk->nbytes - tail, ghash,
-                                      ctx->ghash_table, l);
-
-               err = skcipher_walk_done(walk, tail);
-       } while (walk->nbytes);
+                       if (walk->nbytes == walk->total) {
+                               l = (const u8 *)&lengths;
+                               tail = 0;
+                       }
 
-       kernel_neon_end();
+                       sm4_ce_pmull_gcm_crypt(ctx->key.rkey_enc, dst, src, iv,
+                                              walk->nbytes - tail, ghash,
+                                              ctx->ghash_table, l);
 
+                       err = skcipher_walk_done(walk, tail);
+               } while (walk->nbytes);
+       }
        return err;
 }
 
index 7a60e7b559dc9628f9d5fa475b031b345e9c06c5..5569cece5a0b85e6ac96b36eac660618c8a87540 100644 (file)
@@ -8,7 +8,7 @@
  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
  */
 
-#include <asm/neon.h>
+#include <asm/simd.h>
 #include <crypto/b128ops.h>
 #include <crypto/internal/hash.h>
 #include <crypto/internal/skcipher.h>
@@ -74,10 +74,9 @@ static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
        if (key_len != SM4_KEY_SIZE)
                return -EINVAL;
 
-       kernel_neon_begin();
-       sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
-       kernel_neon_end();
+       scoped_ksimd()
+               sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
+                                 crypto_sm4_fk, crypto_sm4_ck);
        return 0;
 }
 
@@ -94,12 +93,12 @@ static int sm4_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
        if (ret)
                return ret;
 
-       kernel_neon_begin();
-       sm4_ce_expand_key(key, ctx->key1.rkey_enc,
-                         ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
-       sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
-                         ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
-       kernel_neon_end();
+       scoped_ksimd() {
+               sm4_ce_expand_key(key, ctx->key1.rkey_enc,
+                               ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
+               sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
+                               ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
+       }
 
        return 0;
 }
@@ -117,16 +116,14 @@ static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
                u8 *dst = walk.dst.virt.addr;
                unsigned int nblks;
 
-               kernel_neon_begin();
-
-               nblks = BYTES2BLKS(nbytes);
-               if (nblks) {
-                       sm4_ce_crypt(rkey, dst, src, nblks);
-                       nbytes -= nblks * SM4_BLOCK_SIZE;
+               scoped_ksimd() {
+                       nblks = BYTES2BLKS(nbytes);
+                       if (nblks) {
+                               sm4_ce_crypt(rkey, dst, src, nblks);
+                               nbytes -= nblks * SM4_BLOCK_SIZE;
+                       }
                }
 
-               kernel_neon_end();
-
                err = skcipher_walk_done(&walk, nbytes);
        }
 
@@ -167,16 +164,14 @@ static int sm4_cbc_crypt(struct skcipher_request *req,
 
                nblocks = nbytes / SM4_BLOCK_SIZE;
                if (nblocks) {
-                       kernel_neon_begin();
-
-                       if (encrypt)
-                               sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
-                                              walk.iv, nblocks);
-                       else
-                               sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
-                                              walk.iv, nblocks);
-
-                       kernel_neon_end();
+                       scoped_ksimd() {
+                               if (encrypt)
+                                       sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
+                                                      walk.iv, nblocks);
+                               else
+                                       sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
+                                                      walk.iv, nblocks);
+                       }
                }
 
                err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -249,16 +244,14 @@ static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
        if (err)
                return err;
 
-       kernel_neon_begin();
-
-       if (encrypt)
-               sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
-                                  walk.src.virt.addr, walk.iv, walk.nbytes);
-       else
-               sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
-                                  walk.src.virt.addr, walk.iv, walk.nbytes);
-
-       kernel_neon_end();
+       scoped_ksimd() {
+               if (encrypt)
+                       sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
+                                          walk.src.virt.addr, walk.iv, walk.nbytes);
+               else
+                       sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
+                                          walk.src.virt.addr, walk.iv, walk.nbytes);
+       }
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -288,28 +281,26 @@ static int sm4_ctr_crypt(struct skcipher_request *req)
                u8 *dst = walk.dst.virt.addr;
                unsigned int nblks;
 
-               kernel_neon_begin();
-
-               nblks = BYTES2BLKS(nbytes);
-               if (nblks) {
-                       sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
-                       dst += nblks * SM4_BLOCK_SIZE;
-                       src += nblks * SM4_BLOCK_SIZE;
-                       nbytes -= nblks * SM4_BLOCK_SIZE;
-               }
-
-               /* tail */
-               if (walk.nbytes == walk.total && nbytes > 0) {
-                       u8 keystream[SM4_BLOCK_SIZE];
-
-                       sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
-                       crypto_inc(walk.iv, SM4_BLOCK_SIZE);
-                       crypto_xor_cpy(dst, src, keystream, nbytes);
-                       nbytes = 0;
+               scoped_ksimd() {
+                       nblks = BYTES2BLKS(nbytes);
+                       if (nblks) {
+                               sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
+                               dst += nblks * SM4_BLOCK_SIZE;
+                               src += nblks * SM4_BLOCK_SIZE;
+                               nbytes -= nblks * SM4_BLOCK_SIZE;
+                       }
+
+                       /* tail */
+                       if (walk.nbytes == walk.total && nbytes > 0) {
+                               u8 keystream[SM4_BLOCK_SIZE];
+
+                               sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
+                               crypto_inc(walk.iv, SM4_BLOCK_SIZE);
+                               crypto_xor_cpy(dst, src, keystream, nbytes);
+                               nbytes = 0;
+                       }
                }
 
-               kernel_neon_end();
-
                err = skcipher_walk_done(&walk, nbytes);
        }
 
@@ -359,18 +350,16 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
                if (nbytes < walk.total)
                        nbytes &= ~(SM4_BLOCK_SIZE - 1);
 
-               kernel_neon_begin();
-
-               if (encrypt)
-                       sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
-                                      walk.src.virt.addr, walk.iv, nbytes,
-                                      rkey2_enc);
-               else
-                       sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
-                                      walk.src.virt.addr, walk.iv, nbytes,
-                                      rkey2_enc);
-
-               kernel_neon_end();
+               scoped_ksimd() {
+                       if (encrypt)
+                               sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
+                                               walk.src.virt.addr, walk.iv, nbytes,
+                                               rkey2_enc);
+                       else
+                               sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
+                                               walk.src.virt.addr, walk.iv, nbytes,
+                                               rkey2_enc);
+               }
 
                rkey2_enc = NULL;
 
@@ -395,18 +384,16 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
        if (err)
                return err;
 
-       kernel_neon_begin();
-
-       if (encrypt)
-               sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
-                              walk.src.virt.addr, walk.iv, walk.nbytes,
-                              rkey2_enc);
-       else
-               sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
-                              walk.src.virt.addr, walk.iv, walk.nbytes,
-                              rkey2_enc);
-
-       kernel_neon_end();
+       scoped_ksimd() {
+               if (encrypt)
+                       sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
+                                       walk.src.virt.addr, walk.iv, walk.nbytes,
+                                       rkey2_enc);
+               else
+                       sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
+                                       walk.src.virt.addr, walk.iv, walk.nbytes,
+                                       rkey2_enc);
+       }
 
        return skcipher_walk_done(&walk, 0);
 }
@@ -510,11 +497,9 @@ static int sm4_cbcmac_setkey(struct crypto_shash *tfm, const u8 *key,
        if (key_len != SM4_KEY_SIZE)
                return -EINVAL;
 
-       kernel_neon_begin();
-       sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
-       kernel_neon_end();
-
+       scoped_ksimd()
+               sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
+                               crypto_sm4_fk, crypto_sm4_ck);
        return 0;
 }
 
@@ -530,15 +515,13 @@ static int sm4_cmac_setkey(struct crypto_shash *tfm, const u8 *key,
 
        memset(consts, 0, SM4_BLOCK_SIZE);
 
-       kernel_neon_begin();
-
-       sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
+       scoped_ksimd() {
+               sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
+                               crypto_sm4_fk, crypto_sm4_ck);
 
-       /* encrypt the zero block */
-       sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);
-
-       kernel_neon_end();
+               /* encrypt the zero block */
+               sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);
+       }
 
        /* gf(2^128) multiply zero-ciphertext with u and u^2 */
        a = be64_to_cpu(consts[0].a);
@@ -568,18 +551,16 @@ static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,
        if (key_len != SM4_KEY_SIZE)
                return -EINVAL;
 
-       kernel_neon_begin();
-
-       sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
+       scoped_ksimd() {
+               sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
+                               crypto_sm4_fk, crypto_sm4_ck);
 
-       sm4_ce_crypt_block(ctx->key.rkey_enc, key2, ks[0]);
-       sm4_ce_crypt(ctx->key.rkey_enc, ctx->consts, ks[1], 2);
+               sm4_ce_crypt_block(ctx->key.rkey_enc, key2, ks[0]);
+               sm4_ce_crypt(ctx->key.rkey_enc, ctx->consts, ks[1], 2);
 
-       sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
-                         crypto_sm4_fk, crypto_sm4_ck);
-
-       kernel_neon_end();
+               sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
+                               crypto_sm4_fk, crypto_sm4_ck);
+       }
 
        return 0;
 }
@@ -600,10 +581,9 @@ static int sm4_mac_update(struct shash_desc *desc, const u8 *p,
        unsigned int nblocks = len / SM4_BLOCK_SIZE;
 
        len %= SM4_BLOCK_SIZE;
-       kernel_neon_begin();
-       sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
-                         nblocks, false, true);
-       kernel_neon_end();
+       scoped_ksimd()
+               sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
+                               nblocks, false, true);
        return len;
 }
 
@@ -619,10 +599,9 @@ static int sm4_cmac_finup(struct shash_desc *desc, const u8 *src,
                ctx->digest[len] ^= 0x80;
                consts += SM4_BLOCK_SIZE;
        }
-       kernel_neon_begin();
-       sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
-                         false, true);
-       kernel_neon_end();
+       scoped_ksimd()
+               sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
+                                 false, true);
        memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
        return 0;
 }
@@ -635,10 +614,9 @@ static int sm4_cbcmac_finup(struct shash_desc *desc, const u8 *src,
 
        if (len) {
                crypto_xor(ctx->digest, src, len);
-               kernel_neon_begin();
-               sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
-                                  ctx->digest);
-               kernel_neon_end();
+               scoped_ksimd()
+                       sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
+                                          ctx->digest);
        }
        memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
        return 0;
index e3500aca2d18bddb61ba5244ef1e8a95791ea22a..e944c2a2efb025212490cbea7c953d9494b4354b 100644 (file)
@@ -48,11 +48,8 @@ static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
 
                nblocks = nbytes / SM4_BLOCK_SIZE;
                if (nblocks) {
-                       kernel_neon_begin();
-
-                       sm4_neon_crypt(rkey, dst, src, nblocks);
-
-                       kernel_neon_end();
+                       scoped_ksimd()
+                               sm4_neon_crypt(rkey, dst, src, nblocks);
                }
 
                err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -126,12 +123,9 @@ static int sm4_cbc_decrypt(struct skcipher_request *req)
 
                nblocks = nbytes / SM4_BLOCK_SIZE;
                if (nblocks) {
-                       kernel_neon_begin();
-
-                       sm4_neon_cbc_dec(ctx->rkey_dec, dst, src,
-                                        walk.iv, nblocks);
-
-                       kernel_neon_end();
+                       scoped_ksimd()
+                               sm4_neon_cbc_dec(ctx->rkey_dec, dst, src,
+                                                walk.iv, nblocks);
                }
 
                err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -157,12 +151,9 @@ static int sm4_ctr_crypt(struct skcipher_request *req)
 
                nblocks = nbytes / SM4_BLOCK_SIZE;
                if (nblocks) {
-                       kernel_neon_begin();
-
-                       sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src,
-                                          walk.iv, nblocks);
-
-                       kernel_neon_end();
+                       scoped_ksimd()
+                               sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src,
+                                                  walk.iv, nblocks);
 
                        dst += nblocks * SM4_BLOCK_SIZE;
                        src += nblocks * SM4_BLOCK_SIZE;