]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Alternative fix for CVE-2022-4304
authorBernd Edlinger <bernd.edlinger@hotmail.de>
Mon, 13 Feb 2023 16:46:41 +0000 (17:46 +0100)
committerBernd Edlinger <bernd.edlinger@hotmail.de>
Fri, 31 Mar 2023 19:06:23 +0000 (21:06 +0200)
This is about a timing leak in the topmost limb
of the internal result of RSA_private_decrypt,
before the padding check.

There are in fact at least three bugs together that
caused the timing leak:

First and probably most important is the fact that
the blinding did not use the constant time code path
at all when the RSA object was used for a private
decrypt, due to the fact that the Montgomery context
rsa->_method_mod_n was not set up early enough in
rsa_ossl_private_decrypt, when BN_BLINDING_create_param
needed it, and that was persisted as blinding->m_ctx,
although the RSA object creates the Montgomery context
just a bit later.

Then the infamous bn_correct_top was used on the
secret value right after the blinding was removed.

And finally the function BN_bn2binpad did not use
the constant-time code path since the BN_FLG_CONSTTIME
was not set on the secret value.

In order to address the first problem, this patch
makes sure that the rsa->_method_mod_n is initialized
right before the blinding context.

And to fix the second problem, we add a new utility
function bn_correct_top_consttime, a const-time
variant of bn_correct_top.

Together with the fact, that BN_bn2binpad is already
constant time if the flag BN_FLG_CONSTTIME is set,
this should eliminate the timing oracle completely.

In addition the no-asm variant may also have
branches that depend on secret values, because the last
invocation of bn_sub_words in bn_from_montgomery_word
had branches when the function is compiled by certain
gcc compiler versions, due to the clumsy coding style.

So additionally this patch stream-lined the no-asm
C-code in order to avoid branches where possible and
improve the resulting code quality.

Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/20284)

CHANGES
crypto/bn/bn_asm.c
crypto/bn/bn_blind.c
crypto/bn/bn_lib.c
crypto/bn/bn_local.h
crypto/rsa/rsa_ossl.c

diff --git a/CHANGES b/CHANGES
index b19f1429bbb0fcabf6d3bb8b070053d066c2a239..430e32e624e2b03eb1bc7b648d354f910eccf372 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -9,6 +9,16 @@
 
  Changes between 1.1.1t and 1.1.1u [xx XXX xxxx]
 
+  *) Reworked the Fix for the Timing Oracle in RSA Decryption (CVE-2022-4304).
+     The previous fix for this timing side channel turned out to cause
+     a severe 2-3x performance regression in the typical use case
+     compared to 1.1.1s. The new fix uses existing constant time
+     code paths, and restores the previous performance level while
+     fully eliminating all existing timing side channels.
+     The fix was developed by Bernd Edlinger with testing support
+     by Hubert Kario.
+     [Bernd Edlinger]
+
   *) Corrected documentation of X509_VERIFY_PARAM_add0_policy() to mention
      that it does not enable policy checking. Thanks to
      David Benjamin for discovering this issue. (CVE-2023-0466)
