]> git.ipfire.org Git - thirdparty/openssl.git/blob - include/internal/safe_math.h
85c6147e55c81aa882592232452d58a66f0c7d17
[thirdparty/openssl.git] / include / internal / safe_math.h
1 /*
2 * Copyright 2021 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #ifndef OSSL_INTERNAL_SAFE_MATH_H
11 # define OSSL_INTERNAL_SAFE_MATH_H
12 # pragma once
13
14 # include <openssl/e_os2.h> /* For 'ossl_inline' */
15
16 # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
17 # ifdef __has_builtin
18 # define has(func) __has_builtin(func)
19 # elif __GNUC__ > 5
20 # define has(func) 1
21 # endif
22 # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
23
24 # ifndef has
25 # define has(func) 0
26 # endif
27
28 /*
29 * Safe addition helpers
30 */
31 # if has(__builtin_add_overflow)
32 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
33 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
34 type b, \
35 int *err) \
36 { \
37 type r; \
38 \
39 if (!__builtin_add_overflow(a, b, &r)) \
40 return r; \
41 *err |= 1; \
42 return a < 0 ? min : max; \
43 }
44
45 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
46 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
47 type b, \
48 int *err) \
49 { \
50 type r; \
51 \
52 if (!__builtin_add_overflow(a, b, &r)) \
53 return r; \
54 *err |= 1; \
55 return a + b; \
56 }
57
58 # else /* has(__builtin_add_overflow) */
59 # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
60 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
61 type b, \
62 int *err) \
63 { \
64 if ((a < 0) ^ (b < 0) \
65 || (a > 0 && b <= max - a) \
66 || (a < 0 && b >= min - a) \
67 || a == 0) \
68 return a + b; \
69 *err |= 1; \
70 return a < 0 ? min : max; \
71 }
72
73 # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
74 static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
75 type b, \
76 int *err) \
77 { \
78 if (b > max - a) \
79 *err |= 1; \
80 return a + b; \
81 }
82 # endif /* has(__builtin_add_overflow) */
83
84 /*
85 * Safe subtraction helpers
86 */
87 # if has(__builtin_sub_overflow)
88 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
89 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
90 type b, \
91 int *err) \
92 { \
93 type r; \
94 \
95 if (!__builtin_sub_overflow(a, b, &r)) \
96 return r; \
97 *err |= 1; \
98 return a < 0 ? min : max; \
99 }
100
101 # else /* has(__builtin_sub_overflow) */
102 # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
103 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
104 type b, \
105 int *err) \
106 { \
107 if (!((a < 0) ^ (b < 0)) \
108 || (b > 0 && a >= min + b) \
109 || (b < 0 && a <= max + b) \
110 || b == 0) \
111 return a - b; \
112 *err |= 1; \
113 return a < 0 ? min : max; \
114 }
115
116 # endif /* has(__builtin_sub_overflow) */
117
118 # define OSSL_SAFE_MATH_SUBU(type_name, type) \
119 static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
120 type b, \
121 int *err) \
122 { \
123 if (b > a) \
124 *err |= 1; \
125 return a - b; \
126 }
127
128 /*
129 * Safe multiplication helpers
130 */
131 # if has(__builtin_mul_overflow)
132 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
133 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
134 type b, \
135 int *err) \
136 { \
137 type r; \
138 \
139 if (!__builtin_mul_overflow(a, b, &r)) \
140 return r; \
141 *err |= 1; \
142 return (a < 0) ^ (b < 0) ? min : max; \
143 }
144
145 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
146 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
147 type b, \
148 int *err) \
149 { \
150 type r; \
151 \
152 if (!__builtin_mul_overflow(a, b, &r)) \
153 return r; \
154 *err |= 1; \
155 return a * b; \
156 }
157
158 # else /* has(__builtin_mul_overflow) */
159 # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
160 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
161 type b, \
162 int *err) \
163 { \
164 if (a == 0 || b == 0) \
165 return 0; \
166 if (a == 1) \
167 return b; \
168 if (b == 1) \
169 return a; \
170 if (a != min && b != min) { \
171 const type x = a < 0 ? -a : a; \
172 const type y = b < 0 ? -b : b; \
173 \
174 if (x <= max / y) \
175 return a * b; \
176 } \
177 *err |= 1; \
178 return (a < 0) ^ (b < 0) ? min : max; \
179 }
180
181 # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
182 static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
183 type b, \
184 int *err) \
185 { \
186 if (a > max / b) \
187 *err |= 1; \
188 return a * b; \
189 }
190 # endif /* has(__builtin_mul_overflow) */
191
192 /*
193 * Safe division helpers
194 */
195 # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
196 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
197 type b, \
198 int *err) \
199 { \
200 if (b == 0) { \
201 *err |= 1; \
202 return a < 0 ? min : max; \
203 } \
204 if (b == -1 && a == min) { \
205 *err |= 1; \
206 return max; \
207 } \
208 return a / b; \
209 }
210
211 # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
212 static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
213 type b, \
214 int *err) \
215 { \
216 if (b != 0) \
217 return a / b; \
218 *err |= 1; \
219 return max; \
220 }
221
222 /*
223 * Safe modulus helpers
224 */
225 # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
226 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
227 type b, \
228 int *err) \
229 { \
230 if (b == 0) { \
231 *err |= 1; \
232 return 0; \
233 } \
234 if (b == -1 && a == min) { \
235 *err |= 1; \
236 return max; \
237 } \
238 return a % b; \
239 }
240
241 # define OSSL_SAFE_MATH_MODU(type_name, type) \
242 static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
243 type b, \
244 int *err) \
245 { \
246 if (b != 0) \
247 return a % b; \
248 *err |= 1; \
249 return 0; \
250 }
251
252 /*
253 * Safe negation helpers
254 */
255 # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
256 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
257 int *err) \
258 { \
259 if (a != min) \
260 return -a; \
261 *err |= 1; \
262 return min; \
263 }
264
265 # define OSSL_SAFE_MATH_NEGU(type_name, type) \
266 static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
267 int *err) \
268 { \
269 if (a == 0) \
270 return a; \
271 *err |= 1; \
272 return 1 + ~a; \
273 }
274
275 /*
276 * Safe absolute value helpers
277 */
278 # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
279 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
280 int *err) \
281 { \
282 if (a != min) \
283 return a < 0 ? -a : a; \
284 *err |= 1; \
285 return min; \
286 }
287
288 # define OSSL_SAFE_MATH_ABSU(type_name, type) \
289 static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
290 int *err) \
291 { \
292 return a; \
293 }
294
295 /*
296 * Safe fused multiply divide helpers
297 *
298 * These are a bit obscure:
299 * . They begin by checking the denominator for zero and getting rid of this
300 * corner case.
301 *
302 * . Second is an attempt to do the multiplication directly, if it doesn't
303 * overflow, the quotient is returned (for signed values there is a
304 * potential problem here which isn't present for unsigned).
305 *
306 * . Finally, the multiplication/division is transformed so that the larger
307 * of the numerators is divided first. This requires a remainder
308 * correction:
309 *
310 * a b / c = (a / c) b + (a mod c) b / c, where a > b
311 *
312 * The individual operations need to be overflow checked (again signed
313 * being more problematic).
314 *
315 * The algorithm used is not perfect but it should be "good enough".
316 */
317 # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
318 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
319 type b, \
320 type c, \
321 int *err) \
322 { \
323 int e2 = 0; \
324 type q, r, x, y; \
325 \
326 if (c == 0) { \
327 *err |= 1; \
328 return a == 0 || b == 0 ? 0 : max; \
329 } \
330 x = safe_mul_ ## type_name(a, b, &e2); \
331 if (!e2) \
332 return safe_div_ ## type_name(x, c, err); \
333 if (b > a) { \
334 x = b; \
335 b = a; \
336 a = x; \
337 } \
338 q = safe_div_ ## type_name(a, c, err); \
339 r = safe_mod_ ## type_name(a, c, err); \
340 x = safe_mul_ ## type_name(r, b, err); \
341 y = safe_mul_ ## type_name(q, b, err); \
342 q = safe_div_ ## type_name(x, c, err); \
343 return safe_add_ ## type_name(y, q, err); \
344 }
345
346 # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
347 static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
348 type b, \
349 type c, \
350 int *err) \
351 { \
352 int e2 = 0; \
353 type x, y; \
354 \
355 if (c == 0) { \
356 *err |= 1; \
357 return a == 0 || b == 0 ? 0 : max; \
358 } \
359 x = safe_mul_ ## type_name(a, b, &e2); \
360 if (!e2) \
361 return x / c; \
362 if (b > a) { \
363 x = b; \
364 b = a; \
365 a = x; \
366 } \
367 x = safe_mul_ ## type_name(a % c, b, err); \
368 y = safe_mul_ ## type_name(a / c, b, err); \
369 return safe_add_ ## type_name(y, x / c, err); \
370 }
371
372 /* Calculate ranges of types */
373 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
374 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
375 # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
376
377 /*
378 * Wrapper macros to create all the functions of a given type
379 */
380 # define OSSL_SAFE_MATH_SIGNED(type_name, type) \
381 OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
382 OSSL_SAFE_MATH_MAXS(type)) \
383 OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
384 OSSL_SAFE_MATH_MAXS(type)) \
385 OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
386 OSSL_SAFE_MATH_MAXS(type)) \
387 OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
388 OSSL_SAFE_MATH_MAXS(type)) \
389 OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
390 OSSL_SAFE_MATH_MAXS(type)) \
391 OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
392 OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
393 OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
394
395 # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
396 OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
397 OSSL_SAFE_MATH_SUBU(type_name, type) \
398 OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
399 OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
400 OSSL_SAFE_MATH_MODU(type_name, type) \
401 OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
402 OSSL_SAFE_MATH_NEGU(type_name, type) \
403 OSSL_SAFE_MATH_ABSU(type_name, type)
404
405 #endif /* OSSL_INTERNAL_SAFE_MATH_H */