]> git.ipfire.org Git - thirdparty/hostap.git/commitdiff
SAE: Avoid branches in is_quadratic_residue_blind()
authorJouni Malinen <jouni@codeaurora.org>
Tue, 26 Feb 2019 17:34:38 +0000 (19:34 +0200)
committerJouni Malinen <j@w1.fi>
Tue, 9 Apr 2019 14:11:15 +0000 (17:11 +0300)
Make the non-failure path in the function proceed without branches based
on r_odd and in constant time to minimize risk of observable differences
in timing or cache use. (CVE-2019-9494)

Signed-off-by: Jouni Malinen <jouni@codeaurora.org>
src/common/sae.c

index d55323bcdd7e48a05d367729fcc5aef17161d52d..5df9b95aae705f9381162cd3a511268414316509 100644 (file)
@@ -232,12 +232,14 @@ get_rand_1_to_p_1(const u8 *prime, size_t prime_len, size_t prime_bits,
 
 static int is_quadratic_residue_blind(struct sae_data *sae,
                                      const u8 *prime, size_t bits,
-                                     const struct crypto_bignum *qr,
-                                     const struct crypto_bignum *qnr,
+                                     const u8 *qr, const u8 *qnr,
                                      const struct crypto_bignum *y_sqr)
 {
-       struct crypto_bignum *r, *num;
+       struct crypto_bignum *r, *num, *qr_or_qnr = NULL;
        int r_odd, check, res = -1;
+       u8 qr_or_qnr_bin[SAE_MAX_ECC_PRIME_LEN];
+       size_t prime_len = sae->tmp->prime_len;
+       unsigned int mask;
 
        /*
         * Use the blinding technique to mask y_sqr while determining
@@ -248,7 +250,7 @@ static int is_quadratic_residue_blind(struct sae_data *sae,
         * r = a random number between 1 and p-1, inclusive
         * num = (v * r * r) modulo p
         */
-       r = get_rand_1_to_p_1(prime, sae->tmp->prime_len, bits, &r_odd);
+       r = get_rand_1_to_p_1(prime, prime_len, bits, &r_odd);
        if (!r)
                return -1;
 
@@ -258,41 +260,45 @@ static int is_quadratic_residue_blind(struct sae_data *sae,
            crypto_bignum_mulmod(num, r, sae->tmp->prime, num) < 0)
                goto fail;
 
-       if (r_odd) {
-               /*
-                * num = (num * qr) module p
-                * LGR(num, p) = 1 ==> quadratic residue
-                */
-               if (crypto_bignum_mulmod(num, qr, sae->tmp->prime, num) < 0)
-                       goto fail;
-               check = 1;
-       } else {
-               /*
-                * num = (num * qnr) module p
-                * LGR(num, p) = -1 ==> quadratic residue
-                */
-               if (crypto_bignum_mulmod(num, qnr, sae->tmp->prime, num) < 0)
-                       goto fail;
-               check = -1;
-       }
+       /*
+        * Need to minimize differences in handling different cases, so try to
+        * avoid branches and timing differences.
+        *
+        * If r_odd:
+        * num = (num * qr) module p
+        * LGR(num, p) = 1 ==> quadratic residue
+        * else:
+        * num = (num * qnr) module p
+        * LGR(num, p) = -1 ==> quadratic residue
+        */
+       mask = const_time_is_zero(r_odd);
+       const_time_select_bin(mask, qnr, qr, prime_len, qr_or_qnr_bin);
+       qr_or_qnr = crypto_bignum_init_set(qr_or_qnr_bin, prime_len);
+       if (!qr_or_qnr ||
+           crypto_bignum_mulmod(num, qr_or_qnr, sae->tmp->prime, num) < 0)
+               goto fail;
+       /* r_odd is 0 or 1; branchless version of check = r_odd ? 1 : -1, */
+       check = const_time_select_int(mask, -1, 1);
 
        res = crypto_bignum_legendre(num, sae->tmp->prime);
        if (res == -2) {
                res = -1;
                goto fail;
        }
-       res = res == check;
+       /* branchless version of res = res == check
+        * (res is -1, 0, or 1; check is -1 or 1) */
+       mask = const_time_eq(res, check);
+       res = const_time_select_int(mask, 1, 0);
 fail:
        crypto_bignum_deinit(num, 1);
        crypto_bignum_deinit(r, 1);
+       crypto_bignum_deinit(qr_or_qnr, 1);
        return res;
 }
 
 
 static int sae_test_pwd_seed_ecc(struct sae_data *sae, const u8 *pwd_seed,
-                                const u8 *prime,
-                                const struct crypto_bignum *qr,
-                                const struct crypto_bignum *qnr,
+                                const u8 *prime, const u8 *qr, const u8 *qnr,
                                 u8 *pwd_value)
 {
        struct crypto_bignum *y_sqr, *x_cand;
@@ -452,6 +458,8 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
        struct crypto_bignum *x = NULL, *qr = NULL, *qnr = NULL;
        u8 x_bin[SAE_MAX_ECC_PRIME_LEN];
        u8 x_cand_bin[SAE_MAX_ECC_PRIME_LEN];
+       u8 qr_bin[SAE_MAX_ECC_PRIME_LEN];
+       u8 qnr_bin[SAE_MAX_ECC_PRIME_LEN];
        size_t bits;
        int res = -1;
        u8 found = 0; /* 0 (false) or 0xff (true) to be used as const_time_*
@@ -476,7 +484,9 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
         * (qnr) modulo p for blinding purposes during the loop.
         */
        if (get_random_qr_qnr(prime, prime_len, sae->tmp->prime, bits,
-                             &qr, &qnr) < 0)
+                             &qr, &qnr) < 0 ||
+           crypto_bignum_to_bin(qr, qr_bin, sizeof(qr_bin), prime_len) < 0 ||
+           crypto_bignum_to_bin(qnr, qnr_bin, sizeof(qnr_bin), prime_len) < 0)
                goto fail;
 
        wpa_hexdump_ascii_key(MSG_DEBUG, "SAE: password",
@@ -527,7 +537,7 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
                        break;
 
                res = sae_test_pwd_seed_ecc(sae, pwd_seed,
-                                           prime, qr, qnr, x_cand_bin);
+                                           prime, qr_bin, qnr_bin, x_cand_bin);
                const_time_select_bin(found, x_bin, x_cand_bin, prime_len,
                                      x_bin);
                pwd_seed_odd = const_time_select_u8(