]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
crypto/arm64: aes/xts - Use single ksimd scope to reduce stack bloat
authorArd Biesheuvel <ardb@kernel.org>
Wed, 3 Dec 2025 16:38:05 +0000 (17:38 +0100)
committerEric Biggers <ebiggers@kernel.org>
Tue, 9 Dec 2025 23:10:21 +0000 (15:10 -0800)
The ciphertext stealing logic in the AES-XTS implementation creates a
separate ksimd scope to call into the FP/SIMD core routines, and in some
cases (CONFIG_KASAN_STACK is one, but there might be others), the 528
byte kernel mode FP/SIMD buffer that is allocated inside this scope is
not shared with the preceding ksimd scope, resulting in unnecessary
stack bloat.

Considering that

a) the XTS ciphertext stealing logic is never called for block
   encryption use cases, and XTS is rarely used for anything else,

b) in the vast majority of cases, the entire input block is processed
   during the first iteration of the loop,

we can combine both ksimd scopes into a single one with no practical
impact on how often/how long FP/SIMD is en/disabled, allowing us to
reuse the same stack slot for both FP/SIMD routine calls.

Fixes: ba3c1b3b5ac9 ("crypto/arm64: aes-blk - Switch to 'ksimd' scoped guard API")
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
Tested-by: Arnd Bergmann <arnd@arndb.de>
Link: https://lore.kernel.org/r/20251203163803.157541-5-ardb@kernel.org
Signed-off-by: Eric Biggers <ebiggers@kernel.org>
arch/arm64/crypto/aes-glue.c
arch/arm64/crypto/aes-neonbs-glue.c

index b087b900d2790b4e1b328f083b91546098d26bb8..c51d4487e9e9b6eb9b8b7d1d2b0324c59e273a1b 100644 (file)
@@ -549,38 +549,37 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
                tail = 0;
        }
 
-       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
-               int nbytes = walk.nbytes;
+       scoped_ksimd() {
+               for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+                       int nbytes = walk.nbytes;
 
-               if (walk.nbytes < walk.total)
-                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+                       if (walk.nbytes < walk.total)
+                               nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               scoped_ksimd()
                        aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                        ctx->key1.key_enc, rounds, nbytes,
                                        ctx->key2.key_enc, walk.iv, first);
-               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-       }
+                       err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+               }
 
-       if (err || likely(!tail))
-               return err;
+               if (err || likely(!tail))
+                       return err;
 
-       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-       if (req->dst != req->src)
-               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+               dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+               if (req->dst != req->src)
+                       dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 
-       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-                                  req->iv);
+               skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                          req->iv);
 
-       err = skcipher_walk_virt(&walk, &subreq, false);
-       if (err)
-               return err;
+               err = skcipher_walk_virt(&walk, &subreq, false);
+               if (err)
+                       return err;
 
-       scoped_ksimd()
                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                ctx->key1.key_enc, rounds, walk.nbytes,
                                ctx->key2.key_enc, walk.iv, first);
-
+       }
        return skcipher_walk_done(&walk, 0);
 }
 
@@ -619,39 +618,37 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
                tail = 0;
        }
 
-       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
-               int nbytes = walk.nbytes;
+       scoped_ksimd() {
+               for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+                       int nbytes = walk.nbytes;
 
-               if (walk.nbytes < walk.total)
-                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+                       if (walk.nbytes < walk.total)
+                               nbytes &= ~(AES_BLOCK_SIZE - 1);
 
-               scoped_ksimd()
                        aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                        ctx->key1.key_dec, rounds, nbytes,
                                        ctx->key2.key_enc, walk.iv, first);
-               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-       }
+                       err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+               }
 
-       if (err || likely(!tail))
-               return err;
-
-       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-       if (req->dst != req->src)
-               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+               if (err || likely(!tail))
+                       return err;
 
-       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-                                  req->iv);
+               dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+               if (req->dst != req->src)
+                       dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 
-       err = skcipher_walk_virt(&walk, &subreq, false);
-       if (err)
-               return err;
+               skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                          req->iv);
 
+               err = skcipher_walk_virt(&walk, &subreq, false);
+               if (err)
+                       return err;
 
-       scoped_ksimd()
                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                ctx->key1.key_dec, rounds, walk.nbytes,
                                ctx->key2.key_enc, walk.iv, first);
-
+       }
        return skcipher_walk_done(&walk, 0);
 }
 
index d496effb0a5b77119b4d018770c0ddbe749b3efc..cb87c8fc66b3b056ff4dde39d974c786253f4713 100644 (file)
@@ -312,13 +312,13 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
        if (err)
                return err;
 
-       while (walk.nbytes >= AES_BLOCK_SIZE) {
-               int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
-               out = walk.dst.virt.addr;
-               in = walk.src.virt.addr;
-               nbytes = walk.nbytes;
+       scoped_ksimd() {
+               while (walk.nbytes >= AES_BLOCK_SIZE) {
+                       int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
+                       out = walk.dst.virt.addr;
+                       in = walk.src.virt.addr;
+                       nbytes = walk.nbytes;
 
-               scoped_ksimd() {
                        if (blocks >= 8) {
                                if (first == 1)
                                        neon_aes_ecb_encrypt(walk.iv, walk.iv,
@@ -344,30 +344,28 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
                                                             ctx->twkey, walk.iv, first);
                                nbytes = first = 0;
                        }
+                       err = skcipher_walk_done(&walk, nbytes);
                }
-               err = skcipher_walk_done(&walk, nbytes);
-       }
 
-       if (err || likely(!tail))
-               return err;
+               if (err || likely(!tail))
+                       return err;
 
-       /* handle ciphertext stealing */
-       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-       if (req->dst != req->src)
-               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+               /* handle ciphertext stealing */
+               dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+               if (req->dst != req->src)
+                       dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 
-       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-                                  req->iv);
+               skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                          req->iv);
 
-       err = skcipher_walk_virt(&walk, req, false);
-       if (err)
-               return err;
+               err = skcipher_walk_virt(&walk, req, false);
+               if (err)
+                       return err;
 
-       out = walk.dst.virt.addr;
-       in = walk.src.virt.addr;
-       nbytes = walk.nbytes;
+               out = walk.dst.virt.addr;
+               in = walk.src.virt.addr;
+               nbytes = walk.nbytes;
 
-       scoped_ksimd() {
                if (encrypt)
                        neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
                                             ctx->key.rounds, nbytes, ctx->twkey,