]> git.ipfire.org Git - thirdparty/gcc.git/blob - libphobos/src/std/numeric.d
d: Import dmd b8384668f, druntime e6caaab9, phobos 5ab9ad256 (v2.098.0-beta.1)
[thirdparty/gcc.git] / libphobos / src / std / numeric.d
1 // Written in the D programming language.
2
3 /**
4 This module is a port of a growing fragment of the $(D_PARAM numeric)
5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library,
6 Standard Template Library), with a few additions.
7
8 Macros:
9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009.
10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0).
11 Authors: $(HTTP erdani.org, Andrei Alexandrescu),
12 Don Clugston, Robert Jacques, Ilya Yaroshenko
13 Source: $(PHOBOSSRC std/numeric.d)
14 */
15 /*
16 Copyright Andrei Alexandrescu 2008 - 2009.
17 Distributed under the Boost Software License, Version 1.0.
18 (See accompanying file LICENSE_1_0.txt or copy at
19 http://www.boost.org/LICENSE_1_0.txt)
20 */
21 module std.numeric;
22
23 import std.complex;
24 import std.math;
25 import core.math : fabs, ldexp, sin, sqrt;
26 import std.range.primitives;
27 import std.traits;
28 import std.typecons;
29
30 /// Format flags for CustomFloat.
31 public enum CustomFloatFlags
32 {
33 /// Adds a sign bit to allow for signed numbers.
34 signed = 1,
35
36 /**
37 * Store values in normalized form by default. The actual precision of the
38 * significand is extended by 1 bit by assuming an implicit leading bit of 1
39 * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`.
40 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types
41 */
42 storeNormalized = 2,
43
44 /**
45 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
46 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0.
47 */
48 allowDenorm = 4,
49
50 /**
51 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity,
52 * IEEE754 _infinity) values.
53 */
54 infinity = 8,
55
56 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values.
57 nan = 16,
58
59 /**
60 * If set, select an exponent bias such that max_exp = 1.
61 * i.e. so that the maximum value is >= 1.0 and < 2.0.
62 * Ignored if the exponent bias is manually specified.
63 */
64 probability = 32,
65
66 /// If set, unsigned custom floats are assumed to be negative.
67 negativeUnsigned = 64,
68
69 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
70 * IEEE754 denormalized) number.
71 * Requires allowDenorm and storeNormalized.
72 */
73 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized,
74
75 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options.
76 ieee = signed | storeNormalized | allowDenorm | infinity | nan ,
77
78 /// Include none of the above options.
79 none = 0
80 }
81
82 private template CustomFloatParams(uint bits)
83 {
84 enum CustomFloatFlags flags = CustomFloatFlags.ieee
85 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none);
86 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags);
87 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags);
88 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags);
89 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags);
90 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags);
91 }
92
93 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags)
94 {
95 import std.meta : AliasSeq;
96 alias CustomFloatParams =
97 AliasSeq!(
98 precision,
99 exponentWidth,
100 flags,
101 (1 << (exponentWidth - ((flags & flags.probability) == 0)))
102 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0)
103 ); // ((flags & CustomFloatFlags.probability) == 0)
104 }
105
106 /**
107 * Allows user code to define custom floating-point formats. These formats are
108 * for storage only; all operations on them are performed by first implicitly
109 * extracting them to `real` first. After the operation is completed the
110 * result can be stored in a custom floating-point value via assignment.
111 */
112 template CustomFloat(uint bits)
113 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80)
114 {
115 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits));
116 }
117
118 /// ditto
119 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee)
120 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0)
121 {
122 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags));
123 }
124
125 ///
126 @safe unittest
127 {
128 import std.math.trigonometry : sin, cos;
129
130 // Define a 16-bit floating point values
131 CustomFloat!16 x; // Using the number of bits
132 CustomFloat!(10, 5) y; // Using the precision and exponent width
133 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags
134 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias
135
136 // Use the 16-bit floats mostly like normal numbers
137 w = x*y - 1;
138
139 // Functions calls require conversion
140 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real
141 z = sin(x.get!float) + cos(y.get!float); // Or use get!T
142 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert
143
144 // Define a 8-bit custom float for storing probabilities
145 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed );
146 auto p = Probability(0.5);
147 }
148
149 // Facilitate converting numeric types to custom float
150 private union ToBinary(F)
151 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real))
152 {
153 F set;
154
155 // If on Linux or Mac, where 80-bit reals are padded, ignore the
156 // padding.
157 import std.algorithm.comparison : min;
158 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get;
159
160 // Convert F to the correct binary type.
161 static typeof(get) opCall(F value)
162 {
163 ToBinary r;
164 r.set = value;
165 return r.get;
166 }
167 alias get this;
168 }
169
170 /// ditto
171 struct CustomFloat(uint precision, // fraction bits (23 for float)
172 uint exponentWidth, // exponent bits (8 for float) Exponent width
173 CustomFloatFlags flags,
174 uint bias)
175 if (isCorrectCustomFloat(precision, exponentWidth, flags))
176 {
177 import std.bitmanip : bitfields;
178 import std.meta : staticIndexOf;
179 private:
180 // get the correct unsigned bitfield type to support > 32 bits
181 template uType(uint bits)
182 {
183 static if (bits <= size_t.sizeof*8) alias uType = size_t;
184 else alias uType = ulong ;
185 }
186
187 // get the correct signed bitfield type to support > 32 bits
188 template sType(uint bits)
189 {
190 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t;
191 else alias sType = long;
192 }
193
194 alias T_sig = uType!precision;
195 alias T_exp = uType!exponentWidth;
196 alias T_signed_exp = sType!exponentWidth;
197
198 alias Flags = CustomFloatFlags;
199
200 // Perform IEEE rounding with round to nearest detection
201 void roundedShift(T,U)(ref T sig, U shift)
202 {
203 if (shift >= T.sizeof*8)
204 {
205 // avoid illegal shift
206 sig = 0;
207 }
208 else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1))
209 {
210 // round to even
211 sig >>= shift;
212 sig += sig & 1;
213 }
214 else
215 {
216 sig >>= shift - 1;
217 sig += sig & 1;
218 // Perform standard rounding
219 sig >>= 1;
220 }
221 }
222
223 // Convert the current value to signed exponent, normalized form
224 void toNormalized(T,U)(ref T sig, ref U exp)
225 {
226 sig = significand;
227 auto shift = (T.sizeof*8) - precision;
228 exp = exponent;
229 static if (flags&(Flags.infinity|Flags.nan))
230 {
231 // Handle inf or nan
232 if (exp == exponent_max)
233 {
234 exp = exp.max;
235 sig <<= shift;
236 static if (flags&Flags.storeNormalized)
237 {
238 // Save inf/nan in denormalized format
239 sig >>= 1;
240 sig += cast(T) 1uL << (T.sizeof*8 - 1);
241 }
242 return;
243 }
244 }
245 if ((~flags&Flags.storeNormalized) ||
246 // Convert denormalized form to normalized form
247 ((flags&Flags.allowDenorm) && exp == 0))
248 {
249 if (sig > 0)
250 {
251 import core.bitop : bsr;
252 auto shift2 = precision - bsr(sig);
253 exp -= shift2-1;
254 shift += shift2;
255 }
256 else // value = 0.0
257 {
258 exp = exp.min;
259 return;
260 }
261 }
262 sig <<= shift;
263 exp -= bias;
264 }
265
266 // Set the current value from signed exponent, normalized form
267 void fromNormalized(T,U)(ref T sig, ref U exp)
268 {
269 auto shift = (T.sizeof*8) - precision;
270 if (exp == exp.max)
271 {
272 // infinity or nan
273 exp = exponent_max;
274 static if (flags & Flags.storeNormalized)
275 sig <<= 1;
276
277 // convert back to normalized form
278 static if (~flags & Flags.infinity)
279 // No infinity support?
280 assert(sig != 0, "Infinity floating point value assigned to a "
281 ~ typeof(this).stringof ~ " (no infinity support).");
282
283 static if (~flags & Flags.nan) // No NaN support?
284 assert(sig == 0, "NaN floating point value assigned to a " ~
285 typeof(this).stringof ~ " (no nan support).");
286 sig >>= shift;
287 return;
288 }
289 if (exp == exp.min) // 0.0
290 {
291 exp = 0;
292 sig = 0;
293 return;
294 }
295
296 exp += bias;
297 if (exp <= 0)
298 {
299 static if ((flags&Flags.allowDenorm) ||
300 // Convert from normalized form to denormalized
301 (~flags&Flags.storeNormalized))
302 {
303 shift += -exp;
304 roundedShift(sig,1);
305 sig += cast(T) 1uL << (T.sizeof*8 - 1);
306 // Add the leading 1
307 exp = 0;
308 }
309 else
310 assert((flags&Flags.storeNormalized) && exp == 0,
311 "Underflow occured assigning to a " ~
312 typeof(this).stringof ~ " (no denormal support).");
313 }
314 else
315 {
316 static if (~flags&Flags.storeNormalized)
317 {
318 // Convert from normalized form to denormalized
319 roundedShift(sig,1);
320 sig += cast(T) 1uL << (T.sizeof*8 - 1);
321 // Add the leading 1
322 }
323 }
324
325 if (shift > 0)
326 roundedShift(sig,shift);
327 if (sig > significand_max)
328 {
329 // handle significand overflow (should only be 1 bit)
330 static if (~flags&Flags.storeNormalized)
331 {
332 sig >>= 1;
333 }
334 else
335 sig &= significand_max;
336 exp++;
337 }
338 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly)
339 {
340 // disallow non-zero denormals
341 if (exp == 0)
342 {
343 sig <<= 1;
344 if (sig > significand_max && (sig&significand_max) > 0)
345 // Check and round to even
346 exp++;
347 sig = 0;
348 }
349 }
350
351 if (exp >= exponent_max)
352 {
353 static if (flags&(Flags.infinity|Flags.nan))
354 {
355 sig = 0;
356 exp = exponent_max;
357 static if (~flags&(Flags.infinity))
358 assert(0, "Overflow occured assigning to a " ~
359 typeof(this).stringof ~ " (no infinity support).");
360 }
361 else
362 assert(exp == exponent_max, "Overflow occured assigning to a "
363 ~ typeof(this).stringof ~ " (no infinity support).");
364 }
365 }
366
367 public:
368 static if (precision == 64) // CustomFloat!80 support hack
369 {
370 ulong significand;
371 enum ulong significand_max = ulong.max;
372 mixin(bitfields!(
373 T_exp , "exponent", exponentWidth,
374 bool , "sign" , flags & flags.signed ));
375 }
376 else
377 {
378 mixin(bitfields!(
379 T_sig, "significand", precision,
380 T_exp, "exponent" , exponentWidth,
381 bool , "sign" , flags & flags.signed ));
382 }
383
384 /// Returns: infinity value
385 static if (flags & Flags.infinity)
386 static @property CustomFloat infinity()
387 {
388 CustomFloat value;
389 static if (flags & Flags.signed)
390 value.sign = 0;
391 value.significand = 0;
392 value.exponent = exponent_max;
393 return value;
394 }
395
396 /// Returns: NaN value
397 static if (flags & Flags.nan)
398 static @property CustomFloat nan()
399 {
400 CustomFloat value;
401 static if (flags & Flags.signed)
402 value.sign = 0;
403 value.significand = cast(typeof(significand_max)) 1L << (precision-1);
404 value.exponent = exponent_max;
405 return value;
406 }
407
408 /// Returns: number of decimal digits of precision
409 static @property size_t dig()
410 {
411 auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0);
412 return shiftcnt == 64 ? 19 : cast(size_t) log10(1uL << shiftcnt);
413 }
414
415 /// Returns: smallest increment to the value 1
416 static @property CustomFloat epsilon()
417 {
418 CustomFloat one = CustomFloat(1);
419 CustomFloat onePlusEpsilon = one;
420 onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here
421
422 return CustomFloat(onePlusEpsilon - one);
423 }
424
425 /// the number of bits in mantissa
426 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0);
427
428 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable
429 static @property int max_10_exp(){ return cast(int) log10( +max ); }
430
431 /// maximum int value such that 2<sup>max_exp-1</sup> is representable
432 enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1;
433
434 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable
435 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); }
436
437 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value
438 enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0);
439
440 /// Returns: largest representable value that's not infinity
441 static @property CustomFloat max()
442 {
443 CustomFloat value;
444 static if (flags & Flags.signed)
445 value.sign = 0;
446 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0);
447 value.significand = significand_max;
448 return value;
449 }
450
451 /// Returns: smallest representable normalized value that's not 0
452 static @property CustomFloat min_normal()
453 {
454 CustomFloat value;
455 static if (flags & Flags.signed)
456 value.sign = 0;
457 value.exponent = (flags & Flags.allowDenorm) != 0;
458 static if (flags & Flags.storeNormalized)
459 value.significand = 0;
460 else
461 value.significand = cast(T_sig) 1uL << (precision - 1);
462 return value;
463 }
464
465 /// Returns: real part
466 @property CustomFloat re() { return this; }
467
468 /// Returns: imaginary part
469 static @property CustomFloat im() { return CustomFloat(0.0f); }
470
471 /// Initialize from any `real` compatible type.
472 this(F)(F input) if (__traits(compiles, cast(real) input ))
473 {
474 this = input;
475 }
476
477 /// Self assignment
478 void opAssign(F:CustomFloat)(F input)
479 {
480 static if (flags & Flags.signed)
481 sign = input.sign;
482 exponent = input.exponent;
483 significand = input.significand;
484 }
485
486 /// Assigns from any `real` compatible type.
487 void opAssign(F)(F input)
488 if (__traits(compiles, cast(real) input))
489 {
490 import std.conv : text;
491
492 static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
493 auto value = ToBinary!(Unqual!F)(input);
494 else
495 auto value = ToBinary!(real )(input);
496
497 // Assign the sign bit
498 static if (~flags & Flags.signed)
499 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0),
500 "Incorrectly signed floating point value assigned to a " ~
501 typeof(this).stringof ~ " (no sign support).");
502 else
503 sign = value.sign;
504
505 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent;
506 CommonType!(T_sig, value.T_sig ) sig = value.significand;
507
508 value.toNormalized(sig,exp);
509 fromNormalized(sig,exp);
510
511 assert(exp <= exponent_max, text(typeof(this).stringof ~
512 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig));
513 assert(sig <= significand_max, text(typeof(this).stringof ~
514 " significand too large: ",sig," > ",significand_max,
515 "\t",input,"\t",exp," ",exponent_max));
516 exponent = cast(T_exp) exp;
517 significand = cast(T_sig) sig;
518 }
519
520 /// Fetches the stored value either as a `float`, `double` or `real`.
521 @property F get(F)()
522 if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
523 {
524 import std.conv : text;
525
526 ToBinary!F result;
527
528 static if (flags&Flags.signed)
529 result.sign = sign;
530 else
531 result.sign = (flags&flags.negativeUnsigned) > 0;
532
533 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction
534 CommonType!(T_sig, result.get.T_sig ) sig = significand;
535
536 toNormalized(sig,exp);
537 result.fromNormalized(sig,exp);
538 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) );
539 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) );
540 result.exponent = cast(result.get.T_exp) exp;
541 result.significand = cast(result.get.T_sig) sig;
542 return result.set;
543 }
544
545 ///ditto
546 alias opCast = get;
547
548 /// Convert the CustomFloat to a real and perform the relevant operator on the result
549 real opUnary(string op)()
550 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--")
551 {
552 static if (op=="++" || op=="--")
553 {
554 auto result = get!real;
555 this = mixin(op~`result`);
556 return result;
557 }
558 else
559 return mixin(op~`get!real`);
560 }
561
562 /// ditto
563 // Define an opBinary `CustomFloat op CustomFloat` so that those below
564 // do not match equally, which is disallowed by the spec:
565 // https://dlang.org/spec/operatoroverloading.html#binary
566 real opBinary(string op,T)(T b)
567 if (__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
568 {
569 return mixin(`get!real`~op~`b.get!real`);
570 }
571
572 /// ditto
573 real opBinary(string op,T)(T b)
574 if ( __traits(compiles, mixin(`get!real`~op~`b`)) &&
575 !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
576 {
577 return mixin(`get!real`~op~`b`);
578 }
579
580 /// ditto
581 real opBinaryRight(string op,T)(T a)
582 if ( __traits(compiles, mixin(`a`~op~`get!real`)) &&
583 !__traits(compiles, mixin(`get!real`~op~`b`)) &&
584 !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
585 {
586 return mixin(`a`~op~`get!real`);
587 }
588
589 /// ditto
590 int opCmp(T)(auto ref T b)
591 if (__traits(compiles, cast(real) b))
592 {
593 auto x = get!real;
594 auto y = cast(real) b;
595 return (x >= y)-(x <= y);
596 }
597
598 /// ditto
599 void opOpAssign(string op, T)(auto ref T b)
600 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`)))
601 {
602 return mixin(`this = this `~op~` cast(real) b`);
603 }
604
605 /// ditto
606 template toString()
607 {
608 import std.format.spec : FormatSpec;
609 import std.format.write : formatValue;
610 // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737.
611 void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt)
612 {
613 sink.formatValue(get!real, fmt);
614 }
615 }
616 }
617
618 @safe unittest
619 {
620 import std.meta;
621 alias FPTypes =
622 AliasSeq!(
623 CustomFloat!(5, 10),
624 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
625 CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
626 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed)
627 );
628
629 foreach (F; FPTypes)
630 {
631 auto x = F(0.125);
632 assert(x.get!float == 0.125F);
633 assert(x.get!double == 0.125);
634
635 x -= 0.0625;
636 assert(x.get!float == 0.0625F);
637 assert(x.get!double == 0.0625);
638
639 x *= 2;
640 assert(x.get!float == 0.125F);
641 assert(x.get!double == 0.125);
642
643 x /= 4;
644 assert(x.get!float == 0.03125);
645 assert(x.get!double == 0.03125);
646
647 x = 0.5;
648 x ^^= 4;
649 assert(x.get!float == 1 / 16.0F);
650 assert(x.get!double == 1 / 16.0);
651 }
652 }
653
654 @system unittest
655 {
656 // @system due to to!string(CustomFloat)
657 import std.conv;
658 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125);
659 assert(y.to!string == "0.125");
660 }
661
662 @safe unittest
663 {
664 alias cf = CustomFloat!(5, 2);
665
666 auto a = cf.infinity;
667 assert(a.sign == 0);
668 assert(a.exponent == 3);
669 assert(a.significand == 0);
670
671 auto b = cf.nan;
672 assert(b.exponent == 3);
673 assert(b.significand != 0);
674
675 assert(cf.dig == 1);
676
677 auto c = cf.epsilon;
678 assert(c.sign == 0);
679 assert(c.exponent == 0);
680 assert(c.significand == 1);
681
682 assert(cf.mant_dig == 6);
683
684 assert(cf.max_10_exp == 0);
685 assert(cf.max_exp == 2);
686 assert(cf.min_10_exp == 0);
687 assert(cf.min_exp == 1);
688
689 auto d = cf.max;
690 assert(d.sign == 0);
691 assert(d.exponent == 2);
692 assert(d.significand == 31);
693
694 auto e = cf.min_normal;
695 assert(e.sign == 0);
696 assert(e.exponent == 1);
697 assert(e.significand == 0);
698
699 assert(e.re == e);
700 assert(e.im == cf(0.0));
701 }
702
703 // check whether CustomFloats identical to float/double behave like float/double
704 @safe unittest
705 {
706 import std.conv : to;
707
708 alias myFloat = CustomFloat!(23, 8);
709
710 static assert(myFloat.dig == float.dig);
711 static assert(myFloat.mant_dig == float.mant_dig);
712 assert(myFloat.max_10_exp == float.max_10_exp);
713 static assert(myFloat.max_exp == float.max_exp);
714 assert(myFloat.min_10_exp == float.min_10_exp);
715 static assert(myFloat.min_exp == float.min_exp);
716 assert(to!float(myFloat.epsilon) == float.epsilon);
717 assert(to!float(myFloat.max) == float.max);
718 assert(to!float(myFloat.min_normal) == float.min_normal);
719
720 alias myDouble = CustomFloat!(52, 11);
721
722 static assert(myDouble.dig == double.dig);
723 static assert(myDouble.mant_dig == double.mant_dig);
724 assert(myDouble.max_10_exp == double.max_10_exp);
725 static assert(myDouble.max_exp == double.max_exp);
726 assert(myDouble.min_10_exp == double.min_10_exp);
727 static assert(myDouble.min_exp == double.min_exp);
728 assert(to!double(myDouble.epsilon) == double.epsilon);
729 assert(to!double(myDouble.max) == double.max);
730 assert(to!double(myDouble.min_normal) == double.min_normal);
731 }
732
733 // testing .dig
734 @safe unittest
735 {
736 static assert(CustomFloat!(1, 6).dig == 0);
737 static assert(CustomFloat!(9, 6).dig == 2);
738 static assert(CustomFloat!(10, 5).dig == 3);
739 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2);
740 static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3);
741 static assert(CustomFloat!(64, 7).dig == 19);
742 }
743
744 // testing .mant_dig
745 @safe unittest
746 {
747 static assert(CustomFloat!(10, 5).mant_dig == 11);
748 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10);
749 }
750
751 // testing .max_exp
752 @safe unittest
753 {
754 static assert(CustomFloat!(1, 6).max_exp == 2^^5);
755 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5);
756 static assert(CustomFloat!(5, 10).max_exp == 2^^9);
757 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9);
758 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5);
759 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9);
760 }
761
762 // testing .min_exp
763 @safe unittest
764 {
765 static assert(CustomFloat!(1, 6).min_exp == -2^^5+3);
766 static assert(CustomFloat!(5, 10).min_exp == -2^^9+3);
767 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1);
768 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1);
769 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2);
770 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2);
771 static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2);
772 static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2);
773 }
774
775 // testing .max_10_exp
776 @safe unittest
777 {
778 assert(CustomFloat!(1, 6).max_10_exp == 9);
779 assert(CustomFloat!(5, 10).max_10_exp == 154);
780 assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9);
781 assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154);
782 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9);
783 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154);
784 }
785
786 // testing .min_10_exp
787 @safe unittest
788 {
789 assert(CustomFloat!(1, 6).min_10_exp == -9);
790 assert(CustomFloat!(5, 10).min_10_exp == -153);
791 assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9);
792 assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154);
793 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9);
794 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153);
795 assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9);
796 assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153);
797 }
798
799 // testing .epsilon
800 @safe unittest
801 {
802 assert(CustomFloat!(1,6).epsilon.sign == 0);
803 assert(CustomFloat!(1,6).epsilon.exponent == 30);
804 assert(CustomFloat!(1,6).epsilon.significand == 0);
805 assert(CustomFloat!(2,5).epsilon.sign == 0);
806 assert(CustomFloat!(2,5).epsilon.exponent == 13);
807 assert(CustomFloat!(2,5).epsilon.significand == 0);
808 assert(CustomFloat!(3,4).epsilon.sign == 0);
809 assert(CustomFloat!(3,4).epsilon.exponent == 4);
810 assert(CustomFloat!(3,4).epsilon.significand == 0);
811 // the following epsilons are only available, when denormalized numbers are allowed:
812 assert(CustomFloat!(4,3).epsilon.sign == 0);
813 assert(CustomFloat!(4,3).epsilon.exponent == 0);
814 assert(CustomFloat!(4,3).epsilon.significand == 4);
815 assert(CustomFloat!(5,2).epsilon.sign == 0);
816 assert(CustomFloat!(5,2).epsilon.exponent == 0);
817 assert(CustomFloat!(5,2).epsilon.significand == 1);
818 }
819
820 // testing .max
821 @safe unittest
822 {
823 static assert(CustomFloat!(5,2).max.sign == 0);
824 static assert(CustomFloat!(5,2).max.exponent == 2);
825 static assert(CustomFloat!(5,2).max.significand == 31);
826 static assert(CustomFloat!(4,3).max.sign == 0);
827 static assert(CustomFloat!(4,3).max.exponent == 6);
828 static assert(CustomFloat!(4,3).max.significand == 15);
829 static assert(CustomFloat!(3,4).max.sign == 0);
830 static assert(CustomFloat!(3,4).max.exponent == 14);
831 static assert(CustomFloat!(3,4).max.significand == 7);
832 static assert(CustomFloat!(2,5).max.sign == 0);
833 static assert(CustomFloat!(2,5).max.exponent == 30);
834 static assert(CustomFloat!(2,5).max.significand == 3);
835 static assert(CustomFloat!(1,6).max.sign == 0);
836 static assert(CustomFloat!(1,6).max.exponent == 62);
837 static assert(CustomFloat!(1,6).max.significand == 1);
838 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31);
839 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7);
840 }
841
842 // testing .min_normal
843 @safe unittest
844 {
845 static assert(CustomFloat!(5,2).min_normal.sign == 0);
846 static assert(CustomFloat!(5,2).min_normal.exponent == 1);
847 static assert(CustomFloat!(5,2).min_normal.significand == 0);
848 static assert(CustomFloat!(4,3).min_normal.sign == 0);
849 static assert(CustomFloat!(4,3).min_normal.exponent == 1);
850 static assert(CustomFloat!(4,3).min_normal.significand == 0);
851 static assert(CustomFloat!(3,4).min_normal.sign == 0);
852 static assert(CustomFloat!(3,4).min_normal.exponent == 1);
853 static assert(CustomFloat!(3,4).min_normal.significand == 0);
854 static assert(CustomFloat!(2,5).min_normal.sign == 0);
855 static assert(CustomFloat!(2,5).min_normal.exponent == 1);
856 static assert(CustomFloat!(2,5).min_normal.significand == 0);
857 static assert(CustomFloat!(1,6).min_normal.sign == 0);
858 static assert(CustomFloat!(1,6).min_normal.exponent == 1);
859 static assert(CustomFloat!(1,6).min_normal.significand == 0);
860 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0);
861 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4);
862 }
863
864 @safe unittest
865 {
866 import std.math.traits : isNaN;
867
868 alias cf = CustomFloat!(5, 2);
869
870 auto f = cf.nan.get!float();
871 assert(isNaN(f));
872
873 cf a;
874 a = real.max;
875 assert(a == cf.infinity);
876
877 a = 0.015625;
878 assert(a.exponent == 0);
879 assert(a.significand == 0);
880
881 a = 0.984375;
882 assert(a.exponent == 1);
883 assert(a.significand == 0);
884 }
885
886 @system unittest
887 {
888 import std.exception : assertThrown;
889 import core.exception : AssertError;
890
891 alias cf = CustomFloat!(3, 5, CustomFloatFlags.none);
892
893 cf a;
894 assertThrown!AssertError(a = real.max);
895 }
896
897 @system unittest
898 {
899 import std.exception : assertThrown;
900 import core.exception : AssertError;
901
902 alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan);
903
904 cf a;
905 assertThrown!AssertError(a = real.max);
906 }
907
908 @system unittest
909 {
910 import std.exception : assertThrown;
911 import core.exception : AssertError;
912
913 alias cf = CustomFloat!(24, 8, CustomFloatFlags.none);
914
915 cf a;
916 assertThrown!AssertError(a = float.infinity);
917 }
918
919 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc
920 {
921 // Restrictions from bitfield
922 // due to CustomFloat!80 support hack precision with 64 bits is handled specially
923 auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision);
924 if (length != 8 && length != 16 && length != 32 && length != 64) return false;
925
926 // mantissa needs to fit into real mantissa
927 if (precision > real.mant_dig - 1 && precision != 64) return false;
928
929 // exponent needs to fit into real exponent
930 if (1L << exponentWidth - 1 > real.max_exp) return false;
931
932 // mantissa should have at least one bit
933 if (precision == 0) return false;
934
935 // exponent should have at least one bit, in some cases two
936 if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false;
937
938 return true;
939 }
940
941 @safe pure nothrow @nogc unittest
942 {
943 assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee));
944 assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none));
945 assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee));
946 assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee));
947 assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee));
948 assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee));
949 assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee));
950 assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee));
951 assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee));
952 assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none));
953 assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none));
954 }
955
956 /**
957 Defines the fastest type to use when storing temporaries of a
958 calculation intended to ultimately yield a result of type `F`
959 (where `F` must be one of `float`, `double`, or $(D
960 real)). When doing a multi-step computation, you may want to store
961 intermediate results as `FPTemporary!F`.
962
963 The necessity of `FPTemporary` stems from the optimized
964 floating-point operations and registers present in virtually all
965 processors. When adding numbers in the example above, the addition may
966 in fact be done in `real` precision internally. In that case,
967 storing the intermediate `result` in $(D double format) is not only
968 less precise, it is also (surprisingly) slower, because a conversion
969 from `real` to `double` is performed every pass through the
970 loop. This being a lose-lose situation, `FPTemporary!F` has been
971 defined as the $(I fastest) type to use for calculations at precision
972 `F`. There is no need to define a type for the $(I most accurate)
973 calculations, as that is always `real`.
974
975 Finally, there is no guarantee that using `FPTemporary!F` will
976 always be fastest, as the speed of floating-point calculations depends
977 on very many factors.
978 */
979 template FPTemporary(F)
980 if (isFloatingPoint!F)
981 {
982 version (X86)
983 alias FPTemporary = real;
984 else
985 alias FPTemporary = Unqual!F;
986 }
987
988 ///
989 @safe unittest
990 {
991 import std.math.operations : isClose;
992
993 // Average numbers in an array
994 double avg(in double[] a)
995 {
996 if (a.length == 0) return 0;
997 FPTemporary!double result = 0;
998 foreach (e; a) result += e;
999 return result / a.length;
1000 }
1001
1002 auto a = [1.0, 2.0, 3.0];
1003 assert(isClose(avg(a), 2));
1004 }
1005
1006 /**
1007 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a
1008 root of the function `fun` starting from points $(D [xn_1, x_n])
1009 (ideally close to the root). `Num` may be `float`, `double`,
1010 or `real`.
1011 */
1012 template secantMethod(alias fun)
1013 {
1014 import std.functional : unaryFun;
1015 Num secantMethod(Num)(Num xn_1, Num xn)
1016 {
1017 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn;
1018 typeof(fxn) fxn_1;
1019
1020 xn = xn_1;
1021 while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d))
1022 {
1023 xn_1 = xn;
1024 xn -= d;
1025 fxn_1 = fxn;
1026 fxn = unaryFun!(fun)(xn);
1027 d *= -fxn / (fxn - fxn_1);
1028 }
1029 return xn;
1030 }
1031 }
1032
1033 ///
1034 @safe unittest
1035 {
1036 import std.math.operations : isClose;
1037 import std.math.trigonometry : cos;
1038
1039 float f(float x)
1040 {
1041 return cos(x) - x*x*x;
1042 }
1043 auto x = secantMethod!(f)(0f, 1f);
1044 assert(isClose(x, 0.865474));
1045 }
1046
1047 @system unittest
1048 {
1049 // @system because of __gshared stderr
1050 import std.stdio;
1051 scope(failure) stderr.writeln("Failure testing secantMethod");
1052 float f(float x)
1053 {
1054 return cos(x) - x*x*x;
1055 }
1056 immutable x = secantMethod!(f)(0f, 1f);
1057 assert(isClose(x, 0.865474));
1058 auto d = &f;
1059 immutable y = secantMethod!(d)(0f, 1f);
1060 assert(isClose(y, 0.865474));
1061 }
1062
1063
1064 /**
1065 * Return true if a and b have opposite sign.
1066 */
1067 private bool oppositeSigns(T1, T2)(T1 a, T2 b)
1068 {
1069 return signbit(a) != signbit(b);
1070 }
1071
1072 public:
1073
1074 /** Find a real root of a real function f(x) via bracketing.
1075 *
1076 * Given a function `f` and a range `[a .. b]` such that `f(a)`
1077 * and `f(b)` have opposite signs or at least one of them equals ±0,
1078 * returns the value of `x` in
1079 * the range which is closest to a root of `f(x)`. If `f(x)`
1080 * has more than one root in the range, one will be chosen
1081 * arbitrarily. If `f(x)` returns NaN, NaN will be returned;
1082 * otherwise, this algorithm is guaranteed to succeed.
1083 *
1084 * Uses an algorithm based on TOMS748, which uses inverse cubic
1085 * interpolation whenever possible, otherwise reverting to parabolic
1086 * or secant interpolation. Compared to TOMS748, this implementation
1087 * improves worst-case performance by a factor of more than 100, and
1088 * typical performance by a factor of 2. For 80-bit reals, most
1089 * problems require 8 to 15 calls to `f(x)` to achieve full machine
1090 * precision. The worst-case performance (pathological cases) is
1091 * approximately twice the number of bits.
1092 *
1093 * References: "On Enclosing Simple Roots of Nonlinear Equations",
1094 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61,
1095 * pp733-744 (1993). Fortran code available from $(HTTP
1096 * www.netlib.org,www.netlib.org) as algorithm TOMS478.
1097 *
1098 */
1099 T findRoot(T, DF, DT)(scope DF f, const T a, const T b,
1100 scope DT tolerance) //= (T a, T b) => false)
1101 if (
1102 isFloatingPoint!T &&
1103 is(typeof(tolerance(T.init, T.init)) : bool) &&
1104 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R
1105 )
1106 {
1107 immutable fa = f(a);
1108 if (fa == 0)
1109 return a;
1110 immutable fb = f(b);
1111 if (fb == 0)
1112 return b;
1113 immutable r = findRoot(f, a, b, fa, fb, tolerance);
1114 // Return the first value if it is smaller or NaN
1115 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1];
1116 }
1117
1118 ///ditto
1119 T findRoot(T, DF)(scope DF f, const T a, const T b)
1120 {
1121 return findRoot(f, a, b, (T a, T b) => false);
1122 }
1123
1124 /** Find root of a real function f(x) by bracketing, allowing the
1125 * termination condition to be specified.
1126 *
1127 * Params:
1128 *
1129 * f = Function to be analyzed
1130 *
1131 * ax = Left bound of initial range of `f` known to contain the
1132 * root.
1133 *
1134 * bx = Right bound of initial range of `f` known to contain the
1135 * root.
1136 *
1137 * fax = Value of `f(ax)`.
1138 *
1139 * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs.
1140 * (`f(ax)` and `f(bx)` are commonly known in advance.)
1141 *
1142 *
1143 * tolerance = Defines an early termination condition. Receives the
1144 * current upper and lower bounds on the root. The
1145 * delegate must return `true` when these bounds are
1146 * acceptable. If this function always returns `false`,
1147 * full machine precision will be achieved.
1148 *
1149 * Returns:
1150 *
1151 * A tuple consisting of two ranges. The first two elements are the
1152 * range (in `x`) of the root, while the second pair of elements
1153 * are the corresponding function values at those points. If an exact
1154 * root was found, both of the first two elements will contain the
1155 * root, and the second pair of elements will be 0.
1156 */
1157 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f,
1158 const T ax, const T bx, const R fax, const R fbx,
1159 scope DT tolerance) // = (T a, T b) => false)
1160 if (
1161 isFloatingPoint!T &&
1162 is(typeof(tolerance(T.init, T.init)) : bool) &&
1163 is(typeof(f(T.init)) == R) && isFloatingPoint!R
1164 )
1165 in
1166 {
1167 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN");
1168 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root.");
1169 }
1170 do
1171 {
1172 // Author: Don Clugston. This code is (heavily) modified from TOMS748
1173 // (www.netlib.org). The changes to improve the worst-cast performance are
1174 // entirely original.
1175
1176 T a, b, d; // [a .. b] is our current bracket. d is the third best guess.
1177 R fa, fb, fd; // Values of f at a, b, d.
1178 bool done = false; // Has a root been found?
1179
1180 // Allow ax and bx to be provided in reverse order
1181 if (ax <= bx)
1182 {
1183 a = ax; fa = fax;
1184 b = bx; fb = fbx;
1185 }
1186 else
1187 {
1188 a = bx; fa = fbx;
1189 b = ax; fb = fax;
1190 }
1191
1192 // Test the function at point c; update brackets accordingly
1193 void bracket(T c)
1194 {
1195 R fc = f(c);
1196 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN
1197 {
1198 a = c;
1199 fa = fc;
1200 d = c;
1201 fd = fc;
1202 done = true;
1203 return;
1204 }
1205
1206 // Determine new enclosing interval
1207 if (signbit(fa) != signbit(fc))
1208 {
1209 d = b;
1210 fd = fb;
1211 b = c;
1212 fb = fc;
1213 }
1214 else
1215 {
1216 d = a;
1217 fd = fa;
1218 a = c;
1219 fa = fc;
1220 }
1221 }
1222
1223 /* Perform a secant interpolation. If the result would lie on a or b, or if
1224 a and b differ so wildly in magnitude that the result would be meaningless,
1225 perform a bisection instead.
1226 */
1227 static T secant_interpolate(T a, T b, R fa, R fb)
1228 {
1229 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b)))
1230 {
1231 // Catastrophic cancellation
1232 if (a == 0)
1233 a = copysign(T(0), b);
1234 else if (b == 0)
1235 b = copysign(T(0), a);
1236 else if (signbit(a) != signbit(b))
1237 return 0;
1238 T c = ieeeMean(a, b);
1239 return c;
1240 }
1241 // avoid overflow
1242 if (b - a > T.max)
1243 return b / 2 + a / 2;
1244 if (fb - fa > R.max)
1245 return a - (b - a) / 2;
1246 T c = a - (fa / (fb - fa)) * (b - a);
1247 if (c == a || c == b)
1248 return (a + b) / 2;
1249 return c;
1250 }
1251
1252 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the
1253 quadratic polynomial interpolating f(x) at a, b, and d.
1254 Returns:
1255 The approximate zero in [a .. b] of the quadratic polynomial.
1256 */
1257 T newtonQuadratic(int numsteps)
1258 {
1259 // Find the coefficients of the quadratic polynomial.
1260 immutable T a0 = fa;
1261 immutable T a1 = (fb - fa)/(b - a);
1262 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
1263
1264 // Determine the starting point of newton steps.
1265 T c = oppositeSigns(a2, fa) ? a : b;
1266
1267 // start the safeguarded newton steps.
1268 foreach (int i; 0 .. numsteps)
1269 {
1270 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
1271 immutable T pdc = a1 + a2*((2 * c) - (a + b));
1272 if (pdc == 0)
1273 return a - a0 / a1;
1274 else
1275 c = c - pc / pdc;
1276 }
1277 return c;
1278 }
1279
1280 // On the first iteration we take a secant step:
1281 if (fa == 0 || fa.isNaN())
1282 {
1283 done = true;
1284 b = a;
1285 fb = fa;
1286 }
1287 else if (fb == 0 || fb.isNaN())
1288 {
1289 done = true;
1290 a = b;
1291 fa = fb;
1292 }
1293 else
1294 {
1295 bracket(secant_interpolate(a, b, fa, fb));
1296 }
1297
1298 // Starting with the second iteration, higher-order interpolation can
1299 // be used.
1300 int itnum = 1; // Iteration number
1301 int baditer = 1; // Num bisections to take if an iteration is bad.
1302 T c, e; // e is our fourth best guess
1303 R fe;
1304
1305 whileloop:
1306 while (!done && (b != nextUp(a)) && !tolerance(a, b))
1307 {
1308 T a0 = a, b0 = b; // record the brackets
1309
1310 // Do two higher-order (cubic or parabolic) interpolation steps.
1311 foreach (int QQ; 0 .. 2)
1312 {
1313 // Cubic inverse interpolation requires that
1314 // all four function values fa, fb, fd, and fe are distinct;
1315 // otherwise use quadratic interpolation.
1316 bool distinct = (fa != fb) && (fa != fd) && (fa != fe)
1317 && (fb != fd) && (fb != fe) && (fd != fe);
1318 // The first time, cubic interpolation is impossible.
1319 if (itnum<2) distinct = false;
1320 bool ok = distinct;
1321 if (distinct)
1322 {
1323 // Cubic inverse interpolation of f(x) at a, b, d, and e
1324 immutable q11 = (d - e) * fd / (fe - fd);
1325 immutable q21 = (b - d) * fb / (fd - fb);
1326 immutable q31 = (a - b) * fa / (fb - fa);
1327 immutable d21 = (b - d) * fd / (fd - fb);
1328 immutable d31 = (a - b) * fb / (fb - fa);
1329
1330 immutable q22 = (d21 - q11) * fb / (fe - fb);
1331 immutable q32 = (d31 - q21) * fa / (fd - fa);
1332 immutable d32 = (d31 - q21) * fd / (fd - fa);
1333 immutable q33 = (d32 - q22) * fa / (fe - fa);
1334 c = a + (q31 + q32 + q33);
1335 if (c.isNaN() || (c <= a) || (c >= b))
1336 {
1337 // DAC: If the interpolation predicts a or b, it's
1338 // probable that it's the actual root. Only allow this if
1339 // we're already close to the root.
1340 if (c == a && a - b != a)
1341 {
1342 c = nextUp(a);
1343 }
1344 else if (c == b && a - b != -b)
1345 {
1346 c = nextDown(b);
1347 }
1348 else
1349 {
1350 ok = false;
1351 }
1352 }
1353 }
1354 if (!ok)
1355 {
1356 // DAC: Alefeld doesn't explain why the number of newton steps
1357 // should vary.
1358 c = newtonQuadratic(distinct ? 3 : 2);
1359 if (c.isNaN() || (c <= a) || (c >= b))
1360 {
1361 // Failure, try a secant step:
1362 c = secant_interpolate(a, b, fa, fb);
1363 }
1364 }
1365 ++itnum;
1366 e = d;
1367 fe = fd;
1368 bracket(c);
1369 if (done || ( b == nextUp(a)) || tolerance(a, b))
1370 break whileloop;
1371 if (itnum == 2)
1372 continue whileloop;
1373 }
1374
1375 // Now we take a double-length secant step:
1376 T u;
1377 R fu;
1378 if (fabs(fa) < fabs(fb))
1379 {
1380 u = a;
1381 fu = fa;
1382 }
1383 else
1384 {
1385 u = b;
1386 fu = fb;
1387 }
1388 c = u - 2 * (fu / (fb - fa)) * (b - a);
1389
1390 // DAC: If the secant predicts a value equal to an endpoint, it's
1391 // probably false.
1392 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2)
1393 {
1394 if ((a-b) == a || (b-a) == b)
1395 {
1396 if ((a>0 && b<0) || (a<0 && b>0))
1397 c = 0;
1398 else
1399 {
1400 if (a == 0)
1401 c = ieeeMean(copysign(T(0), b), b);
1402 else if (b == 0)
1403 c = ieeeMean(copysign(T(0), a), a);
1404 else
1405 c = ieeeMean(a, b);
1406 }
1407 }
1408 else
1409 {
1410 c = a + (b - a) / 2;
1411 }
1412 }
1413 e = d;
1414 fe = fd;
1415 bracket(c);
1416 if (done || (b == nextUp(a)) || tolerance(a, b))
1417 break;
1418
1419 // IMPROVE THE WORST-CASE PERFORMANCE
1420 // We must ensure that the bounds reduce by a factor of 2
1421 // in binary space! every iteration. If we haven't achieved this
1422 // yet, or if we don't yet know what the exponent is,
1423 // perform a binary chop.
1424
1425 if ((a == 0 || b == 0 ||
1426 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
1427 && (b - a) < T(0.25) * (b0 - a0))
1428 {
1429 baditer = 1;
1430 continue;
1431 }
1432
1433 // DAC: If this happens on consecutive iterations, we probably have a
1434 // pathological function. Perform a number of bisections equal to the
1435 // total number of consecutive bad iterations.
1436
1437 if ((b - a) < T(0.25) * (b0 - a0))
1438 baditer = 1;
1439 foreach (int QQ; 0 .. baditer)
1440 {
1441 e = d;
1442 fe = fd;
1443
1444 T w;
1445 if ((a>0 && b<0) || (a<0 && b>0))
1446 w = 0;
1447 else
1448 {
1449 T usea = a;
1450 T useb = b;
1451 if (a == 0)
1452 usea = copysign(T(0), b);
1453 else if (b == 0)
1454 useb = copysign(T(0), a);
1455 w = ieeeMean(usea, useb);
1456 }
1457 bracket(w);
1458 }
1459 ++baditer;
1460 }
1461 return Tuple!(T, T, R, R)(a, b, fa, fb);
1462 }
1463
1464 ///ditto
1465 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f,
1466 const T ax, const T bx, const R fax, const R fbx)
1467 {
1468 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false);
1469 }
1470
1471 ///ditto
1472 T findRoot(T, R)(scope R delegate(T) f, const T a, const T b,
1473 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false)
1474 {
1475 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance);
1476 }
1477
1478 @safe nothrow unittest
1479 {
1480 int numProblems = 0;
1481 int numCalls;
1482
1483 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure
1484 {
1485 //numCalls=0;
1486 //++numProblems;
1487 assert(!x1.isNaN() && !x2.isNaN());
1488 assert(signbit(f(x1)) != signbit(f(x2)));
1489 auto result = findRoot(f, x1, x2, f(x1), f(x2),
1490 (real lo, real hi) { return false; });
1491
1492 auto flo = f(result[0]);
1493 auto fhi = f(result[1]);
1494 if (flo != 0)
1495 {
1496 assert(oppositeSigns(flo, fhi));
1497 }
1498 }
1499
1500 // Test functions
1501 real cubicfn(real x) @nogc @safe nothrow pure
1502 {
1503 //++numCalls;
1504 if (x>float.max)
1505 x = float.max;
1506 if (x<-float.max)
1507 x = -float.max;
1508 // This has a single real root at -59.286543284815
1509 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
1510 }
1511 // Test a function with more than one root.
1512 real multisine(real x) { ++numCalls; return sin(x); }
1513 testFindRoot( &multisine, 6, 90);
1514 testFindRoot(&cubicfn, -100, 100);
1515 testFindRoot( &cubicfn, -double.max, real.max);
1516
1517
1518 /* Tests from the paper:
1519 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra,
1520 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
1521 */
1522 // Parameters common to many alefeld tests.
1523 int n;
1524 real ale_a, ale_b;
1525
1526 int powercalls = 0;
1527
1528 real power(real x)
1529 {
1530 ++powercalls;
1531 ++numCalls;
1532 return pow(x, n) + double.min_normal;
1533 }
1534 int [] power_nvals = [3, 5, 7, 9, 19, 25];
1535 // Alefeld paper states that pow(x,n) is a very poor case, where bisection
1536 // outperforms his method, and gives total numcalls =
1537 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit),
1538 // 2624 for brent (6.8/bit)
1539 // ... but that is for double, not real80.
1540 // This poor performance seems mainly due to catastrophic cancellation,
1541 // which is avoided here by the use of ieeeMean().
1542 // I get: 231 (0.48/bit).
1543 // IE this is 10X faster in Alefeld's worst case
1544 numProblems=0;
1545 foreach (k; power_nvals)
1546 {
1547 n = k;
1548 testFindRoot(&power, -1, 10);
1549 }
1550
1551 int powerProblems = numProblems;
1552
1553 // Tests from Alefeld paper
1554
1555 int [9] alefeldSums;
1556 real alefeld0(real x)
1557 {
1558 ++alefeldSums[0];
1559 ++numCalls;
1560 real q = sin(x) - x/2;
1561 for (int i=1; i<20; ++i)
1562 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
1563 return q;
1564 }
1565 real alefeld1(real x)
1566 {
1567 ++numCalls;
1568 ++alefeldSums[1];
1569 return ale_a*x + exp(ale_b * x);
1570 }
1571 real alefeld2(real x)
1572 {
1573 ++numCalls;
1574 ++alefeldSums[2];
1575 return pow(x, n) - ale_a;
1576 }
1577 real alefeld3(real x)
1578 {
1579 ++numCalls;
1580 ++alefeldSums[3];
1581 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
1582 }
1583 real alefeld4(real x)
1584 {
1585 ++numCalls;
1586 ++alefeldSums[4];
1587 return x*x - pow(1-x, n);
1588 }
1589 real alefeld5(real x)
1590 {
1591 ++numCalls;
1592 ++alefeldSums[5];
1593 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
1594 }
1595 real alefeld6(real x)
1596 {
1597 ++numCalls;
1598 ++alefeldSums[6];
1599 return exp(-n*x)*(x-1.01L) + pow(x, n);
1600 }
1601 real alefeld7(real x)
1602 {
1603 ++numCalls;
1604 ++alefeldSums[7];
1605 return (n*x-1)/((n-1)*x);
1606 }
1607
1608 numProblems=0;
1609 testFindRoot(&alefeld0, PI_2, PI);
1610 for (n=1; n <= 10; ++n)
1611 {
1612 testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
1613 }
1614 ale_a = -40; ale_b = -1;
1615 testFindRoot(&alefeld1, -9, 31);
1616 ale_a = -100; ale_b = -2;
1617 testFindRoot(&alefeld1, -9, 31);
1618 ale_a = -200; ale_b = -3;
1619 testFindRoot(&alefeld1, -9, 31);
1620 int [] nvals_3 = [1, 2, 5, 10, 15, 20];
1621 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
1622 int [] nvals_6 = [1, 5, 10, 15, 20];
1623 int [] nvals_7 = [2, 5, 15, 20];
1624
1625 for (int i=4; i<12; i+=2)
1626 {
1627 n = i;
1628 ale_a = 0.2;
1629 testFindRoot(&alefeld2, 0, 5);
1630 ale_a=1;
1631 testFindRoot(&alefeld2, 0.95, 4.05);
1632 testFindRoot(&alefeld2, 0, 1.5);
1633 }
1634 foreach (i; nvals_3)
1635 {
1636 n=i;
1637 testFindRoot(&alefeld3, 0, 1);
1638 }
1639 foreach (i; nvals_3)
1640 {
1641 n=i;
1642 testFindRoot(&alefeld4, 0, 1);
1643 }
1644 foreach (i; nvals_5)
1645 {
1646 n=i;
1647 testFindRoot(&alefeld5, 0, 1);
1648 }
1649 foreach (i; nvals_6)
1650 {
1651 n=i;
1652 testFindRoot(&alefeld6, 0, 1);
1653 }
1654 foreach (i; nvals_7)
1655 {
1656 n=i;
1657 testFindRoot(&alefeld7, 0.01L, 1);
1658 }
1659 real worstcase(real x)
1660 {
1661 ++numCalls;
1662 return x<0.3*real.max? -0.999e-3 : 1.0;
1663 }
1664 testFindRoot(&worstcase, -real.max, real.max);
1665
1666 // just check that the double + float cases compile
1667 findRoot((double x){ return 0.0; }, -double.max, double.max);
1668 findRoot((float x){ return 0.0f; }, -float.max, float.max);
1669
1670 /*
1671 int grandtotal=0;
1672 foreach (calls; alefeldSums)
1673 {
1674 grandtotal+=calls;
1675 }
1676 grandtotal-=2*numProblems;
1677 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n",
1678 grandtotal, (1.0*grandtotal)/numProblems);
1679 powercalls -= 2*powerProblems;
1680 printf("POWER TOTAL = %d avg = %f ", powercalls,
1681 (1.0*powercalls)/powerProblems);
1682 */
1683 // https://issues.dlang.org/show_bug.cgi?id=14231
1684 auto xp = findRoot((float x) => x, 0f, 1f);
1685 auto xn = findRoot((float x) => x, -1f, -0f);
1686 }
1687
1688 //regression control
1689 @system unittest
1690 {
1691 // @system due to the case in the 2nd line
1692 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init)));
1693 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init)));
1694 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init)));
1695 }
1696
1697 /++
1698 Find a real minimum of a real function `f(x)` via bracketing.
1699 Given a function `f` and a range `(ax .. bx)`,
1700 returns the value of `x` in the range which is closest to a minimum of `f(x)`.
1701 `f` is never evaluted at the endpoints of `ax` and `bx`.
1702 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily.
1703 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned;
1704 otherwise, this algorithm is guaranteed to succeed.
1705
1706 Params:
1707 f = Function to be analyzed
1708 ax = Left bound of initial range of f known to contain the minimum.
1709 bx = Right bound of initial range of f known to contain the minimum.
1710 relTolerance = Relative tolerance.
1711 absTolerance = Absolute tolerance.
1712
1713 Preconditions:
1714 `ax` and `bx` shall be finite reals. $(BR)
1715 `relTolerance` shall be normal positive real. $(BR)
1716 `absTolerance` shall be normal positive real no less then `T.epsilon*2`.
1717
1718 Returns:
1719 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`.
1720
1721 The method used is a combination of golden section search and
1722 successive parabolic interpolation. Convergence is never much slower
1723 than that for a Fibonacci search.
1724
1725 References:
1726 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973)
1727
1728 See_Also: $(LREF findRoot), $(REF isNormal, std,math)
1729 +/
1730 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error")
1731 findLocalMin(T, DF)(
1732 scope DF f,
1733 const T ax,
1734 const T bx,
1735 const T relTolerance = sqrt(T.epsilon),
1736 const T absTolerance = sqrt(T.epsilon),
1737 )
1738 if (isFloatingPoint!T
1739 && __traits(compiles, {T _ = DF.init(T.init);}))
1740 in
1741 {
1742 assert(isFinite(ax), "ax is not finite");
1743 assert(isFinite(bx), "bx is not finite");
1744 assert(isNormal(relTolerance), "relTolerance is not normal floating point number");
1745 assert(isNormal(absTolerance), "absTolerance is not normal floating point number");
1746 assert(relTolerance >= 0, "absTolerance is not positive");
1747 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`");
1748 }
1749 out (result)
1750 {
1751 assert(isFinite(result.x));
1752 }
1753 do
1754 {
1755 alias R = Unqual!(CommonType!(ReturnType!DF, T));
1756 // c is the squared inverse of the golden ratio
1757 // (3 - sqrt(5))/2
1758 // Value obtained from Wolfram Alpha.
1759 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L;
1760 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L;
1761 R tolerance;
1762 T a = ax > bx ? bx : ax;
1763 T b = ax > bx ? ax : bx;
1764 // sequence of declarations suitable for SIMD instructions
1765 T v = a * cm1 + b * c;
1766 assert(isFinite(v));
1767 R fv = f(v);
1768 if (isNaN(fv) || fv == -T.infinity)
1769 {
1770 return typeof(return)(v, fv, T.init);
1771 }
1772 T w = v;
1773 R fw = fv;
1774 T x = v;
1775 R fx = fv;
1776 size_t i;
1777 for (R d = 0, e = 0;;)
1778 {
1779 i++;
1780 T m = (a + b) / 2;
1781 // This fix is not part of the original algorithm
1782 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R.
1783 {
1784 m = a / 2 + b / 2;
1785 if (!isFinite(m)) // fast-math compiler switch is enabled
1786 {
1787 //SIMD instructions can be used by compiler, do not reduce declarations
1788 int a_exp = void;
1789 int b_exp = void;
1790 immutable an = frexp(a, a_exp);
1791 immutable bn = frexp(b, b_exp);
1792 immutable am = ldexp(an, a_exp-1);
1793 immutable bm = ldexp(bn, b_exp-1);
1794 m = am + bm;
1795 if (!isFinite(m)) // wrong input: constraints are disabled in release mode
1796 {
1797 return typeof(return).init;
1798 }
1799 }
1800 }
1801 tolerance = absTolerance * fabs(x) + relTolerance;
1802 immutable t2 = tolerance * 2;
1803 // check stopping criterion
1804 if (!(fabs(x - m) > t2 - (b - a) / 2))
1805 {
1806 break;
1807 }
1808 R p = 0;
1809 R q = 0;
1810 R r = 0;
1811 // fit parabola
1812 if (fabs(e) > tolerance)
1813 {
1814 immutable xw = x - w;
1815 immutable fxw = fx - fw;
1816 immutable xv = x - v;
1817 immutable fxv = fx - fv;
1818 immutable xwfxv = xw * fxv;
1819 immutable xvfxw = xv * fxw;
1820 p = xv * xvfxw - xw * xwfxv;
1821 q = (xvfxw - xwfxv) * 2;
1822 if (q > 0)
1823 p = -p;
1824 else
1825 q = -q;
1826 r = e;
1827 e = d;
1828 }
1829 T u;
1830 // a parabolic-interpolation step
1831 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x))
1832 {
1833 d = p / q;
1834 u = x + d;
1835 // f must not be evaluated too close to a or b
1836 if (u - a < t2 || b - u < t2)
1837 d = x < m ? tolerance : -tolerance;
1838 }
1839 // a golden-section step
1840 else
1841 {
1842 e = (x < m ? b : a) - x;
1843 d = c * e;
1844 }
1845 // f must not be evaluated too close to x
1846 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance);
1847 immutable fu = f(u);
1848 if (isNaN(fu) || fu == -T.infinity)
1849 {
1850 return typeof(return)(u, fu, T.init);
1851 }
1852 // update a, b, v, w, and x
1853 if (fu <= fx)
1854 {
1855 (u < x ? b : a) = x;
1856 v = w; fv = fw;
1857 w = x; fw = fx;
1858 x = u; fx = fu;
1859 }
1860 else
1861 {
1862 (u < x ? a : b) = u;
1863 if (fu <= fw || w == x)
1864 {
1865 v = w; fv = fw;
1866 w = u; fw = fu;
1867 }
1868 else if (fu <= fv || v == x || v == w)
1869 { // do not remove this braces
1870 v = u; fv = fu;
1871 }
1872 }
1873 }
1874 return typeof(return)(x, fx, tolerance * 3);
1875 }
1876
1877 ///
1878 @safe unittest
1879 {
1880 import std.math.operations : isClose;
1881
1882 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7);
1883 assert(ret.x.isClose(4.0));
1884 assert(ret.y.isClose(0.0, 0.0, 1e-10));
1885 }
1886
1887 @safe unittest
1888 {
1889 import std.meta : AliasSeq;
1890 static foreach (T; AliasSeq!(double, float, real))
1891 {
1892 {
1893 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7);
1894 assert(ret.x.isClose(T(4)));
1895 assert(ret.y.isClose(T(0), 0.0, T.epsilon));
1896 }
1897 {
1898 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon);
1899 assert(isClose(ret.x, T(1)));
1900 assert(isClose(ret.y, T(0), 0.0, T.epsilon));
1901 assert(ret.error <= 10 * T.epsilon);
1902 }
1903 {
1904 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon);
1905 assert(!ret.x.isNaN);
1906 assert(ret.y.isNaN);
1907 assert(ret.error.isNaN);
1908 }
1909 {
1910 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon);
1911 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal));
1912 assert(ret.x >= 0 && ret.x <= ret.error);
1913 }
1914 {
1915 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon);
1916 assert(ret.y < -18);
1917 assert(ret.error < 5e-08);
1918 assert(ret.x >= 0 && ret.x <= ret.error);
1919 }
1920 {
1921 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon);
1922 assert(ret.x.fabs.isClose(T(1)));
1923 assert(ret.y.fabs.isClose(T(1)));
1924 assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon));
1925 }
1926 }
1927 }
1928
1929 /**
1930 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance,
1931 Euclidean distance) between input ranges `a` and
1932 `b`. The two ranges must have the same length. The three-parameter
1933 version stops computation as soon as the distance is greater than or
1934 equal to `limit` (this is useful to save computation if a small
1935 distance is sought).
1936 */
1937 CommonType!(ElementType!(Range1), ElementType!(Range2))
1938 euclideanDistance(Range1, Range2)(Range1 a, Range2 b)
1939 if (isInputRange!(Range1) && isInputRange!(Range2))
1940 {
1941 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1942 static if (haveLen) assert(a.length == b.length);
1943 Unqual!(typeof(return)) result = 0;
1944 for (; !a.empty; a.popFront(), b.popFront())
1945 {
1946 immutable t = a.front - b.front;
1947 result += t * t;
1948 }
1949 static if (!haveLen) assert(b.empty);
1950 return sqrt(result);
1951 }
1952
1953 /// Ditto
1954 CommonType!(ElementType!(Range1), ElementType!(Range2))
1955 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit)
1956 if (isInputRange!(Range1) && isInputRange!(Range2))
1957 {
1958 limit *= limit;
1959 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1960 static if (haveLen) assert(a.length == b.length);
1961 Unqual!(typeof(return)) result = 0;
1962 for (; ; a.popFront(), b.popFront())
1963 {
1964 if (a.empty)
1965 {
1966 static if (!haveLen) assert(b.empty);
1967 break;
1968 }
1969 immutable t = a.front - b.front;
1970 result += t * t;
1971 if (result >= limit) break;
1972 }
1973 return sqrt(result);
1974 }
1975
1976 @safe unittest
1977 {
1978 import std.meta : AliasSeq;
1979 static foreach (T; AliasSeq!(double, const double, immutable double))
1980 {{
1981 T[] a = [ 1.0, 2.0, ];
1982 T[] b = [ 4.0, 6.0, ];
1983 assert(euclideanDistance(a, b) == 5);
1984 assert(euclideanDistance(a, b, 6) == 5);
1985 assert(euclideanDistance(a, b, 5) == 5);
1986 assert(euclideanDistance(a, b, 4) == 5);
1987 assert(euclideanDistance(a, b, 2) == 3);
1988 }}
1989 }
1990
1991 /**
1992 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product,
1993 dot product) of input ranges `a` and $(D
1994 b). The two ranges must have the same length. If both ranges define
1995 length, the check is done once; otherwise, it is done at each
1996 iteration.
1997 */
1998 CommonType!(ElementType!(Range1), ElementType!(Range2))
1999 dotProduct(Range1, Range2)(Range1 a, Range2 b)
2000 if (isInputRange!(Range1) && isInputRange!(Range2) &&
2001 !(isArray!(Range1) && isArray!(Range2)))
2002 {
2003 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2004 static if (haveLen) assert(a.length == b.length);
2005 Unqual!(typeof(return)) result = 0;
2006 for (; !a.empty; a.popFront(), b.popFront())
2007 {
2008 result += a.front * b.front;
2009 }
2010 static if (!haveLen) assert(b.empty);
2011 return result;
2012 }
2013
2014 /// Ditto
2015 CommonType!(F1, F2)
2016 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
2017 {
2018 immutable n = avector.length;
2019 assert(n == bvector.length);
2020 auto avec = avector.ptr, bvec = bvector.ptr;
2021 Unqual!(typeof(return)) sum0 = 0, sum1 = 0;
2022
2023 const all_endp = avec + n;
2024 const smallblock_endp = avec + (n & ~3);
2025 const bigblock_endp = avec + (n & ~15);
2026
2027 for (; avec != bigblock_endp; avec += 16, bvec += 16)
2028 {
2029 sum0 += avec[0] * bvec[0];
2030 sum1 += avec[1] * bvec[1];
2031 sum0 += avec[2] * bvec[2];
2032 sum1 += avec[3] * bvec[3];
2033 sum0 += avec[4] * bvec[4];
2034 sum1 += avec[5] * bvec[5];
2035 sum0 += avec[6] * bvec[6];
2036 sum1 += avec[7] * bvec[7];
2037 sum0 += avec[8] * bvec[8];
2038 sum1 += avec[9] * bvec[9];
2039 sum0 += avec[10] * bvec[10];
2040 sum1 += avec[11] * bvec[11];
2041 sum0 += avec[12] * bvec[12];
2042 sum1 += avec[13] * bvec[13];
2043 sum0 += avec[14] * bvec[14];
2044 sum1 += avec[15] * bvec[15];
2045 }
2046
2047 for (; avec != smallblock_endp; avec += 4, bvec += 4)
2048 {
2049 sum0 += avec[0] * bvec[0];
2050 sum1 += avec[1] * bvec[1];
2051 sum0 += avec[2] * bvec[2];
2052 sum1 += avec[3] * bvec[3];
2053 }
2054
2055 sum0 += sum1;
2056
2057 /* Do trailing portion in naive loop. */
2058 while (avec != all_endp)
2059 {
2060 sum0 += *avec * *bvec;
2061 ++avec;
2062 ++bvec;
2063 }
2064
2065 return sum0;
2066 }
2067
2068 /// ditto
2069 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b)
2070 if (N <= 16)
2071 {
2072 F sum0 = 0;
2073 F sum1 = 0;
2074 static foreach (i; 0 .. N / 2)
2075 {
2076 sum0 += a[i*2] * b[i*2];
2077 sum1 += a[i*2+1] * b[i*2+1];
2078 }
2079 static if (N % 2 == 1)
2080 {
2081 sum0 += a[N-1] * b[N-1];
2082 }
2083 return sum0 + sum1;
2084 }
2085
2086 @system unittest
2087 {
2088 // @system due to dotProduct and assertCTFEable
2089 import std.exception : assertCTFEable;
2090 import std.meta : AliasSeq;
2091 static foreach (T; AliasSeq!(double, const double, immutable double))
2092 {{
2093 T[] a = [ 1.0, 2.0, ];
2094 T[] b = [ 4.0, 6.0, ];
2095 assert(dotProduct(a, b) == 16);
2096 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
2097 // Test with fixed-length arrays.
2098 T[2] c = [ 1.0, 2.0, ];
2099 T[2] d = [ 4.0, 6.0, ];
2100 assert(dotProduct(c, d) == 16);
2101 T[3] e = [1, 3, -5];
2102 T[3] f = [4, -2, -1];
2103 assert(dotProduct(e, f) == 3);
2104 }}
2105
2106 // Make sure the unrolled loop codepath gets tested.
2107 static const x =
2108 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22];
2109 static const y =
2110 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23];
2111 assertCTFEable!({ assert(dotProduct(x, y) == 4048); });
2112 }
2113
2114 /**
2115 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity,
2116 cosine similarity) of input ranges `a` and $(D
2117 b). The two ranges must have the same length. If both ranges define
2118 length, the check is done once; otherwise, it is done at each
2119 iteration. If either range has all-zero elements, return 0.
2120 */
2121 CommonType!(ElementType!(Range1), ElementType!(Range2))
2122 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b)
2123 if (isInputRange!(Range1) && isInputRange!(Range2))
2124 {
2125 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2126 static if (haveLen) assert(a.length == b.length);
2127 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0;
2128 for (; !a.empty; a.popFront(), b.popFront())
2129 {
2130 immutable t1 = a.front, t2 = b.front;
2131 norma += t1 * t1;
2132 normb += t2 * t2;
2133 dotprod += t1 * t2;
2134 }
2135 static if (!haveLen) assert(b.empty);
2136 if (norma == 0 || normb == 0) return 0;
2137 return dotprod / sqrt(norma * normb);
2138 }
2139
2140 @safe unittest
2141 {
2142 import std.meta : AliasSeq;
2143 static foreach (T; AliasSeq!(double, const double, immutable double))
2144 {{
2145 T[] a = [ 1.0, 2.0, ];
2146 T[] b = [ 4.0, 3.0, ];
2147 assert(isClose(
2148 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25),
2149 0.01));
2150 }}
2151 }
2152
2153 /**
2154 Normalizes values in `range` by multiplying each element with a
2155 number chosen such that values sum up to `sum`. If elements in $(D
2156 range) sum to zero, assigns $(D sum / range.length) to
2157 all. Normalization makes sense only if all elements in `range` are
2158 positive. `normalize` assumes that is the case without checking it.
2159
2160 Returns: `true` if normalization completed normally, `false` if
2161 all elements in `range` were zero or if `range` is empty.
2162 */
2163 bool normalize(R)(R range, ElementType!(R) sum = 1)
2164 if (isForwardRange!(R))
2165 {
2166 ElementType!(R) s = 0;
2167 // Step 1: Compute sum and length of the range
2168 static if (hasLength!(R))
2169 {
2170 const length = range.length;
2171 foreach (e; range)
2172 {
2173 s += e;
2174 }
2175 }
2176 else
2177 {
2178 uint length = 0;
2179 foreach (e; range)
2180 {
2181 s += e;
2182 ++length;
2183 }
2184 }
2185 // Step 2: perform normalization
2186 if (s == 0)
2187 {
2188 if (length)
2189 {
2190 immutable f = sum / range.length;
2191 foreach (ref e; range) e = f;
2192 }
2193 return false;
2194 }
2195 // The path most traveled
2196 assert(s >= 0);
2197 immutable f = sum / s;
2198 foreach (ref e; range)
2199 e *= f;
2200 return true;
2201 }
2202
2203 ///
2204 @safe unittest
2205 {
2206 double[] a = [];
2207 assert(!normalize(a));
2208 a = [ 1.0, 3.0 ];
2209 assert(normalize(a));
2210 assert(a == [ 0.25, 0.75 ]);
2211 assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5]
2212 a = [ 0.0, 0.0 ];
2213 assert(!normalize(a));
2214 assert(a == [ 0.5, 0.5 ]);
2215 }
2216
2217 /**
2218 Compute the sum of binary logarithms of the input range `r`.
2219 The error of this method is much smaller than with a naive sum of log2.
2220 */
2221 ElementType!Range sumOfLog2s(Range)(Range r)
2222 if (isInputRange!Range && isFloatingPoint!(ElementType!Range))
2223 {
2224 long exp = 0;
2225 Unqual!(typeof(return)) x = 1;
2226 foreach (e; r)
2227 {
2228 if (e < 0)
2229 return typeof(return).nan;
2230 int lexp = void;
2231 x *= frexp(e, lexp);
2232 exp += lexp;
2233 if (x < 0.5)
2234 {
2235 x *= 2;
2236 exp--;
2237 }
2238 }
2239 return exp + log2(x);
2240 }
2241
2242 ///
2243 @safe unittest
2244 {
2245 import std.math.traits : isNaN;
2246
2247 assert(sumOfLog2s(new double[0]) == 0);
2248 assert(sumOfLog2s([0.0L]) == -real.infinity);
2249 assert(sumOfLog2s([-0.0L]) == -real.infinity);
2250 assert(sumOfLog2s([2.0L]) == 1);
2251 assert(sumOfLog2s([-2.0L]).isNaN());
2252 assert(sumOfLog2s([real.nan]).isNaN());
2253 assert(sumOfLog2s([-real.nan]).isNaN());
2254 assert(sumOfLog2s([real.infinity]) == real.infinity);
2255 assert(sumOfLog2s([-real.infinity]).isNaN());
2256 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
2257 }
2258
2259 /**
2260 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory),
2261 _entropy) of input range `r` in bits. This
2262 function assumes (without checking) that the values in `r` are all
2263 in $(D [0, 1]). For the entropy to be meaningful, often `r` should
2264 be normalized too (i.e., its values should sum to 1). The
2265 two-parameter version stops evaluating as soon as the intermediate
2266 result is greater than or equal to `max`.
2267 */
2268 ElementType!Range entropy(Range)(Range r)
2269 if (isInputRange!Range)
2270 {
2271 Unqual!(typeof(return)) result = 0.0;
2272 for (;!r.empty; r.popFront)
2273 {
2274 if (!r.front) continue;
2275 result -= r.front * log2(r.front);
2276 }
2277 return result;
2278 }
2279
2280 /// Ditto
2281 ElementType!Range entropy(Range, F)(Range r, F max)
2282 if (isInputRange!Range &&
2283 !is(CommonType!(ElementType!Range, F) == void))
2284 {
2285 Unqual!(typeof(return)) result = 0.0;
2286 for (;!r.empty; r.popFront)
2287 {
2288 if (!r.front) continue;
2289 result -= r.front * log2(r.front);
2290 if (result >= max) break;
2291 }
2292 return result;
2293 }
2294
2295 @safe unittest
2296 {
2297 import std.meta : AliasSeq;
2298 static foreach (T; AliasSeq!(double, const double, immutable double))
2299 {{
2300 T[] p = [ 0.0, 0, 0, 1 ];
2301 assert(entropy(p) == 0);
2302 p = [ 0.25, 0.25, 0.25, 0.25 ];
2303 assert(entropy(p) == 2);
2304 assert(entropy(p, 1) == 1);
2305 }}
2306 }
2307
2308 /**
2309 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence,
2310 Kullback-Leibler divergence) between input ranges
2311 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base
2312 of logarithm is 2. The ranges are assumed to contain elements in $(D
2313 [0, 1]). Usually the ranges are normalized probability distributions,
2314 but this is not required or checked by $(D
2315 kullbackLeiblerDivergence). If any element `bi` is zero and the
2316 corresponding element `ai` nonzero, returns infinity. (Otherwise,
2317 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is
2318 considered zero.) If the inputs are normalized, the result is
2319 positive.
2320 */
2321 CommonType!(ElementType!Range1, ElementType!Range2)
2322 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b)
2323 if (isInputRange!(Range1) && isInputRange!(Range2))
2324 {
2325 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2326 static if (haveLen) assert(a.length == b.length);
2327 Unqual!(typeof(return)) result = 0;
2328 for (; !a.empty; a.popFront(), b.popFront())
2329 {
2330 immutable t1 = a.front;
2331 if (t1 == 0) continue;
2332 immutable t2 = b.front;
2333 if (t2 == 0) return result.infinity;
2334 assert(t1 > 0 && t2 > 0);
2335 result += t1 * log2(t1 / t2);
2336 }
2337 static if (!haveLen) assert(b.empty);
2338 return result;
2339 }
2340
2341 ///
2342 @safe unittest
2343 {
2344 import std.math.operations : isClose;
2345
2346 double[] p = [ 0.0, 0, 0, 1 ];
2347 assert(kullbackLeiblerDivergence(p, p) == 0);
2348 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2349 assert(kullbackLeiblerDivergence(p1, p1) == 0);
2350 assert(kullbackLeiblerDivergence(p, p1) == 2);
2351 assert(kullbackLeiblerDivergence(p1, p) == double.infinity);
2352 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2353 assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5));
2354 assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5));
2355 }
2356
2357 /**
2358 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence,
2359 Jensen-Shannon divergence) between `a` and $(D
2360 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 *
2361 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are
2362 assumed to contain elements in $(D [0, 1]). Usually the ranges are
2363 normalized probability distributions, but this is not required or
2364 checked by `jensenShannonDivergence`. If the inputs are normalized,
2365 the result is bounded within $(D [0, 1]). The three-parameter version
2366 stops evaluations as soon as the intermediate result is greater than
2367 or equal to `limit`.
2368 */
2369 CommonType!(ElementType!Range1, ElementType!Range2)
2370 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b)
2371 if (isInputRange!Range1 && isInputRange!Range2 &&
2372 is(CommonType!(ElementType!Range1, ElementType!Range2)))
2373 {
2374 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2375 static if (haveLen) assert(a.length == b.length);
2376 Unqual!(typeof(return)) result = 0;
2377 for (; !a.empty; a.popFront(), b.popFront())
2378 {
2379 immutable t1 = a.front;
2380 immutable t2 = b.front;
2381 immutable avg = (t1 + t2) / 2;
2382 if (t1 != 0)
2383 {
2384 result += t1 * log2(t1 / avg);
2385 }
2386 if (t2 != 0)
2387 {
2388 result += t2 * log2(t2 / avg);
2389 }
2390 }
2391 static if (!haveLen) assert(b.empty);
2392 return result / 2;
2393 }
2394
2395 /// Ditto
2396 CommonType!(ElementType!Range1, ElementType!Range2)
2397 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit)
2398 if (isInputRange!Range1 && isInputRange!Range2 &&
2399 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init
2400 >= F.init) : bool))
2401 {
2402 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2403 static if (haveLen) assert(a.length == b.length);
2404 Unqual!(typeof(return)) result = 0;
2405 limit *= 2;
2406 for (; !a.empty; a.popFront(), b.popFront())
2407 {
2408 immutable t1 = a.front;
2409 immutable t2 = b.front;
2410 immutable avg = (t1 + t2) / 2;
2411 if (t1 != 0)
2412 {
2413 result += t1 * log2(t1 / avg);
2414 }
2415 if (t2 != 0)
2416 {
2417 result += t2 * log2(t2 / avg);
2418 }
2419 if (result >= limit) break;
2420 }
2421 static if (!haveLen) assert(b.empty);
2422 return result / 2;
2423 }
2424
2425 ///
2426 @safe unittest
2427 {
2428 import std.math.operations : isClose;
2429
2430 double[] p = [ 0.0, 0, 0, 1 ];
2431 assert(jensenShannonDivergence(p, p) == 0);
2432 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2433 assert(jensenShannonDivergence(p1, p1) == 0);
2434 assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5));
2435 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2436 assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5));
2437 assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5));
2438 assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5));
2439 }
2440
2441 /**
2442 The so-called "all-lengths gap-weighted string kernel" computes a
2443 similarity measure between `s` and `t` based on all of their
2444 common subsequences of all lengths. Gapped subsequences are also
2445 included.
2446
2447 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes,
2448 consider first the case $(D lambda = 1) and the strings $(D s =
2449 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new",
2450 "world"]). In that case, `gapWeightedSimilarity` counts the
2451 following matches:
2452
2453 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`,
2454 and `"world"`;) $(LI three matches of length 2, namely ($(D
2455 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));)
2456 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).))
2457
2458 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of
2459 these matches and adds them up, returning 7.
2460
2461 ----
2462 string[] s = ["Hello", "brave", "new", "world"];
2463 string[] t = ["Hello", "new", "world"];
2464 assert(gapWeightedSimilarity(s, t, 1) == 7);
2465 ----
2466
2467 Note how the gaps in matching are simply ignored, for example ($(D
2468 "Hello", "new")) is deemed as good a match as ($(D "new",
2469 "world")). This may be too permissive for some applications. To
2470 eliminate gapped matches entirely, use $(D lambda = 0):
2471
2472 ----
2473 string[] s = ["Hello", "brave", "new", "world"];
2474 string[] t = ["Hello", "new", "world"];
2475 assert(gapWeightedSimilarity(s, t, 0) == 4);
2476 ----
2477
2478 The call above eliminated the gapped matches ($(D "Hello", "new")),
2479 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the
2480 tally. That leaves only 4 matches.
2481
2482 The most interesting case is when gapped matches still participate in
2483 the result, but not as strongly as ungapped matches. The result will
2484 be a smooth, fine-grained similarity measure between the input
2485 strings. This is where values of `lambda` between 0 and 1 enter
2486 into play: gapped matches are $(I exponentially penalized with the
2487 number of gaps) with base `lambda`. This means that an ungapped
2488 match adds 1 to the return value; a match with one gap in either
2489 string adds `lambda` to the return value; ...; a match with a total
2490 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return
2491 value. In the example above, we have 4 matches without gaps, 2 matches
2492 with one gap, and 1 match with three gaps. The latter match is ($(D
2493 "Hello", "world")), which has two gaps in the first string and one gap
2494 in the second string, totaling to three gaps. Summing these up we get
2495 $(D 4 + 2 * lambda + pow(lambda, 3)).
2496
2497 ----
2498 string[] s = ["Hello", "brave", "new", "world"];
2499 string[] t = ["Hello", "new", "world"];
2500 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125);
2501 ----
2502
2503 `gapWeightedSimilarity` is useful wherever a smooth similarity
2504 measure between sequences allowing for approximate matches is
2505 needed. The examples above are given with words, but any sequences
2506 with elements comparable for equality are allowed, e.g. characters or
2507 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic
2508 programming implementation that needs $(D 16 * min(s.length,
2509 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time
2510 to complete.
2511 */
2512 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda)
2513 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2514 isRandomAccessRange!(R2) && hasLength!(R2))
2515 {
2516 import core.exception : onOutOfMemoryError;
2517 import core.stdc.stdlib : malloc, free;
2518 import std.algorithm.mutation : swap;
2519 import std.functional : binaryFun;
2520
2521 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda);
2522 if (!t.length) return 0;
2523
2524 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length);
2525 if (!dpvi)
2526 onOutOfMemoryError();
2527
2528 auto dpvi1 = dpvi + t.length;
2529 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1);
2530 dpvi[0 .. t.length] = 0;
2531 dpvi1[0] = 0;
2532 immutable lambda2 = lambda * lambda;
2533
2534 F result = 0;
2535 foreach (i; 0 .. s.length)
2536 {
2537 const si = s[i];
2538 for (size_t j = 0;;)
2539 {
2540 F dpsij = void;
2541 if (binaryFun!(comp)(si, t[j]))
2542 {
2543 dpsij = 1 + dpvi[j];
2544 result += dpsij;
2545 }
2546 else
2547 {
2548 dpsij = 0;
2549 }
2550 immutable j1 = j + 1;
2551 if (j1 == t.length) break;
2552 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) -
2553 lambda2 * dpvi[j];
2554 j = j1;
2555 }
2556 swap(dpvi, dpvi1);
2557 }
2558 return result;
2559 }
2560
2561 @system unittest
2562 {
2563 string[] s = ["Hello", "brave", "new", "world"];
2564 string[] t = ["Hello", "new", "world"];
2565 assert(gapWeightedSimilarity(s, t, 1) == 7);
2566 assert(gapWeightedSimilarity(s, t, 0) == 4);
2567 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125);
2568 }
2569
2570 /**
2571 The similarity per `gapWeightedSimilarity` has an issue in that it
2572 grows with the lengths of the two strings, even though the strings are
2573 not actually very similar. For example, the range $(D ["Hello",
2574 "world"]) is increasingly similar with the range $(D ["Hello",
2575 "world", "world", "world",...]) as more instances of `"world"` are
2576 appended. To prevent that, `gapWeightedSimilarityNormalized`
2577 computes a normalized version of the similarity that is computed as
2578 $(D gapWeightedSimilarity(s, t, lambda) /
2579 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t,
2580 lambda))). The function `gapWeightedSimilarityNormalized` (a
2581 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0`
2582 only for ranges that don't match in any position, and `1` only for
2583 identical ranges.
2584
2585 The optional parameters `sSelfSim` and `tSelfSim` are meant for
2586 avoiding duplicate computation. Many applications may have already
2587 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D
2588 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed
2589 as `sSelfSim` and `tSelfSim`, respectively.
2590 */
2591 Select!(isFloatingPoint!(F), F, double)
2592 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F)
2593 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init)
2594 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2595 isRandomAccessRange!(R2) && hasLength!(R2))
2596 {
2597 static bool uncomputed(F n)
2598 {
2599 static if (isFloatingPoint!(F))
2600 return isNaN(n);
2601 else
2602 return n == n.init;
2603 }
2604 if (uncomputed(sSelfSim))
2605 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda);
2606 if (sSelfSim == 0) return 0;
2607 if (uncomputed(tSelfSim))
2608 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda);
2609 if (tSelfSim == 0) return 0;
2610
2611 return gapWeightedSimilarity!(comp)(s, t, lambda) /
2612 sqrt(cast(typeof(return)) sSelfSim * tSelfSim);
2613 }
2614
2615 ///
2616 @system unittest
2617 {
2618 import std.math.operations : isClose;
2619 import std.math.algebraic : sqrt;
2620
2621 string[] s = ["Hello", "brave", "new", "world"];
2622 string[] t = ["Hello", "new", "world"];
2623 assert(gapWeightedSimilarity(s, s, 1) == 15);
2624 assert(gapWeightedSimilarity(t, t, 1) == 7);
2625 assert(gapWeightedSimilarity(s, t, 1) == 7);
2626 assert(isClose(gapWeightedSimilarityNormalized(s, t, 1),
2627 7.0 / sqrt(15.0 * 7), 0.01));
2628 }
2629
2630 /**
2631 Similar to `gapWeightedSimilarity`, just works in an incremental
2632 manner by first revealing the matches of length 1, then gapped matches
2633 of length 2, and so on. The memory requirement is $(BIGOH s.length *
2634 t.length). The time complexity is $(BIGOH s.length * t.length) time
2635 for computing each step. Continuing on the previous example:
2636
2637 The implementation is based on the pseudocode in Fig. 4 of the paper
2638 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf,
2639 "Efficient Computation of Gapped Substring Kernels on Large Alphabets")
2640 by Rousu et al., with additional algorithmic and systems-level
2641 optimizations.
2642 */
2643 struct GapWeightedSimilarityIncremental(Range, F = double)
2644 if (isRandomAccessRange!(Range) && hasLength!(Range))
2645 {
2646 import core.stdc.stdlib : malloc, realloc, alloca, free;
2647
2648 private:
2649 Range s, t;
2650 F currentValue = 0;
2651 F* kl;
2652 size_t gram = void;
2653 F lambda = void, lambda2 = void;
2654
2655 public:
2656 /**
2657 Constructs an object given two ranges `s` and `t` and a penalty
2658 `lambda`. Constructor completes in $(BIGOH s.length * t.length)
2659 time and computes all matches of length 1.
2660 */
2661 this(Range s, Range t, F lambda)
2662 {
2663 import core.exception : onOutOfMemoryError;
2664
2665 assert(lambda > 0);
2666 this.gram = 0;
2667 this.lambda = lambda;
2668 this.lambda2 = lambda * lambda; // for efficiency only
2669
2670 size_t iMin = size_t.max, jMin = size_t.max,
2671 iMax = 0, jMax = 0;
2672 /* initialize */
2673 Tuple!(size_t, size_t) * k0;
2674 size_t k0len;
2675 scope(exit) free(k0);
2676 currentValue = 0;
2677 foreach (i, si; s)
2678 {
2679 foreach (j; 0 .. t.length)
2680 {
2681 if (si != t[j]) continue;
2682 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof);
2683 with (k0[k0len - 1])
2684 {
2685 field[0] = i;
2686 field[1] = j;
2687 }
2688 // Maintain the minimum and maximum i and j
2689 if (iMin > i) iMin = i;
2690 if (iMax < i) iMax = i;
2691 if (jMin > j) jMin = j;
2692 if (jMax < j) jMax = j;
2693 }
2694 }
2695
2696 if (iMin > iMax) return;
2697 assert(k0len);
2698
2699 currentValue = k0len;
2700 // Chop strings down to the useful sizes
2701 s = s[iMin .. iMax + 1];
2702 t = t[jMin .. jMax + 1];
2703 this.s = s;
2704 this.t = t;
2705
2706 kl = cast(F*) malloc(s.length * t.length * F.sizeof);
2707 if (!kl)
2708 onOutOfMemoryError();
2709
2710 kl[0 .. s.length * t.length] = 0;
2711 foreach (pos; 0 .. k0len)
2712 {
2713 with (k0[pos])
2714 {
2715 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2;
2716 }
2717 }
2718 }
2719
2720 /**
2721 Returns: `this`.
2722 */
2723 ref GapWeightedSimilarityIncremental opSlice()
2724 {
2725 return this;
2726 }
2727
2728 /**
2729 Computes the match of the popFront length. Completes in $(BIGOH s.length *
2730 t.length) time.
2731 */
2732 void popFront()
2733 {
2734 import std.algorithm.mutation : swap;
2735
2736 // This is a large source of optimization: if similarity at
2737 // the gram-1 level was 0, then we can safely assume
2738 // similarity at the gram level is 0 as well.
2739 if (empty) return;
2740
2741 // Now attempt to match gapped substrings of length `gram'
2742 ++gram;
2743 currentValue = 0;
2744
2745 auto Si = cast(F*) alloca(t.length * F.sizeof);
2746 Si[0 .. t.length] = 0;
2747 foreach (i; 0 .. s.length)
2748 {
2749 const si = s[i];
2750 F Sij_1 = 0;
2751 F Si_1j_1 = 0;
2752 auto kli = kl + i * t.length;
2753 for (size_t j = 0;;)
2754 {
2755 const klij = kli[j];
2756 const Si_1j = Si[j];
2757 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1;
2758 // now update kl and currentValue
2759 if (si == t[j])
2760 currentValue += kli[j] = lambda2 * Si_1j_1;
2761 else
2762 kli[j] = 0;
2763 // commit to Si
2764 Si[j] = tmp;
2765 if (++j == t.length) break;
2766 // get ready for the popFront step; virtually increment j,
2767 // so essentially stuffj_1 <-- stuffj
2768 Si_1j_1 = Si_1j;
2769 Sij_1 = tmp;
2770 }
2771 }
2772 currentValue /= pow(lambda, 2 * (gram + 1));
2773
2774 version (none)
2775 {
2776 Si_1[0 .. t.length] = 0;
2777 kl[0 .. min(t.length, maxPerimeter + 1)] = 0;
2778 foreach (i; 1 .. min(s.length, maxPerimeter + 1))
2779 {
2780 auto kli = kl + i * t.length;
2781 assert(s.length > i);
2782 const si = s[i];
2783 auto kl_1i_1 = kl_1 + (i - 1) * t.length;
2784 kli[0] = 0;
2785 F lastS = 0;
2786 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length))
2787 {
2788 immutable j_1 = j - 1;
2789 immutable tmp = kl_1i_1[j_1]
2790 + lambda * (Si_1[j] + lastS)
2791 - lambda2 * Si_1[j_1];
2792 kl_1i_1[j_1] = float.nan;
2793 Si_1[j_1] = lastS;
2794 lastS = tmp;
2795 if (si == t[j])
2796 {
2797 currentValue += kli[j] = lambda2 * lastS;
2798 }
2799 else
2800 {
2801 kli[j] = 0;
2802 }
2803 }
2804 Si_1[t.length - 1] = lastS;
2805 }
2806 currentValue /= pow(lambda, 2 * (gram + 1));
2807 // get ready for the popFront computation
2808 swap(kl, kl_1);
2809 }
2810 }
2811
2812 /**
2813 Returns: The gapped similarity at the current match length (initially
2814 1, grows with each call to `popFront`).
2815 */
2816 @property F front() { return currentValue; }
2817
2818 /**
2819 Returns: Whether there are more matches.
2820 */
2821 @property bool empty()
2822 {
2823 if (currentValue) return false;
2824 if (kl)
2825 {
2826 free(kl);
2827 kl = null;
2828 }
2829 return true;
2830 }
2831 }
2832
2833 /**
2834 Ditto
2835 */
2836 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F)
2837 (R r1, R r2, F penalty)
2838 {
2839 return typeof(return)(r1, r2, penalty);
2840 }
2841
2842 ///
2843 @system unittest
2844 {
2845 string[] s = ["Hello", "brave", "new", "world"];
2846 string[] t = ["Hello", "new", "world"];
2847 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2848 assert(simIter.front == 3); // three 1-length matches
2849 simIter.popFront();
2850 assert(simIter.front == 3); // three 2-length matches
2851 simIter.popFront();
2852 assert(simIter.front == 1); // one 3-length match
2853 simIter.popFront();
2854 assert(simIter.empty); // no more match
2855 }
2856
2857 @system unittest
2858 {
2859 import std.conv : text;
2860 string[] s = ["Hello", "brave", "new", "world"];
2861 string[] t = ["Hello", "new", "world"];
2862 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2863 //foreach (e; simIter) writeln(e);
2864 assert(simIter.front == 3); // three 1-length matches
2865 simIter.popFront();
2866 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches
2867 simIter.popFront();
2868 assert(simIter.front == 1); // one 3-length matches
2869 simIter.popFront();
2870 assert(simIter.empty); // no more match
2871
2872 s = ["Hello"];
2873 t = ["bye"];
2874 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2875 assert(simIter.empty);
2876
2877 s = ["Hello"];
2878 t = ["Hello"];
2879 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2880 assert(simIter.front == 1); // one match
2881 simIter.popFront();
2882 assert(simIter.empty);
2883
2884 s = ["Hello", "world"];
2885 t = ["Hello"];
2886 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2887 assert(simIter.front == 1); // one match
2888 simIter.popFront();
2889 assert(simIter.empty);
2890
2891 s = ["Hello", "world"];
2892 t = ["Hello", "yah", "world"];
2893 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2894 assert(simIter.front == 2); // two 1-gram matches
2895 simIter.popFront();
2896 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap
2897 }
2898
2899 @system unittest
2900 {
2901 GapWeightedSimilarityIncremental!(string[]) sim =
2902 GapWeightedSimilarityIncremental!(string[])(
2903 ["nyuk", "I", "have", "no", "chocolate", "giba"],
2904 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"],
2905 0.5);
2906 double[] witness = [ 7.0, 4.03125, 0, 0 ];
2907 foreach (e; sim)
2908 {
2909 //writeln(e);
2910 assert(e == witness.front);
2911 witness.popFront();
2912 }
2913 witness = [ 3.0, 1.3125, 0.25 ];
2914 sim = GapWeightedSimilarityIncremental!(string[])(
2915 ["I", "have", "no", "chocolate"],
2916 ["I", "have", "some", "chocolate"],
2917 0.5);
2918 foreach (e; sim)
2919 {
2920 //writeln(e);
2921 assert(e == witness.front);
2922 witness.popFront();
2923 }
2924 assert(witness.empty);
2925 }
2926
2927 /**
2928 Computes the greatest common divisor of `a` and `b` by using
2929 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's)
2930 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm.
2931
2932 Params:
2933 a = Integer value of any numerical type that supports the modulo operator `%`.
2934 If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will
2935 be used; otherwise, Euclid's algorithm is used as _a fallback.
2936 b = Integer value of any equivalent numerical type.
2937
2938 Returns:
2939 The greatest common divisor of the given arguments.
2940 */
2941 typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b)
2942 if (isIntegral!T && isIntegral!U)
2943 {
2944 // Operate on a common type between the two arguments.
2945 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
2946
2947 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
2948 static if (is(T : immutable short) || is(T : immutable byte))
2949 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
2950 else
2951 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
2952
2953 static if (is(U : immutable short) || is(U : immutable byte))
2954 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
2955 else
2956 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
2957
2958 // Special cases.
2959 if (ax == 0)
2960 return bx;
2961 if (bx == 0)
2962 return ax;
2963
2964 return gcdImpl(ax, bx);
2965 }
2966
2967 private typeof(T.init % T.init) gcdImpl(T)(T a, T b)
2968 if (isIntegral!T)
2969 {
2970 pragma(inline, true);
2971 import core.bitop : bsf;
2972 import std.algorithm.mutation : swap;
2973
2974 immutable uint shift = bsf(a | b);
2975 a >>= a.bsf;
2976 do
2977 {
2978 b >>= b.bsf;
2979 if (a > b)
2980 swap(a, b);
2981 b -= a;
2982 } while (b);
2983
2984 return a << shift;
2985 }
2986
2987 ///
2988 @safe unittest
2989 {
2990 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7);
2991 const int a = 5 * 13 * 23 * 23, b = 13 * 59;
2992 assert(gcd(a, b) == 13);
2993 }
2994
2995 @safe unittest
2996 {
2997 import std.meta : AliasSeq;
2998 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
2999 const byte, const short, const int, const long,
3000 immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3001 {
3002 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3003 const ubyte, const ushort, const uint, const ulong,
3004 immutable byte, immutable short, immutable int, immutable long))
3005 {
3006 // Signed and unsigned tests.
3007 static if (T.max > byte.max && U.max > byte.max)
3008 assert(gcd(T(200), U(200)) == 200);
3009 static if (T.max > ubyte.max)
3010 {
3011 assert(gcd(T(2000), U(20)) == 20);
3012 assert(gcd(T(2011), U(17)) == 1);
3013 }
3014 static if (T.max > ubyte.max && U.max > ubyte.max)
3015 assert(gcd(T(1071), U(462)) == 21);
3016
3017 assert(gcd(T(0), U(13)) == 13);
3018 assert(gcd(T(29), U(0)) == 29);
3019 assert(gcd(T(0), U(0)) == 0);
3020 assert(gcd(T(1), U(2)) == 1);
3021 assert(gcd(T(9), U(6)) == 3);
3022 assert(gcd(T(3), U(4)) == 1);
3023 assert(gcd(T(32), U(24)) == 8);
3024 assert(gcd(T(5), U(6)) == 1);
3025 assert(gcd(T(54), U(36)) == 18);
3026
3027 // Int and Long tests.
3028 static if (T.max > short.max && U.max > short.max)
3029 assert(gcd(T(46391), U(62527)) == 2017);
3030 static if (T.max > ushort.max && U.max > ushort.max)
3031 assert(gcd(T(63245986), U(39088169)) == 1);
3032 static if (T.max > uint.max && U.max > uint.max)
3033 {
3034 assert(gcd(T(77160074263), U(47687519812)) == 1);
3035 assert(gcd(T(77160074264), U(47687519812)) == 4);
3036 }
3037
3038 // Negative tests.
3039 static if (T.min < 0)
3040 {
3041 assert(gcd(T(-21), U(28)) == 7);
3042 assert(gcd(T(-3), U(4)) == 1);
3043 }
3044 static if (U.min < 0)
3045 {
3046 assert(gcd(T(1), U(-2)) == 1);
3047 assert(gcd(T(33), U(-44)) == 11);
3048 }
3049 static if (T.min < 0 && U.min < 0)
3050 {
3051 assert(gcd(T(-5), U(-6)) == 1);
3052 assert(gcd(T(-50), U(-60)) == 10);
3053 }
3054 }
3055 }
3056 }
3057
3058 // https://issues.dlang.org/show_bug.cgi?id=21834
3059 @safe unittest
3060 {
3061 assert(gcd(-120, 10U) == 10);
3062 assert(gcd(120U, -10) == 10);
3063 assert(gcd(int.min, 0L) == 1L + int.max);
3064 assert(gcd(0L, int.min) == 1L + int.max);
3065 assert(gcd(int.min, 0L + int.min) == 1L + int.max);
3066 assert(gcd(int.min, 1L + int.max) == 1L + int.max);
3067 assert(gcd(short.min, 1U + short.max) == 1U + short.max);
3068 }
3069
3070 // This overload is for non-builtin numerical types like BigInt or
3071 // user-defined types.
3072 /// ditto
3073 auto gcd(T)(T a, T b)
3074 if (!isIntegral!T &&
3075 is(typeof(T.init % T.init)) &&
3076 is(typeof(T.init == 0 || T.init > 0)))
3077 {
3078 static if (!is(T == Unqual!T))
3079 {
3080 return gcd!(Unqual!T)(a, b);
3081 }
3082 else
3083 {
3084 // Ensure arguments are unsigned.
3085 a = a >= 0 ? a : -a;
3086 b = b >= 0 ? b : -b;
3087
3088 // Special cases.
3089 if (a == 0)
3090 return b;
3091 if (b == 0)
3092 return a;
3093
3094 return gcdImpl(a, b);
3095 }
3096 }
3097
3098 private auto gcdImpl(T)(T a, T b)
3099 if (!isIntegral!T)
3100 {
3101 pragma(inline, true);
3102 import std.algorithm.mutation : swap;
3103 enum canUseBinaryGcd = is(typeof(() {
3104 T t, u;
3105 t <<= 1;
3106 t >>= 1;
3107 t -= u;
3108 bool b = (t & 1) == 0;
3109 swap(t, u);
3110 }));
3111
3112 static if (canUseBinaryGcd)
3113 {
3114 uint shift = 0;
3115 while ((a & 1) == 0 && (b & 1) == 0)
3116 {
3117 a >>= 1;
3118 b >>= 1;
3119 shift++;
3120 }
3121
3122 if ((a & 1) == 0) swap(a, b);
3123
3124 do
3125 {
3126 assert((a & 1) != 0);
3127 while ((b & 1) == 0)
3128 b >>= 1;
3129 if (a > b)
3130 swap(a, b);
3131 b -= a;
3132 } while (b);
3133
3134 return a << shift;
3135 }
3136 else
3137 {
3138 // The only thing we have is %; fallback to Euclidean algorithm.
3139 while (b != 0)
3140 {
3141 auto t = b;
3142 b = a % b;
3143 a = t;
3144 }
3145 return a;
3146 }
3147 }
3148
3149 // https://issues.dlang.org/show_bug.cgi?id=7102
3150 @system pure unittest
3151 {
3152 import std.bigint : BigInt;
3153 assert(gcd(BigInt("71_000_000_000_000_000_000"),
3154 BigInt("31_000_000_000_000_000_000")) ==
3155 BigInt("1_000_000_000_000_000_000"));
3156
3157 assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567));
3158 assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567));
3159 }
3160
3161 @safe pure nothrow unittest
3162 {
3163 // A numerical type that only supports % and - (to force gcd implementation
3164 // to use Euclidean algorithm).
3165 struct CrippledInt
3166 {
3167 int impl;
3168 CrippledInt opBinary(string op : "%")(CrippledInt i)
3169 {
3170 return CrippledInt(impl % i.impl);
3171 }
3172 CrippledInt opUnary(string op : "-")()
3173 {
3174 return CrippledInt(-impl);
3175 }
3176 int opEquals(CrippledInt i) { return impl == i.impl; }
3177 int opEquals(int i) { return impl == i; }
3178 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; }
3179 }
3180 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77));
3181 assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10));
3182 assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10));
3183 }
3184
3185 // https://issues.dlang.org/show_bug.cgi?id=19514
3186 @system pure unittest
3187 {
3188 import std.bigint : BigInt;
3189 assert(gcd(BigInt(2), BigInt(1)) == BigInt(1));
3190 }
3191
3192 // Issue 20924
3193 @safe unittest
3194 {
3195 import std.bigint : BigInt;
3196 const a = BigInt("123143238472389492934020");
3197 const b = BigInt("902380489324729338420924");
3198 assert(__traits(compiles, gcd(a, b)));
3199 }
3200
3201 // https://issues.dlang.org/show_bug.cgi?id=21834
3202 @safe unittest
3203 {
3204 import std.bigint : BigInt;
3205 assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10));
3206 assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10));
3207 assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max));
3208 assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max));
3209 assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max));
3210 assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max));
3211 assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max));
3212 }
3213
3214
3215 /**
3216 Computes the least common multiple of `a` and `b`.
3217 Arguments are the same as $(MYREF gcd).
3218
3219 Returns:
3220 The least common multiple of the given arguments.
3221 */
3222 typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b)
3223 if (isIntegral!T && isIntegral!U)
3224 {
3225 // Operate on a common type between the two arguments.
3226 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
3227
3228 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
3229 static if (is(T : immutable short) || is(T : immutable byte))
3230 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
3231 else
3232 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
3233
3234 static if (is(U : immutable short) || is(U : immutable byte))
3235 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
3236 else
3237 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
3238
3239 // Special cases.
3240 if (ax == 0)
3241 return ax;
3242 if (bx == 0)
3243 return bx;
3244
3245 return (ax / gcdImpl(ax, bx)) * bx;
3246 }
3247
3248 ///
3249 @safe unittest
3250 {
3251 assert(lcm(1, 2) == 2);
3252 assert(lcm(3, 4) == 12);
3253 assert(lcm(5, 6) == 30);
3254 }
3255
3256 @safe unittest
3257 {
3258 import std.meta : AliasSeq;
3259 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3260 const byte, const short, const int, const long,
3261 immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3262 {
3263 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3264 const ubyte, const ushort, const uint, const ulong,
3265 immutable byte, immutable short, immutable int, immutable long))
3266 {
3267 assert(lcm(T(21), U(6)) == 42);
3268 assert(lcm(T(41), U(0)) == 0);
3269 assert(lcm(T(0), U(7)) == 0);
3270 assert(lcm(T(0), U(0)) == 0);
3271 assert(lcm(T(1U), U(2)) == 2);
3272 assert(lcm(T(3), U(4U)) == 12);
3273 assert(lcm(T(5U), U(6U)) == 30);
3274 static if (T.min < 0)
3275 assert(lcm(T(-42), U(21U)) == 42);
3276 }
3277 }
3278 }
3279
3280 /// ditto
3281 auto lcm(T)(T a, T b)
3282 if (!isIntegral!T &&
3283 is(typeof(T.init % T.init)) &&
3284 is(typeof(T.init == 0 || T.init > 0)))
3285 {
3286 // Ensure arguments are unsigned.
3287 a = a >= 0 ? a : -a;
3288 b = b >= 0 ? b : -b;
3289
3290 // Special cases.
3291 if (a == 0)
3292 return a;
3293 if (b == 0)
3294 return b;
3295
3296 return (a / gcdImpl(a, b)) * b;
3297 }
3298
3299 @safe unittest
3300 {
3301 import std.bigint : BigInt;
3302 assert(lcm(BigInt(21), BigInt(6)) == BigInt(42));
3303 assert(lcm(BigInt(41), BigInt(0)) == BigInt(0));
3304 assert(lcm(BigInt(0), BigInt(7)) == BigInt(0));
3305 assert(lcm(BigInt(0), BigInt(0)) == BigInt(0));
3306 assert(lcm(BigInt(1U), BigInt(2)) == BigInt(2));
3307 assert(lcm(BigInt(3), BigInt(4U)) == BigInt(12));
3308 assert(lcm(BigInt(5U), BigInt(6U)) == BigInt(30));
3309 assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42));
3310 }
3311
3312 // This is to make tweaking the speed/size vs. accuracy tradeoff easy,
3313 // though floats seem accurate enough for all practical purposes, since
3314 // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for
3315 // size 2 ^^ 22.
3316 private alias lookup_t = float;
3317
3318 /**A class for performing fast Fourier transforms of power of two sizes.
3319 * This class encapsulates a large amount of state that is reusable when
3320 * performing multiple FFTs of sizes smaller than or equal to that specified
3321 * in the constructor. This results in substantial speedups when performing
3322 * multiple FFTs with a known maximum size. However,
3323 * a free function API is provided for convenience if you need to perform a
3324 * one-off FFT.
3325 *
3326 * References:
3327 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
3328 */
3329 final class Fft
3330 {
3331 import core.bitop : bsf;
3332 import std.algorithm.iteration : map;
3333 import std.array : uninitializedArray;
3334
3335 private:
3336 immutable lookup_t[][] negSinLookup;
3337
3338 void enforceSize(R)(R range) const
3339 {
3340 import std.conv : text;
3341 assert(range.length <= size, text(
3342 "FFT size mismatch. Expected ", size, ", got ", range.length));
3343 }
3344
3345 void fftImpl(Ret, R)(Stride!R range, Ret buf) const
3346 in
3347 {
3348 assert(range.length >= 4);
3349 assert(isPowerOf2(range.length));
3350 }
3351 do
3352 {
3353 auto recurseRange = range;
3354 recurseRange.doubleSteps();
3355
3356 if (buf.length > 4)
3357 {
3358 fftImpl(recurseRange, buf[0..$ / 2]);
3359 recurseRange.popHalf();
3360 fftImpl(recurseRange, buf[$ / 2..$]);
3361 }
3362 else
3363 {
3364 // Do this here instead of in another recursion to save on
3365 // recursion overhead.
3366 slowFourier2(recurseRange, buf[0..$ / 2]);
3367 recurseRange.popHalf();
3368 slowFourier2(recurseRange, buf[$ / 2..$]);
3369 }
3370
3371 butterfly(buf);
3372 }
3373
3374 // This algorithm works by performing the even and odd parts of our FFT
3375 // using the "two for the price of one" method mentioned at
3376 // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521
3377 // by making the odd terms into the imaginary components of our new FFT,
3378 // and then using symmetry to recombine them.
3379 void fftImplPureReal(Ret, R)(R range, Ret buf) const
3380 in
3381 {
3382 assert(range.length >= 4);
3383 assert(isPowerOf2(range.length));
3384 }
3385 do
3386 {
3387 alias E = ElementType!R;
3388
3389 // Converts odd indices of range to the imaginary components of
3390 // a range half the size. The even indices become the real components.
3391 static if (isArray!R && isFloatingPoint!E)
3392 {
3393 // Then the memory layout of complex numbers provides a dirt
3394 // cheap way to convert. This is a common case, so take advantage.
3395 auto oddsImag = cast(Complex!E[]) range;
3396 }
3397 else
3398 {
3399 // General case: Use a higher order range. We can assume
3400 // source.length is even because it has to be a power of 2.
3401 static struct OddToImaginary
3402 {
3403 R source;
3404 alias C = Complex!(CommonType!(E, typeof(buf[0].re)));
3405
3406 @property
3407 {
3408 C front()
3409 {
3410 return C(source[0], source[1]);
3411 }
3412
3413 C back()
3414 {
3415 immutable n = source.length;
3416 return C(source[n - 2], source[n - 1]);
3417 }
3418
3419 typeof(this) save()
3420 {
3421 return typeof(this)(source.save);
3422 }
3423
3424 bool empty()
3425 {
3426 return source.empty;
3427 }
3428
3429 size_t length()
3430 {
3431 return source.length / 2;
3432 }
3433 }
3434
3435 void popFront()
3436 {
3437 source.popFront();
3438 source.popFront();
3439 }
3440
3441 void popBack()
3442 {
3443 source.popBack();
3444 source.popBack();
3445 }
3446
3447 C opIndex(size_t index)
3448 {
3449 return C(source[index * 2], source[index * 2 + 1]);
3450 }
3451
3452 typeof(this) opSlice(size_t lower, size_t upper)
3453 {
3454 return typeof(this)(source[lower * 2 .. upper * 2]);
3455 }
3456 }
3457
3458 auto oddsImag = OddToImaginary(range);
3459 }
3460
3461 fft(oddsImag, buf[0..$ / 2]);
3462 auto evenFft = buf[0..$ / 2];
3463 auto oddFft = buf[$ / 2..$];
3464 immutable halfN = evenFft.length;
3465 oddFft[0].re = buf[0].im;
3466 oddFft[0].im = 0;
3467 evenFft[0].im = 0;
3468 // evenFft[0].re is already right b/c it's aliased with buf[0].re.
3469
3470 foreach (k; 1 .. halfN / 2 + 1)
3471 {
3472 immutable bufk = buf[k];
3473 immutable bufnk = buf[buf.length / 2 - k];
3474 evenFft[k].re = 0.5 * (bufk.re + bufnk.re);
3475 evenFft[halfN - k].re = evenFft[k].re;
3476 evenFft[k].im = 0.5 * (bufk.im - bufnk.im);
3477 evenFft[halfN - k].im = -evenFft[k].im;
3478
3479 oddFft[k].re = 0.5 * (bufk.im + bufnk.im);
3480 oddFft[halfN - k].re = oddFft[k].re;
3481 oddFft[k].im = 0.5 * (bufnk.re - bufk.re);
3482 oddFft[halfN - k].im = -oddFft[k].im;
3483 }
3484
3485 butterfly(buf);
3486 }
3487
3488 void butterfly(R)(R buf) const
3489 in
3490 {
3491 assert(isPowerOf2(buf.length));
3492 }
3493 do
3494 {
3495 immutable n = buf.length;
3496 immutable localLookup = negSinLookup[bsf(n)];
3497 assert(localLookup.length == n);
3498
3499 immutable cosMask = n - 1;
3500 immutable cosAdd = n / 4 * 3;
3501
3502 lookup_t negSinFromLookup(size_t index) pure nothrow
3503 {
3504 return localLookup[index];
3505 }
3506
3507 lookup_t cosFromLookup(size_t index) pure nothrow
3508 {
3509 // cos is just -sin shifted by PI * 3 / 2.
3510 return localLookup[(index + cosAdd) & cosMask];
3511 }
3512
3513 immutable halfLen = n / 2;
3514
3515 // This loop is unrolled and the two iterations are interleaved
3516 // relative to the textbook FFT to increase ILP. This gives roughly 5%
3517 // speedups on DMD.
3518 for (size_t k = 0; k < halfLen; k += 2)
3519 {
3520 immutable cosTwiddle1 = cosFromLookup(k);
3521 immutable sinTwiddle1 = negSinFromLookup(k);
3522 immutable cosTwiddle2 = cosFromLookup(k + 1);
3523 immutable sinTwiddle2 = negSinFromLookup(k + 1);
3524
3525 immutable realLower1 = buf[k].re;
3526 immutable imagLower1 = buf[k].im;
3527 immutable realLower2 = buf[k + 1].re;
3528 immutable imagLower2 = buf[k + 1].im;
3529
3530 immutable upperIndex1 = k + halfLen;
3531 immutable upperIndex2 = upperIndex1 + 1;
3532 immutable realUpper1 = buf[upperIndex1].re;
3533 immutable imagUpper1 = buf[upperIndex1].im;
3534 immutable realUpper2 = buf[upperIndex2].re;
3535 immutable imagUpper2 = buf[upperIndex2].im;
3536
3537 immutable realAdd1 = cosTwiddle1 * realUpper1
3538 - sinTwiddle1 * imagUpper1;
3539 immutable imagAdd1 = sinTwiddle1 * realUpper1
3540 + cosTwiddle1 * imagUpper1;
3541 immutable realAdd2 = cosTwiddle2 * realUpper2
3542 - sinTwiddle2 * imagUpper2;
3543 immutable imagAdd2 = sinTwiddle2 * realUpper2
3544 + cosTwiddle2 * imagUpper2;
3545
3546 buf[k].re += realAdd1;
3547 buf[k].im += imagAdd1;
3548 buf[k + 1].re += realAdd2;
3549 buf[k + 1].im += imagAdd2;
3550
3551 buf[upperIndex1].re = realLower1 - realAdd1;
3552 buf[upperIndex1].im = imagLower1 - imagAdd1;
3553 buf[upperIndex2].re = realLower2 - realAdd2;
3554 buf[upperIndex2].im = imagLower2 - imagAdd2;
3555 }
3556 }
3557
3558 // This constructor is used within this module for allocating the
3559 // buffer space elsewhere besides the GC heap. It's definitely **NOT**
3560 // part of the public API and definitely **IS** subject to change.
3561 //
3562 // Also, this is unsafe because the memSpace buffer will be cast
3563 // to immutable.
3564 //
3565 // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636.
3566 public this(lookup_t[] memSpace)
3567 {
3568 immutable size = memSpace.length / 2;
3569
3570 /* Create a lookup table of all negative sine values at a resolution of
3571 * size and all smaller power of two resolutions. This may seem
3572 * inefficient, but having all the lookups be next to each other in
3573 * memory at every level of iteration is a huge win performance-wise.
3574 */
3575 if (size == 0)
3576 {
3577 return;
3578 }
3579
3580 assert(isPowerOf2(size),
3581 "Can only do FFTs on ranges with a size that is a power of two.");
3582
3583 auto table = new lookup_t[][bsf(size) + 1];
3584
3585 table[$ - 1] = memSpace[$ - size..$];
3586 memSpace = memSpace[0 .. size];
3587
3588 auto lastRow = table[$ - 1];
3589 lastRow[0] = 0; // -sin(0) == 0.
3590 foreach (ptrdiff_t i; 1 .. size)
3591 {
3592 // The hard coded cases are for improved accuracy and to prevent
3593 // annoying non-zeroness when stuff should be zero.
3594
3595 if (i == size / 4)
3596 lastRow[i] = -1; // -sin(pi / 2) == -1.
3597 else if (i == size / 2)
3598 lastRow[i] = 0; // -sin(pi) == 0.
3599 else if (i == size * 3 / 4)
3600 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1
3601 else
3602 lastRow[i] = -sin(i * 2.0L * PI / size);
3603 }
3604
3605 // Fill in all the other rows with strided versions.
3606 foreach (i; 1 .. table.length - 1)
3607 {
3608 immutable strideLength = size / (2 ^^ i);
3609 auto strided = Stride!(lookup_t[])(lastRow, strideLength);
3610 table[i] = memSpace[$ - strided.length..$];
3611 memSpace = memSpace[0..$ - strided.length];
3612
3613 size_t copyIndex;
3614 foreach (elem; strided)
3615 {
3616 table[i][copyIndex++] = elem;
3617 }
3618 }
3619
3620 negSinLookup = cast(immutable) table;
3621 }
3622
3623 public:
3624 /**Create an `Fft` object for computing fast Fourier transforms of
3625 * power of two sizes of `size` or smaller. `size` must be a
3626 * power of two.
3627 */
3628 this(size_t size)
3629 {
3630 // Allocate all twiddle factor buffers in one contiguous block so that,
3631 // when one is done being used, the next one is next in cache.
3632 auto memSpace = uninitializedArray!(lookup_t[])(2 * size);
3633 this(memSpace);
3634 }
3635
3636 @property size_t size() const
3637 {
3638 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length;
3639 }
3640
3641 /**Compute the Fourier transform of range using the $(BIGOH N log N)
3642 * Cooley-Tukey Algorithm. `range` must be a random-access range with
3643 * slicing and a length equal to `size` as provided at the construction of
3644 * this object. The contents of range can be either numeric types,
3645 * which will be interpreted as pure real values, or complex types with
3646 * properties or members `.re` and `.im` that can be read.
3647 *
3648 * Note: Pure real FFTs are automatically detected and the relevant
3649 * optimizations are performed.
3650 *
3651 * Returns: An array of complex numbers representing the transformed data in
3652 * the frequency domain.
3653 *
3654 * Conventions: The exponent is negative and the factor is one,
3655 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ].
3656 */
3657 Complex!F[] fft(F = double, R)(R range) const
3658 if (isFloatingPoint!F && isRandomAccessRange!R)
3659 {
3660 enforceSize(range);
3661 Complex!F[] ret;
3662 if (range.length == 0)
3663 {
3664 return ret;
3665 }
3666
3667 // Don't waste time initializing the memory for ret.
3668 ret = uninitializedArray!(Complex!F[])(range.length);
3669
3670 fft(range, ret);
3671 return ret;
3672 }
3673
3674 /**Same as the overload, but allows for the results to be stored in a user-
3675 * provided buffer. The buffer must be of the same length as range, must be
3676 * a random-access range, must have slicing, and must contain elements that are
3677 * complex-like. This means that they must have a .re and a .im member or
3678 * property that can be both read and written and are floating point numbers.
3679 */
3680 void fft(Ret, R)(R range, Ret buf) const
3681 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3682 {
3683 assert(buf.length == range.length);
3684 enforceSize(range);
3685
3686 if (range.length == 0)
3687 {
3688 return;
3689 }
3690 else if (range.length == 1)
3691 {
3692 buf[0] = range[0];
3693 return;
3694 }
3695 else if (range.length == 2)
3696 {
3697 slowFourier2(range, buf);
3698 return;
3699 }
3700 else
3701 {
3702 alias E = ElementType!R;
3703 static if (is(E : real))
3704 {
3705 return fftImplPureReal(range, buf);
3706 }
3707 else
3708 {
3709 static if (is(R : Stride!R))
3710 return fftImpl(range, buf);
3711 else
3712 return fftImpl(Stride!R(range, 1), buf);
3713 }
3714 }
3715 }
3716
3717 /**
3718 * Computes the inverse Fourier transform of a range. The range must be a
3719 * random access range with slicing, have a length equal to the size
3720 * provided at construction of this object, and contain elements that are
3721 * either of type std.complex.Complex or have essentially
3722 * the same compile-time interface.
3723 *
3724 * Returns: The time-domain signal.
3725 *
3726 * Conventions: The exponent is positive and the factor is 1/N, i.e.,
3727 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ].
3728 */
3729 Complex!F[] inverseFft(F = double, R)(R range) const
3730 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F)
3731 {
3732 enforceSize(range);
3733 Complex!F[] ret;
3734 if (range.length == 0)
3735 {
3736 return ret;
3737 }
3738
3739 // Don't waste time initializing the memory for ret.
3740 ret = uninitializedArray!(Complex!F[])(range.length);
3741
3742 inverseFft(range, ret);
3743 return ret;
3744 }
3745
3746 /**
3747 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer
3748 * must be a random access range with slicing, and its elements
3749 * must be some complex-like type.
3750 */
3751 void inverseFft(Ret, R)(R range, Ret buf) const
3752 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3753 {
3754 enforceSize(range);
3755
3756 auto swapped = map!swapRealImag(range);
3757 fft(swapped, buf);
3758
3759 immutable lenNeg1 = 1.0 / buf.length;
3760 foreach (ref elem; buf)
3761 {
3762 immutable temp = elem.re * lenNeg1;
3763 elem.re = elem.im * lenNeg1;
3764 elem.im = temp;
3765 }
3766 }
3767 }
3768
3769 // This mixin creates an Fft object in the scope it's mixed into such that all
3770 // memory owned by the object is deterministically destroyed at the end of that
3771 // scope.
3772 private enum string MakeLocalFft = q{
3773 import core.stdc.stdlib;
3774 import core.exception : onOutOfMemoryError;
3775
3776 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof))
3777 [0 .. 2 * range.length];
3778 if (!lookupBuf.ptr)
3779 onOutOfMemoryError();
3780
3781 scope(exit) free(cast(void*) lookupBuf.ptr);
3782 auto fftObj = scoped!Fft(lookupBuf);
3783 };
3784
3785 /**Convenience functions that create an `Fft` object, run the FFT or inverse
3786 * FFT and return the result. Useful for one-off FFTs.
3787 *
3788 * Note: In addition to convenience, these functions are slightly more
3789 * efficient than manually creating an Fft object for a single use,
3790 * as the Fft object is deterministically destroyed before these
3791 * functions return.
3792 */
3793 Complex!F[] fft(F = double, R)(R range)
3794 {
3795 mixin(MakeLocalFft);
3796 return fftObj.fft!(F, R)(range);
3797 }
3798
3799 /// ditto
3800 void fft(Ret, R)(R range, Ret buf)
3801 {
3802 mixin(MakeLocalFft);
3803 return fftObj.fft!(Ret, R)(range, buf);
3804 }
3805
3806 /// ditto
3807 Complex!F[] inverseFft(F = double, R)(R range)
3808 {
3809 mixin(MakeLocalFft);
3810 return fftObj.inverseFft!(F, R)(range);
3811 }
3812
3813 /// ditto
3814 void inverseFft(Ret, R)(R range, Ret buf)
3815 {
3816 mixin(MakeLocalFft);
3817 return fftObj.inverseFft!(Ret, R)(range, buf);
3818 }
3819
3820 @system unittest
3821 {
3822 import std.algorithm;
3823 import std.conv;
3824 import std.range;
3825 // Test values from R and Octave.
3826 auto arr = [1,2,3,4,5,6,7,8];
3827 auto fft1 = fft(arr);
3828 assert(isClose(map!"a.re"(fft1),
3829 [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4));
3830 assert(isClose(map!"a.im"(fft1),
3831 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4));
3832
3833 auto fft1Retro = fft(retro(arr));
3834 assert(isClose(map!"a.re"(fft1Retro),
3835 [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4));
3836 assert(isClose(map!"a.im"(fft1Retro),
3837 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4));
3838
3839 auto fft1Float = fft(to!(float[])(arr));
3840 assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float)));
3841 assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float)));
3842
3843 alias C = Complex!float;
3844 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10),
3845 C(11,12), C(13,14), C(15,16)];
3846 auto fft2 = fft(arr2);
3847 assert(isClose(map!"a.re"(fft2),
3848 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4));
3849 assert(isClose(map!"a.im"(fft2),
3850 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4));
3851
3852 auto inv1 = inverseFft(fft1);
3853 assert(isClose(map!"a.re"(inv1), arr, 1e-6));
3854 assert(reduce!max(map!"a.im"(inv1)) < 1e-10);
3855
3856 auto inv2 = inverseFft(fft2);
3857 assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2)));
3858 assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2)));
3859
3860 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here.
3861 ushort[] empty;
3862 assert(fft(empty) == null);
3863 assert(inverseFft(fft(empty)) == null);
3864
3865 real[] oneElem = [4.5L];
3866 auto oneFft = fft(oneElem);
3867 assert(oneFft.length == 1);
3868 assert(oneFft[0].re == 4.5L);
3869 assert(oneFft[0].im == 0);
3870
3871 auto oneInv = inverseFft(oneFft);
3872 assert(oneInv.length == 1);
3873 assert(isClose(oneInv[0].re, 4.5));
3874 assert(isClose(oneInv[0].im, 0, 0.0, 1e-10));
3875
3876 long[2] twoElems = [8, 4];
3877 auto twoFft = fft(twoElems[]);
3878 assert(twoFft.length == 2);
3879 assert(isClose(twoFft[0].re, 12));
3880 assert(isClose(twoFft[0].im, 0, 0.0, 1e-10));
3881 assert(isClose(twoFft[1].re, 4));
3882 assert(isClose(twoFft[1].im, 0, 0.0, 1e-10));
3883 auto twoInv = inverseFft(twoFft);
3884 assert(isClose(twoInv[0].re, 8));
3885 assert(isClose(twoInv[0].im, 0, 0.0, 1e-10));
3886 assert(isClose(twoInv[1].re, 4));
3887 assert(isClose(twoInv[1].im, 0, 0.0, 1e-10));
3888 }
3889
3890 // Swaps the real and imaginary parts of a complex number. This is useful
3891 // for inverse FFTs.
3892 C swapRealImag(C)(C input)
3893 {
3894 return C(input.im, input.re);
3895 }
3896
3897 /** This function transforms `decimal` value into a value in the factorial number
3898 system stored in `fac`.
3899
3900 A factorial number is constructed as:
3901 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!)
3902
3903 Params:
3904 decimal = The decimal value to convert into the factorial number system.
3905 fac = The array to store the factorial number. The array is of size 21 as
3906 `ulong.max` requires 21 digits in the factorial number system.
3907 Returns:
3908 A variable storing the number of digits of the factorial number stored in
3909 `fac`.
3910 */
3911 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac)
3912 @safe pure nothrow @nogc
3913 {
3914 import std.algorithm.mutation : reverse;
3915 size_t idx;
3916
3917 for (ulong i = 1; decimal != 0; ++i)
3918 {
3919 auto temp = decimal % i;
3920 decimal /= i;
3921 fac[idx++] = cast(ubyte)(temp);
3922 }
3923
3924 if (idx == 0)
3925 {
3926 fac[idx++] = cast(ubyte) 0;
3927 }
3928
3929 reverse(fac[0 .. idx]);
3930
3931 // first digit of the number in factorial will always be zero
3932 assert(fac[idx - 1] == 0);
3933
3934 return idx;
3935 }
3936
3937 ///
3938 @safe pure @nogc unittest
3939 {
3940 ubyte[21] fac;
3941 size_t idx = decimalToFactorial(2982, fac);
3942
3943 assert(fac[0] == 4);
3944 assert(fac[1] == 0);
3945 assert(fac[2] == 4);
3946 assert(fac[3] == 1);
3947 assert(fac[4] == 0);
3948 assert(fac[5] == 0);
3949 assert(fac[6] == 0);
3950 }
3951
3952 @safe pure unittest
3953 {
3954 ubyte[21] fac;
3955 size_t idx = decimalToFactorial(0UL, fac);
3956 assert(idx == 1);
3957 assert(fac[0] == 0);
3958
3959 fac[] = 0;
3960 idx = 0;
3961 idx = decimalToFactorial(ulong.max, fac);
3962 assert(idx == 21);
3963 auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0];
3964 foreach (i, it; fac[0 .. 21])
3965 {
3966 assert(it == t[i]);
3967 }
3968
3969 fac[] = 0;
3970 idx = decimalToFactorial(2982, fac);
3971
3972 assert(idx == 7);
3973 t = [4, 0, 4, 1, 0, 0, 0];
3974 foreach (i, it; fac[0 .. idx])
3975 {
3976 assert(it == t[i]);
3977 }
3978 }
3979
3980 private:
3981 // The reasons I couldn't use std.algorithm were b/c its stride length isn't
3982 // modifiable on the fly and because range has grown some performance hacks
3983 // for powers of 2.
3984 struct Stride(R)
3985 {
3986 import core.bitop : bsf;
3987 Unqual!R range;
3988 size_t _nSteps;
3989 size_t _length;
3990 alias E = ElementType!(R);
3991
3992 this(R range, size_t nStepsIn)
3993 {
3994 this.range = range;
3995 _nSteps = nStepsIn;
3996 _length = (range.length + _nSteps - 1) / nSteps;
3997 }
3998
3999 size_t length() const @property
4000 {
4001 return _length;
4002 }
4003
4004 typeof(this) save() @property
4005 {
4006 auto ret = this;
4007 ret.range = ret.range.save;
4008 return ret;
4009 }
4010
4011 E opIndex(size_t index)
4012 {
4013 return range[index * _nSteps];
4014 }
4015
4016 E front() @property
4017 {
4018 return range[0];
4019 }
4020
4021 void popFront()
4022 {
4023 if (range.length >= _nSteps)
4024 {
4025 range = range[_nSteps .. range.length];
4026 _length--;
4027 }
4028 else
4029 {
4030 range = range[0 .. 0];
4031 _length = 0;
4032 }
4033 }
4034
4035 // Pops half the range's stride.
4036 void popHalf()
4037 {
4038 range = range[_nSteps / 2 .. range.length];
4039 }
4040
4041 bool empty() const @property
4042 {
4043 return length == 0;
4044 }
4045
4046 size_t nSteps() const @property
4047 {
4048 return _nSteps;
4049 }
4050
4051 void doubleSteps()
4052 {
4053 _nSteps *= 2;
4054 _length /= 2;
4055 }
4056
4057 size_t nSteps(size_t newVal) @property
4058 {
4059 _nSteps = newVal;
4060
4061 // Using >> bsf(nSteps) is a few cycles faster than / nSteps.
4062 _length = (range.length + _nSteps - 1) >> bsf(nSteps);
4063 return newVal;
4064 }
4065 }
4066
4067 // Hard-coded base case for FFT of size 2. This is actually a TON faster than
4068 // using a generic slow DFT. This seems to be the best base case. (Size 1
4069 // can be coded inline as buf[0] = range[0]).
4070 void slowFourier2(Ret, R)(R range, Ret buf)
4071 {
4072 assert(range.length == 2);
4073 assert(buf.length == 2);
4074 buf[0] = range[0] + range[1];
4075 buf[1] = range[0] - range[1];
4076 }
4077
4078 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size
4079 // 2 case.
4080 void slowFourier4(Ret, R)(R range, Ret buf)
4081 {
4082 alias C = ElementType!Ret;
4083
4084 assert(range.length == 4);
4085 assert(buf.length == 4);
4086 buf[0] = range[0] + range[1] + range[2] + range[3];
4087 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1);
4088 buf[2] = range[0] - range[1] + range[2] - range[3];
4089 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1);
4090 }
4091
4092 N roundDownToPowerOf2(N)(N num)
4093 if (isScalarType!N && !isFloatingPoint!N)
4094 {
4095 import core.bitop : bsr;
4096 return num & (cast(N) 1 << bsr(num));
4097 }
4098
4099 @safe unittest
4100 {
4101 assert(roundDownToPowerOf2(7) == 4);
4102 assert(roundDownToPowerOf2(4) == 4);
4103 }
4104
4105 template isComplexLike(T)
4106 {
4107 enum bool isComplexLike = is(typeof(T.init.re)) &&
4108 is(typeof(T.init.im));
4109 }
4110
4111 @safe unittest
4112 {
4113 static assert(isComplexLike!(Complex!double));
4114 static assert(!isComplexLike!(uint));
4115 }