index 4d83a8cf1115de72cafcff85b43aa933f93db7cc..177558c6477f49eea557935c66c0e8ce59838f5f 100644 (file)
@@ -381,25 +381,33 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
 #ifndef OPENSSL_SMALL_FOOTPRINT
     while (n & ~3) {
         t1 = a[0];
-        t2 = b[0];
-        r[0] = (t1 - t2 - c) & BN_MASK2;
-        if (t1 != t2)
-            c = (t1 < t2);
+        t2 = (t1 - c) & BN_MASK2;
+        c  = (t2 > t1);
+        t1 = b[0];
+        t1 = (t2 - t1) & BN_MASK2;
+        r[0] = t1;
+        c += (t1 > t2);
         t1 = a[1];
-        t2 = b[1];
-        r[1] = (t1 - t2 - c) & BN_MASK2;
-        if (t1 != t2)
-            c = (t1 < t2);
+        t2 = (t1 - c) & BN_MASK2;
+        c  = (t2 > t1);
+        t1 = b[1];
+        t1 = (t2 - t1) & BN_MASK2;
+        r[1] = t1;
+        c += (t1 > t2);
         t1 = a[2];
-        t2 = b[2];
-        r[2] = (t1 - t2 - c) & BN_MASK2;
-        if (t1 != t2)
-            c = (t1 < t2);
+        t2 = (t1 - c) & BN_MASK2;
+        c  = (t2 > t1);
+        t1 = b[2];
+        t1 = (t2 - t1) & BN_MASK2;
+        r[2] = t1;
+        c += (t1 > t2);
         t1 = a[3];
-        t2 = b[3];
-        r[3] = (t1 - t2 - c) & BN_MASK2;
-        if (t1 != t2)
-            c = (t1 < t2);
+        t2 = (t1 - c) & BN_MASK2;
+        c  = (t2 > t1);
+        t1 = b[3];
+        t1 = (t2 - t1) & BN_MASK2;
+        r[3] = t1;
+        c += (t1 > t2);
         a += 4;
         b += 4;
         r += 4;
@@ -408,10 +416,12 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
 #endif
     while (n) {
         t1 = a[0];
-        t2 = b[0];
-        r[0] = (t1 - t2 - c) & BN_MASK2;
-        if (t1 != t2)
-            c = (t1 < t2);
+        t2 = (t1 - c) & BN_MASK2;
+        c  = (t2 > t1);
+        t1 = b[0];
+        t1 = (t2 - t1) & BN_MASK2;
+        r[0] = t1;
+        c += (t1 > t2);
         a++;
         b++;
         r++;
@@ -446,7 +456,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         t += c0;                /* no carry */  \
         c0 = (BN_ULONG)Lw(t);                   \
         hi = (BN_ULONG)Hw(t);                   \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define mul_add_c2(a,b,c0,c1,c2)      do {    \
@@ -455,11 +465,11 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         BN_ULLONG tt = t+c0;    /* no carry */  \
         c0 = (BN_ULONG)Lw(tt);                  \
         hi = (BN_ULONG)Hw(tt);                  \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         t += c0;                /* no carry */  \
         c0 = (BN_ULONG)Lw(t);                   \
         hi = (BN_ULONG)Hw(t);                   \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define sqr_add_c(a,i,c0,c1,c2)       do {    \
@@ -468,7 +478,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         t += c0;                /* no carry */  \
         c0 = (BN_ULONG)Lw(t);                   \
         hi = (BN_ULONG)Hw(t);                   \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define sqr_add_c2(a,i,j,c0,c1,c2) \
@@ -483,26 +493,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         BN_ULONG ta = (a), tb = (b);            \
         BN_ULONG lo, hi;                        \
         BN_UMULT_LOHI(lo,hi,ta,tb);             \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define mul_add_c2(a,b,c0,c1,c2)      do {    \
         BN_ULONG ta = (a), tb = (b);            \
         BN_ULONG lo, hi, tt;                    \
         BN_UMULT_LOHI(lo,hi,ta,tb);             \
-        c0 += lo; tt = hi+((c0<lo)?1:0);        \
-        c1 += tt; c2 += (c1<tt)?1:0;            \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; tt = hi + (c0<lo);            \
+        c1 += tt; c2 += (c1<tt);                \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define sqr_add_c(a,i,c0,c1,c2)       do {    \
         BN_ULONG ta = (a)[i];                   \
         BN_ULONG lo, hi;                        \
         BN_UMULT_LOHI(lo,hi,ta,ta);             \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define sqr_add_c2(a,i,j,c0,c1,c2)    \
@@ -517,26 +527,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         BN_ULONG ta = (a), tb = (b);            \
         BN_ULONG lo = ta * tb;                  \
         BN_ULONG hi = BN_UMULT_HIGH(ta,tb);     \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define mul_add_c2(a,b,c0,c1,c2)      do {    \
         BN_ULONG ta = (a), tb = (b), tt;        \
         BN_ULONG lo = ta * tb;                  \
         BN_ULONG hi = BN_UMULT_HIGH(ta,tb);     \
-        c0 += lo; tt = hi + ((c0<lo)?1:0);      \
-        c1 += tt; c2 += (c1<tt)?1:0;            \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; tt = hi + (c0<lo);            \
+        c1 += tt; c2 += (c1<tt);                \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define sqr_add_c(a,i,c0,c1,c2)       do {    \
         BN_ULONG ta = (a)[i];                   \
         BN_ULONG lo = ta * ta;                  \
         BN_ULONG hi = BN_UMULT_HIGH(ta,ta);     \
-        c0 += lo; hi += (c0<lo)?1:0;            \
-        c1 += hi; c2 += (c1<hi)?1:0;            \
+        c0 += lo; hi += (c0<lo);                \
+        c1 += hi; c2 += (c1<hi);                \
         } while(0)
 
 #  define sqr_add_c2(a,i,j,c0,c1,c2)      \
@@ -551,8 +561,8 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         BN_ULONG lo = LBITS(a), hi = HBITS(a);  \
         BN_ULONG bl = LBITS(b), bh = HBITS(b);  \
         mul64(lo,hi,bl,bh);                     \
-        c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c0 = (c0+lo)&BN_MASK2; hi += (c0<lo);   \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define mul_add_c2(a,b,c0,c1,c2)      do {    \
@@ -561,17 +571,17 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
         BN_ULONG bl = LBITS(b), bh = HBITS(b);  \
         mul64(lo,hi,bl,bh);                     \
         tt = hi;                                \
-        c0 = (c0+lo)&BN_MASK2; if (c0<lo) tt++; \
-        c1 = (c1+tt)&BN_MASK2; if (c1<tt) c2++; \
-        c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c0 = (c0+lo)&BN_MASK2; tt += (c0<lo);   \
+        c1 = (c1+tt)&BN_MASK2; c2 += (c1<tt);   \
+        c0 = (c0+lo)&BN_MASK2; hi += (c0<lo);   \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define sqr_add_c(a,i,c0,c1,c2)       do {    \
         BN_ULONG lo, hi;                        \
         sqr64(lo,hi,(a)[i]);                    \
-        c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
-        c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
+        c0 = (c0+lo)&BN_MASK2; hi += (c0<lo);   \
+        c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi);   \
         } while(0)
 
 #  define sqr_add_c2(a,i,j,c0,c1,c2) \
index 15d9e0a5445be2a264b215456c06a4dc463c2f5f..e76f6107a7b552a004939b0a89e24ea953724e3b 100644 (file)
@@ -191,7 +191,8 @@ int BN_BLINDING_invert_ex(BIGNUM *n, const BIGNUM *r, BN_BLINDING *b,
             n->top = (int)(rtop & ~mask) | (ntop & mask);
             n->flags |= (BN_FLG_FIXED_TOP & ~mask);
         }
-        ret = BN_mod_mul_montgomery(n, n, r, b->m_ctx, ctx);
+        ret = bn_mul_mont_fixed_top(n, n, r, b->m_ctx, ctx);
+        bn_correct_top_consttime(n);
     } else {
         ret = BN_mod_mul(n, n, r, b->mod, ctx);
     }
index eb4a31849bef11837d273d7b7208b0e72d8caa3b..fe6fb0e40fbb8431a014322b306de06d7fe3c942 100644 (file)
@@ -1001,6 +1001,28 @@ BIGNUM *bn_wexpand(BIGNUM *a, int words)
     return (words <= a->dmax) ? a : bn_expand2(a, words);
 }
 
+void bn_correct_top_consttime(BIGNUM *a)
+{
+    int j, atop;
+    BN_ULONG limb;
+    unsigned int mask;
+
+    for (j = 0, atop = 0; j < a->dmax; j++) {
+        limb = a->d[j];
+        limb |= 0 - limb;
+        limb >>= BN_BITS2 - 1;
+        limb = 0 - limb;
+        mask = (unsigned int)limb;
+        mask &= constant_time_msb(j - a->top);
+        atop = constant_time_select_int(mask, j + 1, atop);
+    }
+
+    mask = constant_time_eq_int(atop, 0);
+    a->top = atop;
+    a->neg = constant_time_select_int(mask, 0, a->neg);
+    a->flags &= ~BN_FLG_FIXED_TOP;
+}
+
 void bn_correct_top(BIGNUM *a)
 {
     BN_ULONG *ftl;
index ee6342b60cbbeb2cc76b569f47d891c5ad952162..818e34348e1b964975e76a26973d3e0ff1c1d8ea 100644 (file)
@@ -515,10 +515,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         ret =  (r);                     \
         BN_UMULT_LOHI(low,high,w,tmp);  \
         ret += (c);                     \
-        (c) =  (ret<(c))?1:0;           \
+        (c) =  (ret<(c));               \
         (c) += high;                    \
         ret += low;                     \
-        (c) += (ret<low)?1:0;           \
+        (c) += (ret<low);               \
         (r) =  ret;                     \
         }
 
@@ -527,7 +527,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         BN_UMULT_LOHI(low,high,w,ta);   \
         ret =  low + (c);               \
         (c) =  high;                    \
-        (c) += (ret<low)?1:0;           \
+        (c) += (ret<low);               \
         (r) =  ret;                     \
         }
 
@@ -543,10 +543,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         high=  BN_UMULT_HIGH(w,tmp);    \
         ret += (c);                     \
         low =  (w) * tmp;               \
-        (c) =  (ret<(c))?1:0;           \
+        (c) =  (ret<(c));               \
         (c) += high;                    \
         ret += low;                     \
-        (c) += (ret<low)?1:0;           \
+        (c) += (ret<low);               \
         (r) =  ret;                     \
         }
 
@@ -556,7 +556,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         high=  BN_UMULT_HIGH(w,ta);     \
         ret =  low + (c);               \
         (c) =  high;                    \
-        (c) += (ret<low)?1:0;           \
+        (c) += (ret<low);               \
         (r) =  ret;                     \
         }
 
@@ -589,10 +589,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         lt=(bl)*(lt); \
         m1=(bl)*(ht); \
         ht =(bh)*(ht); \
-        m=(m+m1)&BN_MASK2; if (m < m1) ht+=L2HBITS((BN_ULONG)1); \
+        m=(m+m1)&BN_MASK2; ht += L2HBITS((BN_ULONG)(m < m1)); \
         ht+=HBITS(m); \
         m1=L2HBITS(m); \
-        lt=(lt+m1)&BN_MASK2; if (lt < m1) ht++; \
+        lt=(lt+m1)&BN_MASK2; ht += (lt < m1); \
         (l)=lt; \
         (h)=ht; \
         }
@@ -609,7 +609,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         h*=h; \
         h+=(m&BN_MASK2h1)>>(BN_BITS4-1); \
         m =(m&BN_MASK2l)<<(BN_BITS4+1); \
-        l=(l+m)&BN_MASK2; if (l < m) h++; \
+        l=(l+m)&BN_MASK2; h += (l < m); \
         (lo)=l; \
         (ho)=h; \
         }
