/* SPDX-License-Identifier: GPL-2.0-only */
 /*
- * aesce-ccm-core.S - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-core.S - AES-CCM transform for ARMv8 with Crypto Extensions
  *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
  */
 
 #include <linux/linkage.h>
        ld1     {v2.16b}, [x1], #16             /* load next input block */
        .if     \enc == 1
        eor     v2.16b, v2.16b, v5.16b          /* final round enc+mac */
-       eor     v1.16b, v1.16b, v2.16b          /* xor with crypted ctr */
+       eor     v6.16b, v1.16b, v2.16b          /* xor with crypted ctr */
        .else
        eor     v2.16b, v2.16b, v1.16b          /* xor with crypted ctr */
-       eor     v1.16b, v2.16b, v5.16b          /* final round enc */
+       eor     v6.16b, v2.16b, v5.16b          /* final round enc */
        .endif
        eor     v0.16b, v0.16b, v2.16b          /* xor mac with pt ^ rk[last] */
-       st1     {v1.16b}, [x0], #16             /* write output block */
+       st1     {v6.16b}, [x0], #16             /* write output block */
        bne     0b
 CPU_LE(        rev     x8, x8                  )
        st1     {v0.16b}, [x5]                  /* store mac */
 
 6:     eor     v0.16b, v0.16b, v5.16b          /* final round mac */
        eor     v1.16b, v1.16b, v5.16b          /* final round enc */
-       st1     {v0.16b}, [x5]                  /* store mac */
-       add     w2, w2, #16                     /* process partial tail block */
-7:     ldrb    w9, [x1], #1                    /* get 1 byte of input */
-       umov    w6, v1.b[0]                     /* get top crypted ctr byte */
-       umov    w7, v0.b[0]                     /* get top mac byte */
+
+       add     x1, x1, w2, sxtw                /* rewind the input pointer (w2 < 0) */
+       add     x0, x0, w2, sxtw                /* rewind the output pointer */
+
+       adr_l   x8, .Lpermute                   /* load permute vectors */
+       add     x9, x8, w2, sxtw
+       sub     x8, x8, w2, sxtw
+       ld1     {v7.16b-v8.16b}, [x9]
+       ld1     {v9.16b}, [x8]
+
+       ld1     {v2.16b}, [x1]                  /* load a full block of input */
+       tbl     v1.16b, {v1.16b}, v7.16b        /* move keystream to end of register */
        .if     \enc == 1
-       eor     w7, w7, w9
-       eor     w9, w9, w6
+       tbl     v7.16b, {v2.16b}, v9.16b        /* copy plaintext to start of v7 */
+       eor     v2.16b, v2.16b, v1.16b          /* encrypt partial input block */
        .else
-       eor     w9, w9, w6
-       eor     w7, w7, w9
+       eor     v2.16b, v2.16b, v1.16b          /* decrypt partial input block */
+       tbl     v7.16b, {v2.16b}, v9.16b        /* copy plaintext to start of v7 */
        .endif
-       strb    w9, [x0], #1                    /* store out byte */
-       strb    w7, [x5], #1                    /* store mac byte */
-       subs    w2, w2, #1
-       beq     5b
-       ext     v0.16b, v0.16b, v0.16b, #1      /* shift out mac byte */
-       ext     v1.16b, v1.16b, v1.16b, #1      /* shift out ctr byte */
-       b       7b
+       eor     v0.16b, v0.16b, v7.16b          /* fold plaintext into mac */
+       tbx     v2.16b, {v6.16b}, v8.16b        /* insert output from previous iteration */
+
+       st1     {v0.16b}, [x5]                  /* store mac */
+       st1     {v2.16b}, [x0]                  /* store output block */
+       ret
        .endm
 
        /*
 SYM_FUNC_START(ce_aes_ccm_decrypt)
        aes_ccm_do_crypt        0
 SYM_FUNC_END(ce_aes_ccm_decrypt)
+
+       .section ".rodata", "a"
+       .align  6
+       .fill   15, 1, 0xff
+.Lpermute:
+       .byte   0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
+       .byte   0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
+       .fill   15, 1, 0xff