From fae3b96ba6015c35a973da09bf313d90e4e4bb94 Mon Sep 17 00:00:00 2001 From: Eric Biggers Date: Wed, 1 Oct 2025 19:31:10 -0700 Subject: [PATCH] crypto: x86/aes-gcm - add VAES+AVX2 optimized code Add an implementation of AES-GCM that uses 256-bit vectors and the following CPU features: Vector AES (VAES), Vector Carryless Multiplication (VPCLMULQDQ), and AVX2. It doesn't require AVX512. So unlike the existing VAES+AVX512 code, it works on CPUs that support VAES but not AVX512, specifically: - AMD Zen 3, both client and server - Intel Alder Lake, Raptor Lake, Meteor Lake, Arrow Lake, and Lunar Lake. (These are client CPUs.) - Intel Sierra Forest. (This is a server CPU.) On these CPUs, this VAES+AVX2 code is much faster than the existing AES-NI code. The AES-NI code uses only 128-bit vectors. These CPUs are widely deployed, making VAES+AVX2 code worthwhile even though hopefully future x86_64 CPUs will uniformly support AVX512. This implementation will also serve as the fallback 256-bit implementation for older Intel CPUs (Ice Lake and Tiger Lake) that support AVX512 but downclock too eagerly when 512-bit vectors are used. Currently, the VAES+AVX10/256 implementation serves that purpose. A later commit will remove that and just use the VAES+AVX2 one. (Note that AES-XTS and AES-CTR already successfully use this approach.) I originally wrote this AES-GCM implementation for BoringSSL. It's been in BoringSSL for a while now, including in Chromium. This is a port of it to the Linux kernel. The main changes in the Linux version include: - Port from "perlasm" to a standard .S file. - Align all assembly functions with what aesni-intel_glue.c expects, including adding support for lengths not a multiple of 16 bytes. - Rework the en/decryption of the final 1 to 127 bytes. This commit increases AES-256-GCM throughput on AMD Milan (Zen 3) by up to 74%, as shown by the following tables: Table 1: AES-256-GCM encryption throughput change, CPU vs. message length in bytes: | 16384 | 4096 | 4095 | 1420 | 512 | 500 | ----------------------+-------+-------+-------+-------+-------+-------+ AMD Milan (Zen 3) | 67% | 59% | 61% | 39% | 23% | 27% | | 300 | 200 | 64 | 63 | 16 | ----------------------+-------+-------+-------+-------+-------+ AMD Milan (Zen 3) | 14% | 12% | 7% | 7% | 0% | Table 2: AES-256-GCM decryption throughput change, CPU vs. message length in bytes: | 16384 | 4096 | 4095 | 1420 | 512 | 500 | ----------------------+-------+-------+-------+-------+-------+-------+ AMD Milan (Zen 3) | 74% | 65% | 65% | 44% | 23% | 26% | | 300 | 200 | 64 | 63 | 16 | ----------------------+-------+-------+-------+-------+-------+ AMD Milan (Zen 3) | 12% | 11% | 3% | 2% | -3% | Acked-by: Ard Biesheuvel Tested-by: Ard Biesheuvel Link: https://lore.kernel.org/r/20251002023117.37504-2-ebiggers@kernel.org Signed-off-by: Eric Biggers --- arch/x86/crypto/Makefile | 1 + arch/x86/crypto/aes-gcm-vaes-avx2.S | 1145 +++++++++++++++++++++++++++ arch/x86/crypto/aesni-intel_glue.c | 111 ++- 3 files changed, 1255 insertions(+), 2 deletions(-) create mode 100644 arch/x86/crypto/aes-gcm-vaes-avx2.S diff --git a/arch/x86/crypto/Makefile b/arch/x86/crypto/Makefile index 2d30d5d361458..f6f7b2b8b853e 100644 --- a/arch/x86/crypto/Makefile +++ b/arch/x86/crypto/Makefile @@ -46,6 +46,7 @@ obj-$(CONFIG_CRYPTO_AES_NI_INTEL) += aesni-intel.o aesni-intel-y := aesni-intel_asm.o aesni-intel_glue.o aesni-intel-$(CONFIG_64BIT) += aes-ctr-avx-x86_64.o \ aes-gcm-aesni-x86_64.o \ + aes-gcm-vaes-avx2.o \ aes-xts-avx-x86_64.o \ aes-gcm-avx10-x86_64.o diff --git a/arch/x86/crypto/aes-gcm-vaes-avx2.S b/arch/x86/crypto/aes-gcm-vaes-avx2.S new file mode 100644 index 0000000000000..f58096a37342f --- /dev/null +++ b/arch/x86/crypto/aes-gcm-vaes-avx2.S @@ -0,0 +1,1145 @@ +/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */ +// +// AES-GCM implementation for x86_64 CPUs that support the following CPU +// features: VAES && VPCLMULQDQ && AVX2 +// +// Copyright 2025 Google LLC +// +// Author: Eric Biggers +// +//------------------------------------------------------------------------------ +// +// This file is dual-licensed, meaning that you can use it under your choice of +// either of the following two licenses: +// +// Licensed under the Apache License 2.0 (the "License"). You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// or +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// ----------------------------------------------------------------------------- +// +// This is similar to aes-gcm-avx10-x86_64.S, but it uses AVX2 instead of +// AVX512. This means it can only use 16 vector registers instead of 32, the +// maximum vector length is 32 bytes, and some instructions such as vpternlogd +// and masked loads/stores are unavailable. However, it is able to run on CPUs +// that have VAES without AVX512, namely AMD Zen 3 (including "Milan" server +// CPUs), various Intel client CPUs such as Alder Lake, and Intel Sierra Forest. +// +// This implementation also uses Karatsuba multiplication instead of schoolbook +// multiplication for GHASH in its main loop. This does not help much on Intel, +// but it improves performance by ~5% on AMD Zen 3. Other factors weighing +// slightly in favor of Karatsuba multiplication in this implementation are the +// lower maximum vector length (which means there are fewer key powers, so we +// can cache the halves of each key power XOR'd together and still use less +// memory than the AVX512 implementation), and the unavailability of the +// vpternlogd instruction (which helped schoolbook a bit more than Karatsuba). + +#include + +.section .rodata +.p2align 4 + + // The below three 16-byte values must be in the order that they are, as + // they are really two 32-byte tables and a 16-byte value that overlap: + // + // - The first 32-byte table begins at .Lselect_high_bytes_table. + // For 0 <= len <= 16, the 16-byte value at + // '.Lselect_high_bytes_table + len' selects the high 'len' bytes of + // another 16-byte value when AND'ed with it. + // + // - The second 32-byte table begins at .Lrshift_and_bswap_table. + // For 0 <= len <= 16, the 16-byte value at + // '.Lrshift_and_bswap_table + len' is a vpshufb mask that does the + // following operation: right-shift by '16 - len' bytes (shifting in + // zeroes), then reflect all 16 bytes. + // + // - The 16-byte value at .Lbswap_mask is a vpshufb mask that reflects + // all 16 bytes. +.Lselect_high_bytes_table: + .octa 0 +.Lrshift_and_bswap_table: + .octa 0xffffffffffffffffffffffffffffffff +.Lbswap_mask: + .octa 0x000102030405060708090a0b0c0d0e0f + + // Sixteen 0x0f bytes. By XOR'ing an entry of .Lrshift_and_bswap_table + // with this, we get a mask that left-shifts by '16 - len' bytes. +.Lfifteens: + .octa 0x0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f + + // This is the GHASH reducing polynomial without its constant term, i.e. + // x^128 + x^7 + x^2 + x, represented using the backwards mapping + // between bits and polynomial coefficients. + // + // Alternatively, it can be interpreted as the naturally-ordered + // representation of the polynomial x^127 + x^126 + x^121 + 1, i.e. the + // "reversed" GHASH reducing polynomial without its x^128 term. +.Lgfpoly: + .octa 0xc2000000000000000000000000000001 + + // Same as above, but with the (1 << 64) bit set. +.Lgfpoly_and_internal_carrybit: + .octa 0xc2000000000000010000000000000001 + + // Values needed to prepare the initial vector of counter blocks. +.Lctr_pattern: + .octa 0 + .octa 1 + + // The number of AES blocks per vector, as a 128-bit value. +.Linc_2blocks: + .octa 2 + +// Offsets in struct aes_gcm_key_vaes_avx2 +#define OFFSETOF_AESKEYLEN 480 +#define OFFSETOF_H_POWERS 512 +#define NUM_H_POWERS 8 +#define OFFSETOFEND_H_POWERS (OFFSETOF_H_POWERS + (NUM_H_POWERS * 16)) +#define OFFSETOF_H_POWERS_XORED OFFSETOFEND_H_POWERS + +.text + +// Do one step of GHASH-multiplying the 128-bit lanes of \a by the 128-bit lanes +// of \b and storing the reduced products in \dst. Uses schoolbook +// multiplication. +.macro _ghash_mul_step i, a, b, dst, gfpoly, t0, t1, t2 +.if \i == 0 + vpclmulqdq $0x00, \a, \b, \t0 // LO = a_L * b_L + vpclmulqdq $0x01, \a, \b, \t1 // MI_0 = a_L * b_H +.elseif \i == 1 + vpclmulqdq $0x10, \a, \b, \t2 // MI_1 = a_H * b_L +.elseif \i == 2 + vpxor \t2, \t1, \t1 // MI = MI_0 + MI_1 +.elseif \i == 3 + vpclmulqdq $0x01, \t0, \gfpoly, \t2 // LO_L*(x^63 + x^62 + x^57) +.elseif \i == 4 + vpshufd $0x4e, \t0, \t0 // Swap halves of LO +.elseif \i == 5 + vpxor \t0, \t1, \t1 // Fold LO into MI (part 1) + vpxor \t2, \t1, \t1 // Fold LO into MI (part 2) +.elseif \i == 6 + vpclmulqdq $0x11, \a, \b, \dst // HI = a_H * b_H +.elseif \i == 7 + vpclmulqdq $0x01, \t1, \gfpoly, \t0 // MI_L*(x^63 + x^62 + x^57) +.elseif \i == 8 + vpshufd $0x4e, \t1, \t1 // Swap halves of MI +.elseif \i == 9 + vpxor \t1, \dst, \dst // Fold MI into HI (part 1) + vpxor \t0, \dst, \dst // Fold MI into HI (part 2) +.endif +.endm + +// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and store +// the reduced products in \dst. See _ghash_mul_step for full explanation. +.macro _ghash_mul a, b, dst, gfpoly, t0, t1, t2 +.irp i, 0,1,2,3,4,5,6,7,8,9 + _ghash_mul_step \i, \a, \b, \dst, \gfpoly, \t0, \t1, \t2 +.endr +.endm + +// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and add the +// *unreduced* products to \lo, \mi, and \hi. +.macro _ghash_mul_noreduce a, b, lo, mi, hi, t0 + vpclmulqdq $0x00, \a, \b, \t0 // a_L * b_L + vpxor \t0, \lo, \lo + vpclmulqdq $0x01, \a, \b, \t0 // a_L * b_H + vpxor \t0, \mi, \mi + vpclmulqdq $0x10, \a, \b, \t0 // a_H * b_L + vpxor \t0, \mi, \mi + vpclmulqdq $0x11, \a, \b, \t0 // a_H * b_H + vpxor \t0, \hi, \hi +.endm + +// Reduce the unreduced products from \lo, \mi, and \hi and store the 128-bit +// reduced products in \hi. See _ghash_mul_step for explanation of reduction. +.macro _ghash_reduce lo, mi, hi, gfpoly, t0 + vpclmulqdq $0x01, \lo, \gfpoly, \t0 + vpshufd $0x4e, \lo, \lo + vpxor \lo, \mi, \mi + vpxor \t0, \mi, \mi + vpclmulqdq $0x01, \mi, \gfpoly, \t0 + vpshufd $0x4e, \mi, \mi + vpxor \mi, \hi, \hi + vpxor \t0, \hi, \hi +.endm + +// This is a specialized version of _ghash_mul that computes \a * \a, i.e. it +// squares \a. It skips computing MI = (a_L * a_H) + (a_H * a_L) = 0. +.macro _ghash_square a, dst, gfpoly, t0, t1 + vpclmulqdq $0x00, \a, \a, \t0 // LO = a_L * a_L + vpclmulqdq $0x11, \a, \a, \dst // HI = a_H * a_H + vpclmulqdq $0x01, \t0, \gfpoly, \t1 // LO_L*(x^63 + x^62 + x^57) + vpshufd $0x4e, \t0, \t0 // Swap halves of LO + vpxor \t0, \t1, \t1 // Fold LO into MI + vpclmulqdq $0x01, \t1, \gfpoly, \t0 // MI_L*(x^63 + x^62 + x^57) + vpshufd $0x4e, \t1, \t1 // Swap halves of MI + vpxor \t1, \dst, \dst // Fold MI into HI (part 1) + vpxor \t0, \dst, \dst // Fold MI into HI (part 2) +.endm + +// void aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key); +// +// Given the expanded AES key |key->base.aes_key|, derive the GHASH subkey and +// initialize |key->h_powers| and |key->h_powers_xored|. +// +// We use h_powers[0..7] to store H^8 through H^1, and h_powers_xored[0..7] to +// store the 64-bit halves of the key powers XOR'd together (for Karatsuba +// multiplication) in the order 8,6,7,5,4,2,3,1. +SYM_FUNC_START(aes_gcm_precompute_vaes_avx2) + + // Function arguments + .set KEY, %rdi + + // Additional local variables + .set POWERS_PTR, %rsi + .set RNDKEYLAST_PTR, %rdx + .set TMP0, %ymm0 + .set TMP0_XMM, %xmm0 + .set TMP1, %ymm1 + .set TMP1_XMM, %xmm1 + .set TMP2, %ymm2 + .set TMP2_XMM, %xmm2 + .set H_CUR, %ymm3 + .set H_CUR_XMM, %xmm3 + .set H_CUR2, %ymm4 + .set H_INC, %ymm5 + .set H_INC_XMM, %xmm5 + .set GFPOLY, %ymm6 + .set GFPOLY_XMM, %xmm6 + + // Encrypt an all-zeroes block to get the raw hash subkey. + movl OFFSETOF_AESKEYLEN(KEY), %eax + lea 6*16(KEY,%rax,4), RNDKEYLAST_PTR + vmovdqu (KEY), H_CUR_XMM // Zero-th round key XOR all-zeroes block + lea 16(KEY), %rax +1: + vaesenc (%rax), H_CUR_XMM, H_CUR_XMM + add $16, %rax + cmp %rax, RNDKEYLAST_PTR + jne 1b + vaesenclast (RNDKEYLAST_PTR), H_CUR_XMM, H_CUR_XMM + + // Reflect the bytes of the raw hash subkey. + vpshufb .Lbswap_mask(%rip), H_CUR_XMM, H_CUR_XMM + + // Finish preprocessing the byte-reflected hash subkey by multiplying it + // by x^-1 ("standard" interpretation of polynomial coefficients) or + // equivalently x^1 (natural interpretation). This gets the key into a + // format that avoids having to bit-reflect the data blocks later. + vpshufd $0xd3, H_CUR_XMM, TMP0_XMM + vpsrad $31, TMP0_XMM, TMP0_XMM + vpaddq H_CUR_XMM, H_CUR_XMM, H_CUR_XMM + vpand .Lgfpoly_and_internal_carrybit(%rip), TMP0_XMM, TMP0_XMM + vpxor TMP0_XMM, H_CUR_XMM, H_CUR_XMM + + // Load the gfpoly constant. + vbroadcasti128 .Lgfpoly(%rip), GFPOLY + + // Square H^1 to get H^2. + _ghash_square H_CUR_XMM, H_INC_XMM, GFPOLY_XMM, TMP0_XMM, TMP1_XMM + + // Create H_CUR = [H^2, H^1] and H_INC = [H^2, H^2]. + vinserti128 $1, H_CUR_XMM, H_INC, H_CUR + vinserti128 $1, H_INC_XMM, H_INC, H_INC + + // Compute H_CUR2 = [H^4, H^3]. + _ghash_mul H_INC, H_CUR, H_CUR2, GFPOLY, TMP0, TMP1, TMP2 + + // Store [H^2, H^1] and [H^4, H^3]. + vmovdqu H_CUR, OFFSETOF_H_POWERS+3*32(KEY) + vmovdqu H_CUR2, OFFSETOF_H_POWERS+2*32(KEY) + + // For Karatsuba multiplication: compute and store the two 64-bit halves + // of each key power XOR'd together. Order is 4,2,3,1. + vpunpcklqdq H_CUR, H_CUR2, TMP0 + vpunpckhqdq H_CUR, H_CUR2, TMP1 + vpxor TMP1, TMP0, TMP0 + vmovdqu TMP0, OFFSETOF_H_POWERS_XORED+32(KEY) + + // Compute and store H_CUR = [H^6, H^5] and H_CUR2 = [H^8, H^7]. + _ghash_mul H_INC, H_CUR2, H_CUR, GFPOLY, TMP0, TMP1, TMP2 + _ghash_mul H_INC, H_CUR, H_CUR2, GFPOLY, TMP0, TMP1, TMP2 + vmovdqu H_CUR, OFFSETOF_H_POWERS+1*32(KEY) + vmovdqu H_CUR2, OFFSETOF_H_POWERS+0*32(KEY) + + // Again, compute and store the two 64-bit halves of each key power + // XOR'd together. Order is 8,6,7,5. + vpunpcklqdq H_CUR, H_CUR2, TMP0 + vpunpckhqdq H_CUR, H_CUR2, TMP1 + vpxor TMP1, TMP0, TMP0 + vmovdqu TMP0, OFFSETOF_H_POWERS_XORED(KEY) + + vzeroupper + RET +SYM_FUNC_END(aes_gcm_precompute_vaes_avx2) + +// Do one step of the GHASH update of four vectors of data blocks. +// \i: the step to do, 0 through 9 +// \ghashdata_ptr: pointer to the data blocks (ciphertext or AAD) +// KEY: pointer to struct aes_gcm_key_vaes_avx2 +// BSWAP_MASK: mask for reflecting the bytes of blocks +// H_POW[2-1]_XORED: cached values from KEY->h_powers_xored +// TMP[0-2]: temporary registers. TMP[1-2] must be preserved across steps. +// LO, MI: working state for this macro that must be preserved across steps +// GHASH_ACC: the GHASH accumulator (input/output) +.macro _ghash_step_4x i, ghashdata_ptr + .set HI, GHASH_ACC # alias + .set HI_XMM, GHASH_ACC_XMM +.if \i == 0 + // First vector + vmovdqu 0*32(\ghashdata_ptr), TMP1 + vpshufb BSWAP_MASK, TMP1, TMP1 + vmovdqu OFFSETOF_H_POWERS+0*32(KEY), TMP2 + vpxor GHASH_ACC, TMP1, TMP1 + vpclmulqdq $0x00, TMP2, TMP1, LO + vpclmulqdq $0x11, TMP2, TMP1, HI + vpunpckhqdq TMP1, TMP1, TMP0 + vpxor TMP1, TMP0, TMP0 + vpclmulqdq $0x00, H_POW2_XORED, TMP0, MI +.elseif \i == 1 +.elseif \i == 2 + // Second vector + vmovdqu 1*32(\ghashdata_ptr), TMP1 + vpshufb BSWAP_MASK, TMP1, TMP1 + vmovdqu OFFSETOF_H_POWERS+1*32(KEY), TMP2 + vpclmulqdq $0x00, TMP2, TMP1, TMP0 + vpxor TMP0, LO, LO + vpclmulqdq $0x11, TMP2, TMP1, TMP0 + vpxor TMP0, HI, HI + vpunpckhqdq TMP1, TMP1, TMP0 + vpxor TMP1, TMP0, TMP0 + vpclmulqdq $0x10, H_POW2_XORED, TMP0, TMP0 + vpxor TMP0, MI, MI +.elseif \i == 3 + // Third vector + vmovdqu 2*32(\ghashdata_ptr), TMP1 + vpshufb BSWAP_MASK, TMP1, TMP1 + vmovdqu OFFSETOF_H_POWERS+2*32(KEY), TMP2 +.elseif \i == 4 + vpclmulqdq $0x00, TMP2, TMP1, TMP0 + vpxor TMP0, LO, LO + vpclmulqdq $0x11, TMP2, TMP1, TMP0 + vpxor TMP0, HI, HI +.elseif \i == 5 + vpunpckhqdq TMP1, TMP1, TMP0 + vpxor TMP1, TMP0, TMP0 + vpclmulqdq $0x00, H_POW1_XORED, TMP0, TMP0 + vpxor TMP0, MI, MI + + // Fourth vector + vmovdqu 3*32(\ghashdata_ptr), TMP1 + vpshufb BSWAP_MASK, TMP1, TMP1 +.elseif \i == 6 + vmovdqu OFFSETOF_H_POWERS+3*32(KEY), TMP2 + vpclmulqdq $0x00, TMP2, TMP1, TMP0 + vpxor TMP0, LO, LO + vpclmulqdq $0x11, TMP2, TMP1, TMP0 + vpxor TMP0, HI, HI + vpunpckhqdq TMP1, TMP1, TMP0 + vpxor TMP1, TMP0, TMP0 + vpclmulqdq $0x10, H_POW1_XORED, TMP0, TMP0 + vpxor TMP0, MI, MI +.elseif \i == 7 + // Finalize 'mi' following Karatsuba multiplication. + vpxor LO, MI, MI + vpxor HI, MI, MI + + // Fold lo into mi. + vbroadcasti128 .Lgfpoly(%rip), TMP2 + vpclmulqdq $0x01, LO, TMP2, TMP0 + vpshufd $0x4e, LO, LO + vpxor LO, MI, MI + vpxor TMP0, MI, MI +.elseif \i == 8 + // Fold mi into hi. + vpclmulqdq $0x01, MI, TMP2, TMP0 + vpshufd $0x4e, MI, MI + vpxor MI, HI, HI + vpxor TMP0, HI, HI +.elseif \i == 9 + vextracti128 $1, HI, TMP0_XMM + vpxor TMP0_XMM, HI_XMM, GHASH_ACC_XMM +.endif +.endm + +// Update GHASH with four vectors of data blocks. See _ghash_step_4x for full +// explanation. +.macro _ghash_4x ghashdata_ptr +.irp i, 0,1,2,3,4,5,6,7,8,9 + _ghash_step_4x \i, \ghashdata_ptr +.endr +.endm + +// Load 1 <= %ecx <= 16 bytes from the pointer \src into the xmm register \dst +// and zeroize any remaining bytes. Clobbers %rax, %rcx, and \tmp{64,32}. +.macro _load_partial_block src, dst, tmp64, tmp32 + sub $8, %ecx // LEN - 8 + jle .Lle8\@ + + // Load 9 <= LEN <= 16 bytes. + vmovq (\src), \dst // Load first 8 bytes + mov (\src, %rcx), %rax // Load last 8 bytes + neg %ecx + shl $3, %ecx + shr %cl, %rax // Discard overlapping bytes + vpinsrq $1, %rax, \dst, \dst + jmp .Ldone\@ + +.Lle8\@: + add $4, %ecx // LEN - 4 + jl .Llt4\@ + + // Load 4 <= LEN <= 8 bytes. + mov (\src), %eax // Load first 4 bytes + mov (\src, %rcx), \tmp32 // Load last 4 bytes + jmp .Lcombine\@ + +.Llt4\@: + // Load 1 <= LEN <= 3 bytes. + add $2, %ecx // LEN - 2 + movzbl (\src), %eax // Load first byte + jl .Lmovq\@ + movzwl (\src, %rcx), \tmp32 // Load last 2 bytes +.Lcombine\@: + shl $3, %ecx + shl %cl, \tmp64 + or \tmp64, %rax // Combine the two parts +.Lmovq\@: + vmovq %rax, \dst +.Ldone\@: +.endm + +// Store 1 <= %ecx <= 16 bytes from the xmm register \src to the pointer \dst. +// Clobbers %rax, %rcx, and \tmp{64,32}. +.macro _store_partial_block src, dst, tmp64, tmp32 + sub $8, %ecx // LEN - 8 + jl .Llt8\@ + + // Store 8 <= LEN <= 16 bytes. + vpextrq $1, \src, %rax + mov %ecx, \tmp32 + shl $3, %ecx + ror %cl, %rax + mov %rax, (\dst, \tmp64) // Store last LEN - 8 bytes + vmovq \src, (\dst) // Store first 8 bytes + jmp .Ldone\@ + +.Llt8\@: + add $4, %ecx // LEN - 4 + jl .Llt4\@ + + // Store 4 <= LEN <= 7 bytes. + vpextrd $1, \src, %eax + mov %ecx, \tmp32 + shl $3, %ecx + ror %cl, %eax + mov %eax, (\dst, \tmp64) // Store last LEN - 4 bytes + vmovd \src, (\dst) // Store first 4 bytes + jmp .Ldone\@ + +.Llt4\@: + // Store 1 <= LEN <= 3 bytes. + vpextrb $0, \src, 0(\dst) + cmp $-2, %ecx // LEN - 4 == -2, i.e. LEN == 2? + jl .Ldone\@ + vpextrb $1, \src, 1(\dst) + je .Ldone\@ + vpextrb $2, \src, 2(\dst) +.Ldone\@: +.endm + +// void aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, +// u8 ghash_acc[16], +// const u8 *aad, int aadlen); +// +// This function processes the AAD (Additional Authenticated Data) in GCM. +// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the +// data given by |aad| and |aadlen|. On the first call, |ghash_acc| must be all +// zeroes. |aadlen| must be a multiple of 16, except on the last call where it +// can be any length. The caller must do any buffering needed to ensure this. +// +// This handles large amounts of AAD efficiently, while also keeping overhead +// low for small amounts which is the common case. TLS and IPsec use less than +// one block of AAD, but (uncommonly) other use cases may use much more. +SYM_FUNC_START(aes_gcm_aad_update_vaes_avx2) + + // Function arguments + .set KEY, %rdi + .set GHASH_ACC_PTR, %rsi + .set AAD, %rdx + .set AADLEN, %ecx // Must be %ecx for _load_partial_block + .set AADLEN64, %rcx // Zero-extend AADLEN before using! + + // Additional local variables. + // %rax and %r8 are used as temporary registers. + .set TMP0, %ymm0 + .set TMP0_XMM, %xmm0 + .set TMP1, %ymm1 + .set TMP1_XMM, %xmm1 + .set TMP2, %ymm2 + .set TMP2_XMM, %xmm2 + .set LO, %ymm3 + .set LO_XMM, %xmm3 + .set MI, %ymm4 + .set MI_XMM, %xmm4 + .set GHASH_ACC, %ymm5 + .set GHASH_ACC_XMM, %xmm5 + .set BSWAP_MASK, %ymm6 + .set BSWAP_MASK_XMM, %xmm6 + .set GFPOLY, %ymm7 + .set GFPOLY_XMM, %xmm7 + .set H_POW2_XORED, %ymm8 + .set H_POW1_XORED, %ymm9 + + // Load the bswap_mask and gfpoly constants. Since AADLEN is usually + // small, usually only 128-bit vectors will be used. So as an + // optimization, don't broadcast these constants to both 128-bit lanes + // quite yet. + vmovdqu .Lbswap_mask(%rip), BSWAP_MASK_XMM + vmovdqu .Lgfpoly(%rip), GFPOLY_XMM + + // Load the GHASH accumulator. + vmovdqu (GHASH_ACC_PTR), GHASH_ACC_XMM + + // Check for the common case of AADLEN <= 16, as well as AADLEN == 0. + test AADLEN, AADLEN + jz .Laad_done + cmp $16, AADLEN + jle .Laad_lastblock + + // AADLEN > 16, so we'll operate on full vectors. Broadcast bswap_mask + // and gfpoly to both 128-bit lanes. + vinserti128 $1, BSWAP_MASK_XMM, BSWAP_MASK, BSWAP_MASK + vinserti128 $1, GFPOLY_XMM, GFPOLY, GFPOLY + + // If AADLEN >= 128, update GHASH with 128 bytes of AAD at a time. + add $-128, AADLEN // 128 is 4 bytes, -128 is 1 byte + jl .Laad_loop_4x_done + vmovdqu OFFSETOF_H_POWERS_XORED(KEY), H_POW2_XORED + vmovdqu OFFSETOF_H_POWERS_XORED+32(KEY), H_POW1_XORED +.Laad_loop_4x: + _ghash_4x AAD + sub $-128, AAD + add $-128, AADLEN + jge .Laad_loop_4x +.Laad_loop_4x_done: + + // If AADLEN >= 32, update GHASH with 32 bytes of AAD at a time. + add $96, AADLEN + jl .Laad_loop_1x_done +.Laad_loop_1x: + vmovdqu (AAD), TMP0 + vpshufb BSWAP_MASK, TMP0, TMP0 + vpxor TMP0, GHASH_ACC, GHASH_ACC + vmovdqu OFFSETOFEND_H_POWERS-32(KEY), TMP0 + _ghash_mul TMP0, GHASH_ACC, GHASH_ACC, GFPOLY, TMP1, TMP2, LO + vextracti128 $1, GHASH_ACC, TMP0_XMM + vpxor TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + add $32, AAD + sub $32, AADLEN + jge .Laad_loop_1x +.Laad_loop_1x_done: + add $32, AADLEN + // Now 0 <= AADLEN < 32. + + jz .Laad_done + cmp $16, AADLEN + jle .Laad_lastblock + + // Update GHASH with the remaining 17 <= AADLEN <= 31 bytes of AAD. + mov AADLEN, AADLEN // Zero-extend AADLEN to AADLEN64. + vmovdqu (AAD), TMP0_XMM + vmovdqu -16(AAD, AADLEN64), TMP1_XMM + vpshufb BSWAP_MASK_XMM, TMP0_XMM, TMP0_XMM + vpxor TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + lea .Lrshift_and_bswap_table(%rip), %rax + vpshufb -16(%rax, AADLEN64), TMP1_XMM, TMP1_XMM + vinserti128 $1, TMP1_XMM, GHASH_ACC, GHASH_ACC + vmovdqu OFFSETOFEND_H_POWERS-32(KEY), TMP0 + _ghash_mul TMP0, GHASH_ACC, GHASH_ACC, GFPOLY, TMP1, TMP2, LO + vextracti128 $1, GHASH_ACC, TMP0_XMM + vpxor TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + jmp .Laad_done + +.Laad_lastblock: + // Update GHASH with the remaining 1 <= AADLEN <= 16 bytes of AAD. + _load_partial_block AAD, TMP0_XMM, %r8, %r8d + vpshufb BSWAP_MASK_XMM, TMP0_XMM, TMP0_XMM + vpxor TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + vmovdqu OFFSETOFEND_H_POWERS-16(KEY), TMP0_XMM + _ghash_mul TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM, GFPOLY_XMM, \ + TMP1_XMM, TMP2_XMM, LO_XMM + +.Laad_done: + // Store the updated GHASH accumulator back to memory. + vmovdqu GHASH_ACC_XMM, (GHASH_ACC_PTR) + + vzeroupper + RET +SYM_FUNC_END(aes_gcm_aad_update_vaes_avx2) + +// Do one non-last round of AES encryption on the blocks in the given AESDATA +// vectors using the round key that has been broadcast to all 128-bit lanes of +// \round_key. +.macro _vaesenc round_key, vecs:vararg +.irp i, \vecs + vaesenc \round_key, AESDATA\i, AESDATA\i +.endr +.endm + +// Generate counter blocks in the given AESDATA vectors, then do the zero-th AES +// round on them. Clobbers TMP0. +.macro _ctr_begin vecs:vararg + vbroadcasti128 .Linc_2blocks(%rip), TMP0 +.irp i, \vecs + vpshufb BSWAP_MASK, LE_CTR, AESDATA\i + vpaddd TMP0, LE_CTR, LE_CTR +.endr +.irp i, \vecs + vpxor RNDKEY0, AESDATA\i, AESDATA\i +.endr +.endm + +// Generate and encrypt counter blocks in the given AESDATA vectors, excluding +// the last AES round. Clobbers %rax and TMP0. +.macro _aesenc_loop vecs:vararg + _ctr_begin \vecs + lea 16(KEY), %rax +.Laesenc_loop\@: + vbroadcasti128 (%rax), TMP0 + _vaesenc TMP0, \vecs + add $16, %rax + cmp %rax, RNDKEYLAST_PTR + jne .Laesenc_loop\@ +.endm + +// Finalize the keystream blocks in the given AESDATA vectors by doing the last +// AES round, then XOR those keystream blocks with the corresponding data. +// Reduce latency by doing the XOR before the vaesenclast, utilizing the +// property vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a). Clobbers TMP0. +.macro _aesenclast_and_xor vecs:vararg +.irp i, \vecs + vpxor \i*32(SRC), RNDKEYLAST, TMP0 + vaesenclast TMP0, AESDATA\i, AESDATA\i +.endr +.irp i, \vecs + vmovdqu AESDATA\i, \i*32(DST) +.endr +.endm + +// void aes_gcm_{enc,dec}_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, +// const u32 le_ctr[4], u8 ghash_acc[16], +// const u8 *src, u8 *dst, int datalen); +// +// This macro generates a GCM encryption or decryption update function with the +// above prototype (with \enc selecting which one). The function computes the +// next portion of the CTR keystream, XOR's it with |datalen| bytes from |src|, +// and writes the resulting encrypted or decrypted data to |dst|. It also +// updates the GHASH accumulator |ghash_acc| using the next |datalen| ciphertext +// bytes. +// +// |datalen| must be a multiple of 16, except on the last call where it can be +// any length. The caller must do any buffering needed to ensure this. Both +// in-place and out-of-place en/decryption are supported. +// +// |le_ctr| must give the current counter in little-endian format. This +// function loads the counter from |le_ctr| and increments the loaded counter as +// needed, but it does *not* store the updated counter back to |le_ctr|. The +// caller must update |le_ctr| if any more data segments follow. Internally, +// only the low 32-bit word of the counter is incremented, following the GCM +// standard. +.macro _aes_gcm_update enc + + // Function arguments + .set KEY, %rdi + .set LE_CTR_PTR, %rsi + .set LE_CTR_PTR32, %esi + .set GHASH_ACC_PTR, %rdx + .set SRC, %rcx // Assumed to be %rcx. + // See .Ltail_xor_and_ghash_1to16bytes + .set DST, %r8 + .set DATALEN, %r9d + .set DATALEN64, %r9 // Zero-extend DATALEN before using! + + // Additional local variables + + // %rax is used as a temporary register. LE_CTR_PTR is also available + // as a temporary register after the counter is loaded. + + // AES key length in bytes + .set AESKEYLEN, %r10d + .set AESKEYLEN64, %r10 + + // Pointer to the last AES round key for the chosen AES variant + .set RNDKEYLAST_PTR, %r11 + + // BSWAP_MASK is the shuffle mask for byte-reflecting 128-bit values + // using vpshufb, copied to all 128-bit lanes. + .set BSWAP_MASK, %ymm0 + .set BSWAP_MASK_XMM, %xmm0 + + // GHASH_ACC is the accumulator variable for GHASH. When fully reduced, + // only the lowest 128-bit lane can be nonzero. When not fully reduced, + // more than one lane may be used, and they need to be XOR'd together. + .set GHASH_ACC, %ymm1 + .set GHASH_ACC_XMM, %xmm1 + + // TMP[0-2] are temporary registers. + .set TMP0, %ymm2 + .set TMP0_XMM, %xmm2 + .set TMP1, %ymm3 + .set TMP1_XMM, %xmm3 + .set TMP2, %ymm4 + .set TMP2_XMM, %xmm4 + + // LO and MI are used to accumulate unreduced GHASH products. + .set LO, %ymm5 + .set LO_XMM, %xmm5 + .set MI, %ymm6 + .set MI_XMM, %xmm6 + + // H_POW[2-1]_XORED contain cached values from KEY->h_powers_xored. The + // descending numbering reflects the order of the key powers. + .set H_POW2_XORED, %ymm7 + .set H_POW2_XORED_XMM, %xmm7 + .set H_POW1_XORED, %ymm8 + + // RNDKEY0 caches the zero-th round key, and RNDKEYLAST the last one. + .set RNDKEY0, %ymm9 + .set RNDKEYLAST, %ymm10 + + // LE_CTR contains the next set of little-endian counter blocks. + .set LE_CTR, %ymm11 + + // AESDATA[0-3] hold the counter blocks that are being encrypted by AES. + .set AESDATA0, %ymm12 + .set AESDATA0_XMM, %xmm12 + .set AESDATA1, %ymm13 + .set AESDATA1_XMM, %xmm13 + .set AESDATA2, %ymm14 + .set AESDATA3, %ymm15 + +.if \enc + .set GHASHDATA_PTR, DST +.else + .set GHASHDATA_PTR, SRC +.endif + + vbroadcasti128 .Lbswap_mask(%rip), BSWAP_MASK + + // Load the GHASH accumulator and the starting counter. + vmovdqu (GHASH_ACC_PTR), GHASH_ACC_XMM + vbroadcasti128 (LE_CTR_PTR), LE_CTR + + // Load the AES key length in bytes. + movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN + + // Make RNDKEYLAST_PTR point to the last AES round key. This is the + // round key with index 10, 12, or 14 for AES-128, AES-192, or AES-256 + // respectively. Then load the zero-th and last round keys. + lea 6*16(KEY,AESKEYLEN64,4), RNDKEYLAST_PTR + vbroadcasti128 (KEY), RNDKEY0 + vbroadcasti128 (RNDKEYLAST_PTR), RNDKEYLAST + + // Finish initializing LE_CTR by adding 1 to the second block. + vpaddd .Lctr_pattern(%rip), LE_CTR, LE_CTR + + // If there are at least 128 bytes of data, then continue into the loop + // that processes 128 bytes of data at a time. Otherwise skip it. + add $-128, DATALEN // 128 is 4 bytes, -128 is 1 byte + jl .Lcrypt_loop_4x_done\@ + + vmovdqu OFFSETOF_H_POWERS_XORED(KEY), H_POW2_XORED + vmovdqu OFFSETOF_H_POWERS_XORED+32(KEY), H_POW1_XORED + + // Main loop: en/decrypt and hash 4 vectors (128 bytes) at a time. + +.if \enc + // Encrypt the first 4 vectors of plaintext blocks. + _aesenc_loop 0,1,2,3 + _aesenclast_and_xor 0,1,2,3 + sub $-128, SRC // 128 is 4 bytes, -128 is 1 byte + add $-128, DATALEN + jl .Lghash_last_ciphertext_4x\@ +.endif + +.align 16 +.Lcrypt_loop_4x\@: + + // Start the AES encryption of the counter blocks. + _ctr_begin 0,1,2,3 + cmp $24, AESKEYLEN + jl 128f // AES-128? + je 192f // AES-192? + // AES-256 + vbroadcasti128 -13*16(RNDKEYLAST_PTR), TMP0 + _vaesenc TMP0, 0,1,2,3 + vbroadcasti128 -12*16(RNDKEYLAST_PTR), TMP0 + _vaesenc TMP0, 0,1,2,3 +192: + vbroadcasti128 -11*16(RNDKEYLAST_PTR), TMP0 + _vaesenc TMP0, 0,1,2,3 + vbroadcasti128 -10*16(RNDKEYLAST_PTR), TMP0 + _vaesenc TMP0, 0,1,2,3 +128: + + // Finish the AES encryption of the counter blocks in AESDATA[0-3], + // interleaved with the GHASH update of the ciphertext blocks. +.irp i, 9,8,7,6,5,4,3,2,1 + _ghash_step_4x (9 - \i), GHASHDATA_PTR + vbroadcasti128 -\i*16(RNDKEYLAST_PTR), TMP0 + _vaesenc TMP0, 0,1,2,3 +.endr + _ghash_step_4x 9, GHASHDATA_PTR +.if \enc + sub $-128, DST // 128 is 4 bytes, -128 is 1 byte +.endif + _aesenclast_and_xor 0,1,2,3 + sub $-128, SRC +.if !\enc + sub $-128, DST +.endif + add $-128, DATALEN + jge .Lcrypt_loop_4x\@ + +.if \enc +.Lghash_last_ciphertext_4x\@: + // Update GHASH with the last set of ciphertext blocks. + _ghash_4x DST + sub $-128, DST +.endif + +.Lcrypt_loop_4x_done\@: + + // Undo the extra subtraction by 128 and check whether data remains. + sub $-128, DATALEN // 128 is 4 bytes, -128 is 1 byte + jz .Ldone\@ + + // The data length isn't a multiple of 128 bytes. Process the remaining + // data of length 1 <= DATALEN < 128. + // + // Since there are enough key powers available for all remaining data, + // there is no need to do a GHASH reduction after each iteration. + // Instead, multiply each remaining block by its own key power, and only + // do a GHASH reduction at the very end. + + // Make POWERS_PTR point to the key powers [H^N, H^(N-1), ...] where N + // is the number of blocks that remain. + .set POWERS_PTR, LE_CTR_PTR // LE_CTR_PTR is free to be reused. + .set POWERS_PTR32, LE_CTR_PTR32 + mov DATALEN, %eax + neg %rax + and $~15, %rax // -round_up(DATALEN, 16) + lea OFFSETOFEND_H_POWERS(KEY,%rax), POWERS_PTR + + // Start collecting the unreduced GHASH intermediate value LO, MI, HI. + .set HI, H_POW2_XORED // H_POW2_XORED is free to be reused. + .set HI_XMM, H_POW2_XORED_XMM + vpxor LO_XMM, LO_XMM, LO_XMM + vpxor MI_XMM, MI_XMM, MI_XMM + vpxor HI_XMM, HI_XMM, HI_XMM + + // 1 <= DATALEN < 128. Generate 2 or 4 more vectors of keystream blocks + // excluding the last AES round, depending on the remaining DATALEN. + cmp $64, DATALEN + jg .Ltail_gen_4_keystream_vecs\@ + _aesenc_loop 0,1 + cmp $32, DATALEN + jge .Ltail_xor_and_ghash_full_vec_loop\@ + jmp .Ltail_xor_and_ghash_partial_vec\@ +.Ltail_gen_4_keystream_vecs\@: + _aesenc_loop 0,1,2,3 + + // XOR the remaining data and accumulate the unreduced GHASH products + // for DATALEN >= 32, starting with one full 32-byte vector at a time. +.Ltail_xor_and_ghash_full_vec_loop\@: +.if \enc + _aesenclast_and_xor 0 + vpshufb BSWAP_MASK, AESDATA0, AESDATA0 +.else + vmovdqu (SRC), TMP1 + vpxor TMP1, RNDKEYLAST, TMP0 + vaesenclast TMP0, AESDATA0, AESDATA0 + vmovdqu AESDATA0, (DST) + vpshufb BSWAP_MASK, TMP1, AESDATA0 +.endif + // The ciphertext blocks (i.e. GHASH input data) are now in AESDATA0. + vpxor GHASH_ACC, AESDATA0, AESDATA0 + vmovdqu (POWERS_PTR), TMP2 + _ghash_mul_noreduce TMP2, AESDATA0, LO, MI, HI, TMP0 + vmovdqa AESDATA1, AESDATA0 + vmovdqa AESDATA2, AESDATA1 + vmovdqa AESDATA3, AESDATA2 + vpxor GHASH_ACC_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + add $32, SRC + add $32, DST + add $32, POWERS_PTR + sub $32, DATALEN + cmp $32, DATALEN + jge .Ltail_xor_and_ghash_full_vec_loop\@ + test DATALEN, DATALEN + jz .Ltail_ghash_reduce\@ + +.Ltail_xor_and_ghash_partial_vec\@: + // XOR the remaining data and accumulate the unreduced GHASH products, + // for 1 <= DATALEN < 32. + vaesenclast RNDKEYLAST, AESDATA0, AESDATA0 + cmp $16, DATALEN + jle .Ltail_xor_and_ghash_1to16bytes\@ + + // Handle 17 <= DATALEN < 32. + + // Load a vpshufb mask that will right-shift by '32 - DATALEN' bytes + // (shifting in zeroes), then reflect all 16 bytes. + lea .Lrshift_and_bswap_table(%rip), %rax + vmovdqu -16(%rax, DATALEN64), TMP2_XMM + + // Move the second keystream block to its own register and left-align it + vextracti128 $1, AESDATA0, AESDATA1_XMM + vpxor .Lfifteens(%rip), TMP2_XMM, TMP0_XMM + vpshufb TMP0_XMM, AESDATA1_XMM, AESDATA1_XMM + + // Using overlapping loads and stores, XOR the source data with the + // keystream and write the destination data. Then prepare the GHASH + // input data: the full ciphertext block and the zero-padded partial + // ciphertext block, both byte-reflected, in AESDATA0. +.if \enc + vpxor -16(SRC, DATALEN64), AESDATA1_XMM, AESDATA1_XMM + vpxor (SRC), AESDATA0_XMM, AESDATA0_XMM + vmovdqu AESDATA1_XMM, -16(DST, DATALEN64) + vmovdqu AESDATA0_XMM, (DST) + vpshufb TMP2_XMM, AESDATA1_XMM, AESDATA1_XMM + vpshufb BSWAP_MASK_XMM, AESDATA0_XMM, AESDATA0_XMM +.else + vmovdqu -16(SRC, DATALEN64), TMP1_XMM + vmovdqu (SRC), TMP0_XMM + vpxor TMP1_XMM, AESDATA1_XMM, AESDATA1_XMM + vpxor TMP0_XMM, AESDATA0_XMM, AESDATA0_XMM + vmovdqu AESDATA1_XMM, -16(DST, DATALEN64) + vmovdqu AESDATA0_XMM, (DST) + vpshufb TMP2_XMM, TMP1_XMM, AESDATA1_XMM + vpshufb BSWAP_MASK_XMM, TMP0_XMM, AESDATA0_XMM +.endif + vpxor GHASH_ACC_XMM, AESDATA0_XMM, AESDATA0_XMM + vinserti128 $1, AESDATA1_XMM, AESDATA0, AESDATA0 + vmovdqu (POWERS_PTR), TMP2 + jmp .Ltail_ghash_last_vec\@ + +.Ltail_xor_and_ghash_1to16bytes\@: + // Handle 1 <= DATALEN <= 16. Carefully load and store the + // possibly-partial block, which we mustn't access out of bounds. + vmovdqu (POWERS_PTR), TMP2_XMM + mov SRC, KEY // Free up %rcx, assuming SRC == %rcx + mov DATALEN, %ecx + _load_partial_block KEY, TMP0_XMM, POWERS_PTR, POWERS_PTR32 + vpxor TMP0_XMM, AESDATA0_XMM, AESDATA0_XMM + mov DATALEN, %ecx + _store_partial_block AESDATA0_XMM, DST, POWERS_PTR, POWERS_PTR32 +.if \enc + lea .Lselect_high_bytes_table(%rip), %rax + vpshufb BSWAP_MASK_XMM, AESDATA0_XMM, AESDATA0_XMM + vpand (%rax, DATALEN64), AESDATA0_XMM, AESDATA0_XMM +.else + vpshufb BSWAP_MASK_XMM, TMP0_XMM, AESDATA0_XMM +.endif + vpxor GHASH_ACC_XMM, AESDATA0_XMM, AESDATA0_XMM + +.Ltail_ghash_last_vec\@: + // Accumulate the unreduced GHASH products for the last 1-2 blocks. The + // GHASH input data is in AESDATA0. If only one block remains, then the + // second block in AESDATA0 is zero and does not affect the result. + _ghash_mul_noreduce TMP2, AESDATA0, LO, MI, HI, TMP0 + +.Ltail_ghash_reduce\@: + // Finally, do the GHASH reduction. + vbroadcasti128 .Lgfpoly(%rip), TMP0 + _ghash_reduce LO, MI, HI, TMP0, TMP1 + vextracti128 $1, HI, GHASH_ACC_XMM + vpxor HI_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM + +.Ldone\@: + // Store the updated GHASH accumulator back to memory. + vmovdqu GHASH_ACC_XMM, (GHASH_ACC_PTR) + + vzeroupper + RET +.endm + +// void aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, +// const u32 le_ctr[4], u8 ghash_acc[16], +// u64 total_aadlen, u64 total_datalen); +// bool aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, +// const u32 le_ctr[4], const u8 ghash_acc[16], +// u64 total_aadlen, u64 total_datalen, +// const u8 tag[16], int taglen); +// +// This macro generates one of the above two functions (with \enc selecting +// which one). Both functions finish computing the GCM authentication tag by +// updating GHASH with the lengths block and encrypting the GHASH accumulator. +// |total_aadlen| and |total_datalen| must be the total length of the additional +// authenticated data and the en/decrypted data in bytes, respectively. +// +// The encryption function then stores the full-length (16-byte) computed +// authentication tag to |ghash_acc|. The decryption function instead loads the +// expected authentication tag (the one that was transmitted) from the 16-byte +// buffer |tag|, compares the first 4 <= |taglen| <= 16 bytes of it to the +// computed tag in constant time, and returns true if and only if they match. +.macro _aes_gcm_final enc + + // Function arguments + .set KEY, %rdi + .set LE_CTR_PTR, %rsi + .set GHASH_ACC_PTR, %rdx + .set TOTAL_AADLEN, %rcx + .set TOTAL_DATALEN, %r8 + .set TAG, %r9 + .set TAGLEN, %r10d // Originally at 8(%rsp) + .set TAGLEN64, %r10 + + // Additional local variables. + // %rax and %xmm0-%xmm3 are used as temporary registers. + .set AESKEYLEN, %r11d + .set AESKEYLEN64, %r11 + .set GFPOLY, %xmm4 + .set BSWAP_MASK, %xmm5 + .set LE_CTR, %xmm6 + .set GHASH_ACC, %xmm7 + .set H_POW1, %xmm8 + + // Load some constants. + vmovdqa .Lgfpoly(%rip), GFPOLY + vmovdqa .Lbswap_mask(%rip), BSWAP_MASK + + // Load the AES key length in bytes. + movl OFFSETOF_AESKEYLEN(KEY), AESKEYLEN + + // Set up a counter block with 1 in the low 32-bit word. This is the + // counter that produces the ciphertext needed to encrypt the auth tag. + // GFPOLY has 1 in the low word, so grab the 1 from there using a blend. + vpblendd $0xe, (LE_CTR_PTR), GFPOLY, LE_CTR + + // Build the lengths block and XOR it with the GHASH accumulator. + // Although the lengths block is defined as the AAD length followed by + // the en/decrypted data length, both in big-endian byte order, a byte + // reflection of the full block is needed because of the way we compute + // GHASH (see _ghash_mul_step). By using little-endian values in the + // opposite order, we avoid having to reflect any bytes here. + vmovq TOTAL_DATALEN, %xmm0 + vpinsrq $1, TOTAL_AADLEN, %xmm0, %xmm0 + vpsllq $3, %xmm0, %xmm0 // Bytes to bits + vpxor (GHASH_ACC_PTR), %xmm0, GHASH_ACC + + // Load the first hash key power (H^1), which is stored last. + vmovdqu OFFSETOFEND_H_POWERS-16(KEY), H_POW1 + + // Load TAGLEN if decrypting. +.if !\enc + movl 8(%rsp), TAGLEN +.endif + + // Make %rax point to the last AES round key for the chosen AES variant. + lea 6*16(KEY,AESKEYLEN64,4), %rax + + // Start the AES encryption of the counter block by swapping the counter + // block to big-endian and XOR-ing it with the zero-th AES round key. + vpshufb BSWAP_MASK, LE_CTR, %xmm0 + vpxor (KEY), %xmm0, %xmm0 + + // Complete the AES encryption and multiply GHASH_ACC by H^1. + // Interleave the AES and GHASH instructions to improve performance. + cmp $24, AESKEYLEN + jl 128f // AES-128? + je 192f // AES-192? + // AES-256 + vaesenc -13*16(%rax), %xmm0, %xmm0 + vaesenc -12*16(%rax), %xmm0, %xmm0 +192: + vaesenc -11*16(%rax), %xmm0, %xmm0 + vaesenc -10*16(%rax), %xmm0, %xmm0 +128: +.irp i, 0,1,2,3,4,5,6,7,8 + _ghash_mul_step \i, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ + %xmm1, %xmm2, %xmm3 + vaesenc (\i-9)*16(%rax), %xmm0, %xmm0 +.endr + _ghash_mul_step 9, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \ + %xmm1, %xmm2, %xmm3 + + // Undo the byte reflection of the GHASH accumulator. + vpshufb BSWAP_MASK, GHASH_ACC, GHASH_ACC + + // Do the last AES round and XOR the resulting keystream block with the + // GHASH accumulator to produce the full computed authentication tag. + // + // Reduce latency by taking advantage of the property vaesenclast(key, + // a) ^ b == vaesenclast(key ^ b, a). I.e., XOR GHASH_ACC into the last + // round key, instead of XOR'ing the final AES output with GHASH_ACC. + // + // enc_final then returns the computed auth tag, while dec_final + // compares it with the transmitted one and returns a bool. To compare + // the tags, dec_final XORs them together and uses vptest to check + // whether the result is all-zeroes. This should be constant-time. + // dec_final applies the vaesenclast optimization to this additional + // value XOR'd too. +.if \enc + vpxor (%rax), GHASH_ACC, %xmm1 + vaesenclast %xmm1, %xmm0, GHASH_ACC + vmovdqu GHASH_ACC, (GHASH_ACC_PTR) +.else + vpxor (TAG), GHASH_ACC, GHASH_ACC + vpxor (%rax), GHASH_ACC, GHASH_ACC + vaesenclast GHASH_ACC, %xmm0, %xmm0 + lea .Lselect_high_bytes_table(%rip), %rax + vmovdqu (%rax, TAGLEN64), %xmm1 + vpshufb BSWAP_MASK, %xmm1, %xmm1 // select low bytes, not high + vptest %xmm1, %xmm0 + sete %al +.endif + // No need for vzeroupper here, since only used xmm registers were used. + RET +.endm + +SYM_FUNC_START(aes_gcm_enc_update_vaes_avx2) + _aes_gcm_update 1 +SYM_FUNC_END(aes_gcm_enc_update_vaes_avx2) +SYM_FUNC_START(aes_gcm_dec_update_vaes_avx2) + _aes_gcm_update 0 +SYM_FUNC_END(aes_gcm_dec_update_vaes_avx2) + +SYM_FUNC_START(aes_gcm_enc_final_vaes_avx2) + _aes_gcm_final 1 +SYM_FUNC_END(aes_gcm_enc_final_vaes_avx2) +SYM_FUNC_START(aes_gcm_dec_final_vaes_avx2) + _aes_gcm_final 0 +SYM_FUNC_END(aes_gcm_dec_final_vaes_avx2) diff --git a/arch/x86/crypto/aesni-intel_glue.c b/arch/x86/crypto/aesni-intel_glue.c index d953ac470aae3..e2847d67430fd 100644 --- a/arch/x86/crypto/aesni-intel_glue.c +++ b/arch/x86/crypto/aesni-intel_glue.c @@ -874,6 +874,36 @@ struct aes_gcm_key_aesni { #define AES_GCM_KEY_AESNI_SIZE \ (sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1))) +/* Key struct used by the VAES + AVX2 implementation of AES-GCM */ +struct aes_gcm_key_vaes_avx2 { + /* + * Common part of the key. The assembly code prefers 16-byte alignment + * for the round keys; we get this by them being located at the start of + * the struct and the whole struct being 32-byte aligned. + */ + struct aes_gcm_key base; + + /* + * Powers of the hash key H^8 through H^1. These are 128-bit values. + * They all have an extra factor of x^-1 and are byte-reversed. + * The assembly code prefers 32-byte alignment for this. + */ + u64 h_powers[8][2] __aligned(32); + + /* + * Each entry in this array contains the two halves of an entry of + * h_powers XOR'd together, in the following order: + * H^8,H^6,H^7,H^5,H^4,H^2,H^3,H^1 i.e. indices 0,2,1,3,4,6,5,7. + * This is used for Karatsuba multiplication. + */ + u64 h_powers_xored[8]; +}; + +#define AES_GCM_KEY_VAES_AVX2(key) \ + container_of((key), struct aes_gcm_key_vaes_avx2, base) +#define AES_GCM_KEY_VAES_AVX2_SIZE \ + (sizeof(struct aes_gcm_key_vaes_avx2) + (31 & ~(CRYPTO_MINALIGN - 1))) + /* Key struct used by the VAES + AVX10 implementations of AES-GCM */ struct aes_gcm_key_avx10 { /* @@ -910,14 +940,17 @@ struct aes_gcm_key_avx10 { #define FLAG_RFC4106 BIT(0) #define FLAG_ENC BIT(1) #define FLAG_AVX BIT(2) -#define FLAG_AVX10_256 BIT(3) -#define FLAG_AVX10_512 BIT(4) +#define FLAG_VAES_AVX2 BIT(3) +#define FLAG_AVX10_256 BIT(4) +#define FLAG_AVX10_512 BIT(5) static inline struct aes_gcm_key * aes_gcm_key_get(struct crypto_aead *tfm, int flags) { if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) return PTR_ALIGN(crypto_aead_ctx(tfm), 64); + else if (flags & FLAG_VAES_AVX2) + return PTR_ALIGN(crypto_aead_ctx(tfm), 32); else return PTR_ALIGN(crypto_aead_ctx(tfm), 16); } @@ -927,6 +960,8 @@ aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key); asmlinkage void aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key); asmlinkage void +aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key); +asmlinkage void aes_gcm_precompute_vaes_avx10_256(struct aes_gcm_key_avx10 *key); asmlinkage void aes_gcm_precompute_vaes_avx10_512(struct aes_gcm_key_avx10 *key); @@ -947,6 +982,8 @@ static void aes_gcm_precompute(struct aes_gcm_key *key, int flags) aes_gcm_precompute_vaes_avx10_512(AES_GCM_KEY_AVX10(key)); else if (flags & FLAG_AVX10_256) aes_gcm_precompute_vaes_avx10_256(AES_GCM_KEY_AVX10(key)); + else if (flags & FLAG_VAES_AVX2) + aes_gcm_precompute_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key)); else if (flags & FLAG_AVX) aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key)); else @@ -960,6 +997,9 @@ asmlinkage void aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key, u8 ghash_acc[16], const u8 *aad, int aadlen); asmlinkage void +aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, + u8 ghash_acc[16], const u8 *aad, int aadlen); +asmlinkage void aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key, u8 ghash_acc[16], const u8 *aad, int aadlen); @@ -969,6 +1009,9 @@ static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16], if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) aes_gcm_aad_update_vaes_avx10(AES_GCM_KEY_AVX10(key), ghash_acc, aad, aadlen); + else if (flags & FLAG_VAES_AVX2) + aes_gcm_aad_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key), + ghash_acc, aad, aadlen); else if (flags & FLAG_AVX) aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc, aad, aadlen); @@ -986,6 +1029,10 @@ aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key, const u32 le_ctr[4], u8 ghash_acc[16], const u8 *src, u8 *dst, int datalen); asmlinkage void +aes_gcm_enc_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, + const u32 le_ctr[4], u8 ghash_acc[16], + const u8 *src, u8 *dst, int datalen); +asmlinkage void aes_gcm_enc_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key, const u32 le_ctr[4], u8 ghash_acc[16], const u8 *src, u8 *dst, int datalen); @@ -1003,6 +1050,10 @@ aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key, const u32 le_ctr[4], u8 ghash_acc[16], const u8 *src, u8 *dst, int datalen); asmlinkage void +aes_gcm_dec_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, + const u32 le_ctr[4], u8 ghash_acc[16], + const u8 *src, u8 *dst, int datalen); +asmlinkage void aes_gcm_dec_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key, const u32 le_ctr[4], u8 ghash_acc[16], const u8 *src, u8 *dst, int datalen); @@ -1026,6 +1077,10 @@ aes_gcm_update(const struct aes_gcm_key *key, aes_gcm_enc_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key), le_ctr, ghash_acc, src, dst, datalen); + else if (flags & FLAG_VAES_AVX2) + aes_gcm_enc_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key), + le_ctr, ghash_acc, + src, dst, datalen); else if (flags & FLAG_AVX) aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key), le_ctr, ghash_acc, @@ -1042,6 +1097,10 @@ aes_gcm_update(const struct aes_gcm_key *key, aes_gcm_dec_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key), le_ctr, ghash_acc, src, dst, datalen); + else if (flags & FLAG_VAES_AVX2) + aes_gcm_dec_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key), + le_ctr, ghash_acc, + src, dst, datalen); else if (flags & FLAG_AVX) aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key), le_ctr, ghash_acc, @@ -1062,6 +1121,10 @@ aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key, const u32 le_ctr[4], u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen); asmlinkage void +aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, + const u32 le_ctr[4], u8 ghash_acc[16], + u64 total_aadlen, u64 total_datalen); +asmlinkage void aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, const u32 le_ctr[4], u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen); @@ -1076,6 +1139,10 @@ aes_gcm_enc_final(const struct aes_gcm_key *key, aes_gcm_enc_final_vaes_avx10(AES_GCM_KEY_AVX10(key), le_ctr, ghash_acc, total_aadlen, total_datalen); + else if (flags & FLAG_VAES_AVX2) + aes_gcm_enc_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key), + le_ctr, ghash_acc, + total_aadlen, total_datalen); else if (flags & FLAG_AVX) aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key), le_ctr, ghash_acc, @@ -1097,6 +1164,11 @@ aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key, u64 total_aadlen, u64 total_datalen, const u8 tag[16], int taglen); asmlinkage bool __must_check +aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key, + const u32 le_ctr[4], const u8 ghash_acc[16], + u64 total_aadlen, u64 total_datalen, + const u8 tag[16], int taglen); +asmlinkage bool __must_check aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key, const u32 le_ctr[4], const u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen, @@ -1113,6 +1185,11 @@ aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4], le_ctr, ghash_acc, total_aadlen, total_datalen, tag, taglen); + else if (flags & FLAG_VAES_AVX2) + return aes_gcm_dec_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key), + le_ctr, ghash_acc, + total_aadlen, total_datalen, + tag, taglen); else if (flags & FLAG_AVX) return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key), le_ctr, ghash_acc, @@ -1195,6 +1272,10 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key, BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers) != 496); BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers_xored) != 624); BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_times_x64) != 688); + BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_enc) != 0); + BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_length) != 480); + BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers) != 512); + BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers_xored) != 640); BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_enc) != 0); BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_length) != 480); BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, h_powers) != 512); @@ -1240,6 +1321,22 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key, gf128mul_lle(&h, &h1); } memset(k->padding, 0, sizeof(k->padding)); + } else if (flags & FLAG_VAES_AVX2) { + struct aes_gcm_key_vaes_avx2 *k = + AES_GCM_KEY_VAES_AVX2(key); + static const u8 indices[8] = { 0, 2, 1, 3, 4, 6, 5, 7 }; + + for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) { + k->h_powers[i][0] = be64_to_cpu(h.b); + k->h_powers[i][1] = be64_to_cpu(h.a); + gf128mul_lle(&h, &h1); + } + for (i = 0; i < ARRAY_SIZE(k->h_powers_xored); i++) { + int j = indices[i]; + + k->h_powers_xored[i] = k->h_powers[j][0] ^ + k->h_powers[j][1]; + } } else { struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key); @@ -1508,6 +1605,11 @@ DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX, "generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx", AES_GCM_KEY_AESNI_SIZE, 500); +/* aes_gcm_algs_vaes_avx2 */ +DEFINE_GCM_ALGS(vaes_avx2, FLAG_VAES_AVX2, + "generic-gcm-vaes-avx2", "rfc4106-gcm-vaes-avx2", + AES_GCM_KEY_VAES_AVX2_SIZE, 600); + /* aes_gcm_algs_vaes_avx10_256 */ DEFINE_GCM_ALGS(vaes_avx10_256, FLAG_AVX10_256, "generic-gcm-vaes-avx10_256", "rfc4106-gcm-vaes-avx10_256", @@ -1548,6 +1650,10 @@ static int __init register_avx_algs(void) ARRAY_SIZE(skcipher_algs_vaes_avx2)); if (err) return err; + err = crypto_register_aeads(aes_gcm_algs_vaes_avx2, + ARRAY_SIZE(aes_gcm_algs_vaes_avx2)); + if (err) + return err; if (!boot_cpu_has(X86_FEATURE_AVX512BW) || !boot_cpu_has(X86_FEATURE_AVX512VL) || @@ -1595,6 +1701,7 @@ static void unregister_avx_algs(void) unregister_aeads(aes_gcm_algs_aesni_avx); unregister_skciphers(skcipher_algs_vaes_avx2); unregister_skciphers(skcipher_algs_vaes_avx512); + unregister_aeads(aes_gcm_algs_vaes_avx2); unregister_aeads(aes_gcm_algs_vaes_avx10_256); unregister_aeads(aes_gcm_algs_vaes_avx10_512); } -- 2.47.3