@@ -623,9 +623,9 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         mul64(l,h,(bl),(bh)); \
  \
         /* non-multiply part */ \
-        l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
+        l=(l+(c))&BN_MASK2; h += (l < (c)); \
         (c)=(r); \
-        l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
+        l=(l+(c))&BN_MASK2; h += (l < (c)); \
         (c)=h&BN_MASK2; \
         (r)=l; \
         }
@@ -639,7 +639,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
         mul64(l,h,(bl),(bh)); \
  \
         /* non-multiply part */ \
-        l+=(c); if ((l&BN_MASK2) < (c)) h++; \
+        l+=(c); h += ((l&BN_MASK2) < (c)); \
         (c)=h&BN_MASK2; \
         (r)=l&BN_MASK2; \
         }
@@ -669,7 +669,7 @@ BN_ULONG bn_sub_part_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
                            int cl, int dl);
 int bn_mul_mont(BN_ULONG *rp, const BN_ULONG *ap, const BN_ULONG *bp,
                 const BN_ULONG *np, const BN_ULONG *n0, int num);
-
+void bn_correct_top_consttime(BIGNUM *a);
 BIGNUM *int_bn_mod_inverse(BIGNUM *in,
                            const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
                            int *noinv);
index 53cf2d03c947c211eed930c691b6d23c0365bb56..cf5a10ab435f89cff9f54f8e548465d815585f1f 100644 (file)
@@ -226,6 +226,7 @@ static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
      * will only read the modulus from BN_BLINDING. In both cases it's safe
      * to access the blinding without a lock.
      */
+    BN_set_flags(f, BN_FLG_CONSTTIME);
     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
 }
 
@@ -412,6 +413,11 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         goto err;
     }
 
+    if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
+        if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
+                                    rsa->n, ctx))
+            goto err;
+
     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
         if (blinding == NULL) {
@@ -449,13 +455,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
             goto err;
         }
         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
-
-        if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
-            if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
-                                        rsa->n, ctx)) {
-                BN_free(d);
-                goto err;
-            }
         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
                                    rsa->_method_mod_n)) {
             BN_free(d);