]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109218: Refactor tests for the complex() constructor (GH-119635)
authorSerhiy Storchaka <storchaka@gmail.com>
Thu, 30 May 2024 17:35:59 +0000 (20:35 +0300)
committerGitHub <noreply@github.com>
Thu, 30 May 2024 17:35:59 +0000 (17:35 +0000)
* Share common classes.
* Use exactly representable floats and exact tests.
* Check the sign of zero components.
* Remove duplicated tests (mostly left after merging int and long).
* Reorder tests in more consistent way.
* Test more error messages.
* Add tests for missed cases.

Lib/test/test_complex.py

index fa3017b24e16c84d8a77d91c9ecdeaae72d40f24..f29b7d3ebd31abf4374cd397faed54dad992606f 100644 (file)
@@ -5,7 +5,7 @@ from test.test_grammar import (VALID_UNDERSCORE_LITERALS,
                                INVALID_UNDERSCORE_LITERALS)
 
 from random import random
-from math import atan2, isnan, copysign
+from math import isnan, copysign
 import operator
 
 INF = float("inf")
@@ -21,6 +21,27 @@ ZERO_DIVISION = (
     (1, 0+0j),
 )
 
+class WithIndex:
+    def __init__(self, value):
+        self.value = value
+    def __index__(self):
+        return self.value
+
+class WithFloat:
+    def __init__(self, value):
+        self.value = value
+    def __float__(self):
+        return self.value
+
+class ComplexSubclass(complex):
+    pass
+
+class WithComplex:
+    def __init__(self, value):
+        self.value = value
+    def __complex__(self):
+        return self.value
+
 class ComplexTest(unittest.TestCase):
 
     def assertAlmostEqual(self, a, b):
@@ -340,137 +361,90 @@ class ComplexTest(unittest.TestCase):
         self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j)
 
     def test_constructor(self):
