]> git.ipfire.org Git - thirdparty/openssl.git/blob - crypto/rsa/rsa_pss.c
RSA padding Zeroization fixes
[thirdparty/openssl.git] / crypto / rsa / rsa_pss.c
1 /*
2 * Copyright 2005-2018 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the OpenSSL license (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/bn.h>
13 #include <openssl/rsa.h>
14 #include <openssl/evp.h>
15 #include <openssl/rand.h>
16 #include <openssl/sha.h>
17 #include "rsa_locl.h"
18
19 static const unsigned char zeroes[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
20
21 #if defined(_MSC_VER) && defined(_ARM_)
22 # pragma optimize("g", off)
23 #endif
24
25 int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash,
26 const EVP_MD *Hash, const unsigned char *EM,
27 int sLen)
28 {
29 return RSA_verify_PKCS1_PSS_mgf1(rsa, mHash, Hash, NULL, EM, sLen);
30 }
31
32 int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const unsigned char *mHash,
33 const EVP_MD *Hash, const EVP_MD *mgf1Hash,
34 const unsigned char *EM, int sLen)
35 {
36 int i;
37 int ret = 0;
38 int hLen, maskedDBLen, MSBits, emLen;
39 const unsigned char *H;
40 unsigned char *DB = NULL;
41 EVP_MD_CTX *ctx = EVP_MD_CTX_new();
42 unsigned char H_[EVP_MAX_MD_SIZE];
43
44
45 if (ctx == NULL)
46 goto err;
47
48 if (mgf1Hash == NULL)
49 mgf1Hash = Hash;
50
51 hLen = EVP_MD_size(Hash);
52 if (hLen < 0)
53 goto err;
54 /*-
55 * Negative sLen has special meanings:
56 * -1 sLen == hLen
57 * -2 salt length is autorecovered from signature
58 * -N reserved
59 */
60 if (sLen == -1)
61 sLen = hLen;
62 else if (sLen == -2)
63 sLen = -2;
64 else if (sLen < -2) {
65 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_SLEN_CHECK_FAILED);
66 goto err;
67 }
68
69 MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
70 emLen = RSA_size(rsa);
71 if (EM[0] & (0xFF << MSBits)) {
72 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_FIRST_OCTET_INVALID);
73 goto err;
74 }
75 if (MSBits == 0) {
76 EM++;
77 emLen--;
78 }
79 if (emLen < hLen + 2) {
80 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_DATA_TOO_LARGE);
81 goto err;
82 }
83 if (sLen > emLen - hLen - 2) { /* sLen can be small negative */
84 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_DATA_TOO_LARGE);
85 goto err;
86 }
87 if (EM[emLen - 1] != 0xbc) {
88 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_LAST_OCTET_INVALID);
89 goto err;
90 }
91 maskedDBLen = emLen - hLen - 1;
92 H = EM + maskedDBLen;
93 DB = OPENSSL_malloc(maskedDBLen);
94 if (DB == NULL) {
95 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, ERR_R_MALLOC_FAILURE);
96 goto err;
97 }
98 if (PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash) < 0)
99 goto err;
100 for (i = 0; i < maskedDBLen; i++)
101 DB[i] ^= EM[i];
102 if (MSBits)
103 DB[0] &= 0xFF >> (8 - MSBits);
104 for (i = 0; DB[i] == 0 && i < (maskedDBLen - 1); i++) ;
105 if (DB[i++] != 0x1) {
106 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_SLEN_RECOVERY_FAILED);
107 goto err;
108 }
109 if (sLen >= 0 && (maskedDBLen - i) != sLen) {
110 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_SLEN_CHECK_FAILED);
111 goto err;
112 }
113 if (!EVP_DigestInit_ex(ctx, Hash, NULL)
114 || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes))
115 || !EVP_DigestUpdate(ctx, mHash, hLen))
116 goto err;
117 if (maskedDBLen - i) {
118 if (!EVP_DigestUpdate(ctx, DB + i, maskedDBLen - i))
119 goto err;
120 }
121 if (!EVP_DigestFinal_ex(ctx, H_, NULL))
122 goto err;
123 if (memcmp(H_, H, hLen)) {
124 RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS_MGF1, RSA_R_BAD_SIGNATURE);
125 ret = 0;
126 } else
127 ret = 1;
128
129 err:
130 OPENSSL_free(DB);
131 EVP_MD_CTX_free(ctx);
132
133 return ret;
134
135 }
136
137 int RSA_padding_add_PKCS1_PSS(RSA *rsa, unsigned char *EM,
138 const unsigned char *mHash,
139 const EVP_MD *Hash, int sLen)
140 {
141 return RSA_padding_add_PKCS1_PSS_mgf1(rsa, EM, mHash, Hash, NULL, sLen);
142 }
143
144 int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, unsigned char *EM,
145 const unsigned char *mHash,
146 const EVP_MD *Hash, const EVP_MD *mgf1Hash,
147 int sLen)
148 {
149 int i;
150 int ret = 0;
151 int hLen, maskedDBLen, MSBits, emLen;
152 unsigned char *H, *salt = NULL, *p;
153 EVP_MD_CTX *ctx = NULL;
154
155 if (mgf1Hash == NULL)
156 mgf1Hash = Hash;
157
158 hLen = EVP_MD_size(Hash);
159 if (hLen < 0)
160 goto err;
161 /*-
162 * Negative sLen has special meanings:
163 * -1 sLen == hLen
164 * -2 salt length is maximized
165 * -N reserved
166 */
167 if (sLen == -1)
168 sLen = hLen;
169 else if (sLen == -2)
170 sLen = -2;
171 else if (sLen < -2) {
172 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS_MGF1, RSA_R_SLEN_CHECK_FAILED);
173 goto err;
174 }
175
176 MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
177 emLen = RSA_size(rsa);
178 if (MSBits == 0) {
179 *EM++ = 0;
180 emLen--;
181 }
182 if (emLen < hLen + 2) {
183 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS_MGF1,
184 RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
185 goto err;
186 }
187 if (sLen == -2) {
188 sLen = emLen - hLen - 2;
189 } else if (sLen > emLen - hLen - 2) {
190 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS_MGF1,
191 RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
192 goto err;
193 }
194 if (sLen > 0) {
195 salt = OPENSSL_malloc(sLen);
196 if (salt == NULL) {
197 RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS_MGF1,
198 ERR_R_MALLOC_FAILURE);
199 goto err;
200 }
201 if (RAND_bytes(salt, sLen) <= 0)
202 goto err;
203 }
204 maskedDBLen = emLen - hLen - 1;
205 H = EM + maskedDBLen;
206 ctx = EVP_MD_CTX_new();
207 if (ctx == NULL)
208 goto err;
209 if (!EVP_DigestInit_ex(ctx, Hash, NULL)
210 || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes))
211 || !EVP_DigestUpdate(ctx, mHash, hLen))
212 goto err;
213 if (sLen && !EVP_DigestUpdate(ctx, salt, sLen))
214 goto err;
215 if (!EVP_DigestFinal_ex(ctx, H, NULL))
216 goto err;
217
218 /* Generate dbMask in place then perform XOR on it */
219 if (PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash))
220 goto err;
221
222 p = EM;
223
224 /*
225 * Initial PS XORs with all zeroes which is a NOP so just update pointer.
226 * Note from a test above this value is guaranteed to be non-negative.
227 */
228 p += emLen - sLen - hLen - 2;
229 *p++ ^= 0x1;
230 if (sLen > 0) {
231 for (i = 0; i < sLen; i++)
232 *p++ ^= salt[i];
233 }
234 if (MSBits)
235 EM[0] &= 0xFF >> (8 - MSBits);
236
237 /* H is already in place so just set final 0xbc */
238
239 EM[emLen - 1] = 0xbc;
240
241 ret = 1;
242
243 err:
244 EVP_MD_CTX_free(ctx);
245 OPENSSL_clear_free(salt, sLen);
246
247 return ret;
248
249 }
250
251 #if defined(_MSC_VER)
252 # pragma optimize("",on)
253 #endif