-        class NS:
-            def __init__(self, value): self.value = value
-            def __complex__(self): return self.value
-        self.assertEqual(complex(NS(1+10j)), 1+10j)
-        self.assertRaises(TypeError, complex, NS(None))
-        self.assertRaises(TypeError, complex, {})
-        self.assertRaises(TypeError, complex, NS(1.5))
-        self.assertRaises(TypeError, complex, NS(1))
-        self.assertRaises(TypeError, complex, object())
-        self.assertRaises(TypeError, complex, NS(4.25+0.5j), object())
-
-        self.assertAlmostEqual(complex("1+10j"), 1+10j)
-        self.assertAlmostEqual(complex(10), 10+0j)
-        self.assertAlmostEqual(complex(10.0), 10+0j)
-        self.assertAlmostEqual(complex(10), 10+0j)
-        self.assertAlmostEqual(complex(10+0j), 10+0j)
-        self.assertAlmostEqual(complex(1,10), 1+10j)
-        self.assertAlmostEqual(complex(1,10), 1+10j)
-        self.assertAlmostEqual(complex(1,10.0), 1+10j)
-        self.assertAlmostEqual(complex(1,10), 1+10j)
-        self.assertAlmostEqual(complex(1,10), 1+10j)
-        self.assertAlmostEqual(complex(1,10.0), 1+10j)
-        self.assertAlmostEqual(complex(1.0,10), 1+10j)
-        self.assertAlmostEqual(complex(1.0,10), 1+10j)
-        self.assertAlmostEqual(complex(1.0,10.0), 1+10j)
-        self.assertAlmostEqual(complex(3.14+0j), 3.14+0j)
-        self.assertAlmostEqual(complex(3.14), 3.14+0j)
-        self.assertAlmostEqual(complex(314), 314.0+0j)
-        self.assertAlmostEqual(complex(314), 314.0+0j)
-        self.assertAlmostEqual(complex(3.14+0j, 0j), 3.14+0j)
-        self.assertAlmostEqual(complex(3.14, 0.0), 3.14+0j)
-        self.assertAlmostEqual(complex(314, 0), 314.0+0j)
-        self.assertAlmostEqual(complex(314, 0), 314.0+0j)
-        self.assertAlmostEqual(complex(0j, 3.14j), -3.14+0j)
-        self.assertAlmostEqual(complex(0.0, 3.14j), -3.14+0j)
-        self.assertAlmostEqual(complex(0j, 3.14), 3.14j)
-        self.assertAlmostEqual(complex(0.0, 3.14), 3.14j)
-        self.assertAlmostEqual(complex("1"), 1+0j)
-        self.assertAlmostEqual(complex("1j"), 1j)
-        self.assertAlmostEqual(complex(),  0)
-        self.assertAlmostEqual(complex("-1"), -1)
-        self.assertAlmostEqual(complex("+1"), +1)
-        self.assertAlmostEqual(complex("(1+2j)"), 1+2j)
-        self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j)
-        self.assertAlmostEqual(complex("3.14+1J"), 3.14+1j)
-        self.assertAlmostEqual(complex(" ( +3.14-6J )"), 3.14-6j)
-        self.assertAlmostEqual(complex(" ( +3.14-J )"), 3.14-1j)
-        self.assertAlmostEqual(complex(" ( +3.14+j )"), 3.14+1j)
-        self.assertAlmostEqual(complex("J"), 1j)
-        self.assertAlmostEqual(complex("( j )"), 1j)
-        self.assertAlmostEqual(complex("+J"), 1j)
-        self.assertAlmostEqual(complex("( -j)"), -1j)
-        self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j)
-        self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j)
-        self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j)
-        self.assertEqual(complex('1-1j'), 1.0 - 1j)
-        self.assertEqual(complex('1J'), 1j)
-
-        class complex2(complex): pass
-        self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j)
-        self.assertAlmostEqual(complex(real=17, imag=23), 17+23j)
-        self.assertAlmostEqual(complex(real=17+23j), 17+23j)
-        self.assertAlmostEqual(complex(real=17+23j, imag=23), 17+46j)
-        self.assertAlmostEqual(complex(real=1+2j, imag=3+4j), -3+5j)
+        def check(z, x, y):
+            self.assertIs(type(z), complex)
+            self.assertFloatsAreIdentical(z.real, x)
+            self.assertFloatsAreIdentical(z.imag, y)
+
+        check(complex(),  0.0, 0.0)
+        check(complex(10), 10.0, 0.0)
+        check(complex(4.25), 4.25, 0.0)
+        check(complex(4.25+0j), 4.25, 0.0)
+        check(complex(4.25+0.5j), 4.25, 0.5)
+        check(complex(ComplexSubclass(4.25+0.5j)), 4.25, 0.5)
+        check(complex(WithComplex(4.25+0.5j)), 4.25, 0.5)
+
+        check(complex(1, 10), 1.0, 10.0)
+        check(complex(1, 10.0), 1.0, 10.0)
+        check(complex(1, 4.25), 1.0, 4.25)
+        check(complex(1.0, 10), 1.0, 10.0)
+        check(complex(4.25, 10), 4.25, 10.0)
+        check(complex(1.0, 10.0), 1.0, 10.0)
+        check(complex(4.25, 0.5), 4.25, 0.5)
+
+        check(complex(4.25+0j, 0), 4.25, 0.0)
+        check(complex(ComplexSubclass(4.25+0j), 0), 4.25, 0.0)
+        check(complex(WithComplex(4.25+0j), 0), 4.25, 0.0)
+        check(complex(4.25j, 0), 0.0, 4.25)
+        check(complex(0j, 4.25), 0.0, 4.25)
+        check(complex(0, 4.25+0j), 0.0, 4.25)
+        check(complex(0, ComplexSubclass(4.25+0j)), 0.0, 4.25)
+        with self.assertRaisesRegex(TypeError,
+                "second argument must be a number, not 'WithComplex'"):
+            complex(0, WithComplex(4.25+0j))
+        check(complex(0.0, 4.25j), -4.25, 0.0)
+        check(complex(4.25+0j, 0j), 4.25, 0.0)
+        check(complex(4.25j, 0j), 0.0, 4.25)
+        check(complex(0j, 4.25+0j), 0.0, 4.25)
+        check(complex(0j, 4.25j), -4.25, 0.0)
+
+        check(complex(real=4.25), 4.25, 0.0)
+        check(complex(real=4.25+0j), 4.25, 0.0)
+        check(complex(real=4.25+1.5j), 4.25, 1.5)
+        check(complex(imag=1.5), 0.0, 1.5)
+        check(complex(real=4.25, imag=1.5), 4.25, 1.5)
+        check(complex(4.25, imag=1.5), 4.25, 1.5)
 
         # check that the sign of a zero in the real or imaginary part
-        # is preserved when constructing from two floats.  (These checks
-        # are harmless on systems without support for signed zeros.)
-        def split_zeros(x):
-            """Function that produces different results for 0. and -0."""
-            return atan2(x, -1.)
-
-        self.assertEqual(split_zeros(complex(1., 0.).imag), split_zeros(0.))
-        self.assertEqual(split_zeros(complex(1., -0.).imag), split_zeros(-0.))
-        self.assertEqual(split_zeros(complex(0., 1.).real), split_zeros(0.))
-        self.assertEqual(split_zeros(complex(-0., 1.).real), split_zeros(-0.))
-
-        c = 3.14 + 1j
-        self.assertTrue(complex(c) is c)
-        del c
-
-        self.assertRaises(TypeError, complex, "1", "1")
-        self.assertRaises(TypeError, complex, 1, "1")
-
-        # SF bug 543840:  complex(string) accepts strings with \0
-        # Fixed in 2.3.
-        self.assertRaises(ValueError, complex, '1+1j\0j')
-
-        self.assertRaises(TypeError, int, 5+3j)
-        self.assertRaises(TypeError, int, 5+3j)
-        self.assertRaises(TypeError, float, 5+3j)
-        self.assertRaises(ValueError, complex, "")
-        self.assertRaises(TypeError, complex, None)
-        self.assertRaisesRegex(TypeError, "not 'NoneType'", complex, None)
-        self.assertRaises(ValueError, complex, "\0")
-        self.assertRaises(ValueError, complex, "3\09")
-        self.assertRaises(TypeError, complex, "1", "2")
-        self.assertRaises(TypeError, complex, "1", 42)
-        self.assertRaises(TypeError, complex, 1, "2")
-        self.assertRaises(ValueError, complex, "1+")
-        self.assertRaises(ValueError, complex, "1+1j+1j")
-        self.assertRaises(ValueError, complex, "--")
-        self.assertRaises(ValueError, complex, "(1+2j")
-        self.assertRaises(ValueError, complex, "1+2j)")
-        self.assertRaises(ValueError, complex, "1+(2j)")
-        self.assertRaises(ValueError, complex, "(1+2j)123")
-        self.assertRaises(ValueError, complex, "x")
-        self.assertRaises(ValueError, complex, "1j+2")
-        self.assertRaises(ValueError, complex, "1e1ej")
-        self.assertRaises(ValueError, complex, "1e++1ej")
-        self.assertRaises(ValueError, complex, ")1+2j(")
-        self.assertRaisesRegex(
-            TypeError,
+        # is preserved when constructing from two floats.
+        for x in 1.0, -1.0:
+            for y in 0.0, -0.0:
+                check(complex(x, y), x, y)
+                check(complex(y, x), y, x)
+
+        c = complex(4.25, 1.5)
+        self.assertIs(complex(c), c)
+        c2 = ComplexSubclass(c)
+        self.assertEqual(c2, c)
+        self.assertIs(type(c2), ComplexSubclass)
+        del c, c2
+
+        self.assertRaisesRegex(TypeError,
+            "first argument must be a string or a number, not 'dict'",
+            complex, {})
+        self.assertRaisesRegex(TypeError,
+            "first argument must be a string or a number, not 'NoneType'",
+            complex, None)
+        self.assertRaisesRegex(TypeError,
             "first argument must be a string or a number, not 'dict'",
-            complex, {1:2}, 1)
-        self.assertRaisesRegex(
-            TypeError,
+            complex, {1:2}, 0)
+        self.assertRaisesRegex(TypeError,
+            "can't take second arg if first is a string",
+            complex, '1', 0)
+        self.assertRaisesRegex(TypeError,
             "second argument must be a number, not 'dict'",
-            complex, 1, {1:2})
-        # the following three are accepted by Python 2.6
-        self.assertRaises(ValueError, complex, "1..1j")
-        self.assertRaises(ValueError, complex, "1.11.1j")
-        self.assertRaises(ValueError, complex, "1e1.1j")
-
-        # check that complex accepts long unicode strings
-        self.assertEqual(type(complex("1"*500)), complex)
-        # check whitespace processing
-        self.assertEqual(complex('\N{EM SPACE}(\N{EN SPACE}1+1j ) '), 1+1j)
-        # Invalid unicode string
-        # See bpo-34087
-        self.assertRaises(ValueError, complex, '\u3053\u3093\u306b\u3061\u306f')
+            complex, 0, {1:2})
+        self.assertRaisesRegex(TypeError,
+                "second arg can't be a string",
+            complex, 0, '1')
+
+        self.assertRaises(TypeError, complex, WithComplex(1.5))
+        self.assertRaises(TypeError, complex, WithComplex(1))
+        self.assertRaises(TypeError, complex, WithComplex(None))
+        self.assertRaises(TypeError, complex, WithComplex(4.25+0j), object())
+        self.assertRaises(TypeError, complex, WithComplex(1.5), object())
+        self.assertRaises(TypeError, complex, WithComplex(1), object())
+        self.assertRaises(TypeError, complex, WithComplex(None), object())
 
         class EvilExc(Exception):
             pass
@@ -481,33 +455,33 @@ class ComplexTest(unittest.TestCase):
 
         self.assertRaises(EvilExc, complex, evilcomplex())
 
-        class float2:
-            def __init__(self, value):
-                self.value = value
-            def __float__(self):
-                return self.value
-
-        self.assertAlmostEqual(complex(float2(42.)), 42)
-        self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j)
-        self.assertRaises(TypeError, complex, float2(None))
-
-        class MyIndex:
-            def __init__(self, value):
-                self.value = value
-            def __index__(self):
-                return self.value
-
-        self.assertAlmostEqual(complex(MyIndex(42)), 42.0+0.0j)
-        self.assertAlmostEqual(complex(123, MyIndex(42)), 123.0+42.0j)
-        self.assertRaises(OverflowError, complex, MyIndex(2**2000))
-        self.assertRaises(OverflowError, complex, 123, MyIndex(2**2000))
+        check(complex(WithFloat(4.25)), 4.25, 0.0)
+        check(complex(WithFloat(4.25), 1.5), 4.25, 1.5)
+        check(complex(1.5, WithFloat(4.25)), 1.5, 4.25)
+        self.assertRaises(TypeError, complex, WithFloat(42))
+        self.assertRaises(TypeError, complex, WithFloat(42), 1.5)
+        self.assertRaises(TypeError, complex, 1.5, WithFloat(42))
+        self.assertRaises(TypeError, complex, WithFloat(None))
+        self.assertRaises(TypeError, complex, WithFloat(None), 1.5)
+        self.assertRaises(TypeError, complex, 1.5, WithFloat(None))
+
+        check(complex(WithIndex(42)), 42.0, 0.0)
+        check(complex(WithIndex(42), 1.5), 42.0, 1.5)
+        check(complex(1.5, WithIndex(42)), 1.5, 42.0)
+        self.assertRaises(OverflowError, complex, WithIndex(2**2000))
+        self.assertRaises(OverflowError, complex, WithIndex(2**2000), 1.5)
+        self.assertRaises(OverflowError, complex, 1.5, WithIndex(2**2000))
+        self.assertRaises(TypeError, complex, WithIndex(None))
+        self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
+        self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
 
         class MyInt:
             def __int__(self):
                 return 42
 
         self.assertRaises(TypeError, complex, MyInt())
-        self.assertRaises(TypeError, complex, 123, MyInt())
+        self.assertRaises(TypeError, complex, MyInt(), 1.5)
+        self.assertRaises(TypeError, complex, 1.5, MyInt())
 
         class complex0(complex):
             """Test usage of __complex__() when inheriting from 'complex'"""
@@ -527,9 +501,9 @@ class ComplexTest(unittest.TestCase):
             def __complex__(self):
                 return None
 
-        self.assertEqual(complex(complex0(1j)), 42j)
+        check(complex(complex0(1j)), 0.0, 42.0)
         with self.assertWarns(DeprecationWarning):
-            self.assertEqual(complex(complex1(1j)), 2j)
+            check(complex(complex1(1j)), 0.0, 2.0)
         self.assertRaises(TypeError, complex, complex2(1j))
 
     def test___complex__(self):
@@ -537,36 +511,93 @@ class ComplexTest(unittest.TestCase):
         self.assertEqual(z.__complex__(), z)
         self.assertEqual(type(z.__complex__()), complex)
 
-        class complex_subclass(complex):
-            pass
-
-        z = complex_subclass(3 + 4j)
+        z = ComplexSubclass(3 + 4j)
         self.assertEqual(z.__complex__(), 3 + 4j)
         self.assertEqual(type(z.__complex__()), complex)
 
     @support.requires_IEEE_754
     def test_constructor_special_numbers(self):
-        class complex2(complex):
-            pass
         for x in 0.0, -0.0, INF, -INF, NAN:
             for y in 0.0, -0.0, INF, -INF, NAN:
                 with self.subTest(x=x, y=y):
                     z = complex(x, y)
                     self.assertFloatsAreIdentical(z.real, x)
                     self.assertFloatsAreIdentical(z.imag, y)
-                    z = complex2(x, y)
-                    self.assertIs(type(z), complex2)
+                    z = ComplexSubclass(x, y)
+                    self.assertIs(type(z), ComplexSubclass)
                     self.assertFloatsAreIdentical(z.real, x)
                     self.assertFloatsAreIdentical(z.imag, y)
-                    z = complex(complex2(x, y))
+                    z = complex(ComplexSubclass(x, y))
                     self.assertIs(type(z), complex)
                     self.assertFloatsAreIdentical(z.real, x)
                     self.assertFloatsAreIdentical(z.imag, y)
-                    z = complex2(complex(x, y))
-                    self.assertIs(type(z), complex2)
+                    z = ComplexSubclass(complex(x, y))
+                    self.assertIs(type(z), ComplexSubclass)
                     self.assertFloatsAreIdentical(z.real, x)
                     self.assertFloatsAreIdentical(z.imag, y)
 
+    def test_constructor_from_string(self):
+        def check(z, x, y):
+            self.assertIs(type(z), complex)
+            self.assertFloatsAreIdentical(z.real, x)
+            self.assertFloatsAreIdentical(z.imag, y)
+
+        check(complex("1"), 1.0, 0.0)
+        check(complex("1j"), 0.0, 1.0)
+        check(complex("-1"), -1.0, 0.0)
+        check(complex("+1"), 1.0, 0.0)
+        check(complex("1+2j"), 1.0, 2.0)
+        check(complex("(1+2j)"), 1.0, 2.0)
+        check(complex("(1.5+4.25j)"), 1.5, 4.25)
+        check(complex("4.25+1J"), 4.25, 1.0)
+        check(complex(" ( +4.25-6J )"), 4.25, -6.0)
+        check(complex(" ( +4.25-J )"), 4.25, -1.0)
+        check(complex(" ( +4.25+j )"), 4.25, 1.0)
+        check(complex("J"), 0.0, 1.0)
+        check(complex("( j )"), 0.0, 1.0)
+        check(complex("+J"), 0.0, 1.0)
+        check(complex("( -j)"), 0.0, -1.0)
+        check(complex('1-1j'), 1.0, -1.0)
+        check(complex('1J'), 0.0, 1.0)
+
+        check(complex('1e-500'), 0.0, 0.0)
+        check(complex('-1e-500j'), 0.0, -0.0)
+        check(complex('1e-500+1e-500j'), 0.0, 0.0)
+        check(complex('-1e-500+1e-500j'), -0.0, 0.0)
+        check(complex('1e-500-1e-500j'), 0.0, -0.0)
+        check(complex('-1e-500-1e-500j'), -0.0, -0.0)
+
+        # SF bug 543840:  complex(string) accepts strings with \0
+        # Fixed in 2.3.
+        self.assertRaises(ValueError, complex, '1+1j\0j')
+        self.assertRaises(ValueError, complex, "")
+        self.assertRaises(ValueError, complex, "\0")
+        self.assertRaises(ValueError, complex, "3\09")
+        self.assertRaises(ValueError, complex, "1+")
+        self.assertRaises(ValueError, complex, "1+1j+1j")
+        self.assertRaises(ValueError, complex, "--")
+        self.assertRaises(ValueError, complex, "(1+2j")
+        self.assertRaises(ValueError, complex, "1+2j)")
+        self.assertRaises(ValueError, complex, "1+(2j)")
+        self.assertRaises(ValueError, complex, "(1+2j)123")
+        self.assertRaises(ValueError, complex, "x")
+        self.assertRaises(ValueError, complex, "1j+2")
+        self.assertRaises(ValueError, complex, "1e1ej")
+        self.assertRaises(ValueError, complex, "1e++1ej")
+        self.assertRaises(ValueError, complex, ")1+2j(")
+        # the following three are accepted by Python 2.6
+        self.assertRaises(ValueError, complex, "1..1j")
+        self.assertRaises(ValueError, complex, "1.11.1j")
+        self.assertRaises(ValueError, complex, "1e1.1j")
+
+        # check that complex accepts long unicode strings
+        self.assertIs(type(complex("1"*500)), complex)
+        # check whitespace processing
+        self.assertEqual(complex('\N{EM SPACE}(\N{EN SPACE}1+1j ) '), 1+1j)
+        # Invalid unicode string
+        # See bpo-34087
+        self.assertRaises(ValueError, complex, '\u3053\u3093\u306b\u3061\u306f')
+
     def test_constructor_negative_nans_from_string(self):
         self.assertEqual(copysign(1., complex("-nan").real), -1.)
         self.assertEqual(copysign(1., complex("-nanj").imag), -1.)
@@ -645,9 +676,6 @@ class ComplexTest(unittest.TestCase):
         test(complex(-0., -0.), "(-0-0j)")
 
     def test_pos(self):
-        class ComplexSubclass(complex):
-            pass
-
         self.assertEqual(+(1+6j), 1+6j)
         self.assertEqual(+ComplexSubclass(1, 6), 1+6j)
         self.assertIs(type(+ComplexSubclass(1, 6)), complex)
@@ -667,8 +695,8 @@ class ComplexTest(unittest.TestCase):
     def test_plus_minus_0j(self):
         # test that -0j and 0j literals are not identified
         z1, z2 = 0j, -0j
-        self.assertEqual(atan2(z1.imag, -1.), atan2(0., -1.))
-        self.assertEqual(atan2(z2.imag, -1.), atan2(-0., -1.))
+        self.assertFloatsAreIdentical(z1.imag, 0.0)
+        self.assertFloatsAreIdentical(z2.imag, -0.0)
 
     @support.requires_IEEE_754
     def test_negated_imaginary_literal(self):