]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.11] gh-90805: Make sure test_functools works with and without _functoolsmodule...
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Mon, 11 Sep 2023 16:35:41 +0000 (09:35 -0700)
committerGitHub <noreply@github.com>
Mon, 11 Sep 2023 16:35:41 +0000 (16:35 +0000)
(cherry picked from commit baa6dc8e388e71b2a00347143ecefb2ad3a8e53b)

Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
Lib/test/test_functools.py

index 382e7dbffddf9d14f2e7fcb7b596dcbc114ce9b1..fb6e1860ac11fe8c9865e153ce169f184a3a1086 100644 (file)
@@ -27,10 +27,16 @@ import functools
 
 py_functools = import_helper.import_fresh_module('functools',
                                                  blocked=['_functools'])
-c_functools = import_helper.import_fresh_module('functools')
+c_functools = import_helper.import_fresh_module('functools',
+                                                fresh=['_functools'])
 
 decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
 
+_partial_types = [py_functools.partial]
+if c_functools:
+    _partial_types.append(c_functools.partial)
+
+
 @contextlib.contextmanager
 def replaced_module(name, replacement):
     original_module = sys.modules[name]
@@ -202,7 +208,7 @@ class TestPartial:
         kwargs = {'a': object(), 'b': object()}
         kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
                         'b={b!r}, a={a!r}'.format_map(kwargs)]
-        if self.partial in (c_functools.partial, py_functools.partial):
+        if self.partial in _partial_types:
             name = 'functools.partial'
         else:
             name = self.partial.__name__
@@ -224,7 +230,7 @@ class TestPartial:
                        for kwargs_repr in kwargs_reprs])
 
     def test_recursive_repr(self):
-        if self.partial in (c_functools.partial, py_functools.partial):
+        if self.partial in _partial_types:
             name = 'functools.partial'
         else:
             name = self.partial.__name__
@@ -251,7 +257,7 @@ class TestPartial:
             f.__setstate__((capture, (), {}, {}))
 
     def test_pickle(self):
-        with self.AllowPickle():
+        with replaced_module('functools', self.module):
             f = self.partial(signature, ['asdf'], bar=[True])
             f.attr = []
             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -334,7 +340,7 @@ class TestPartial:
         self.assertIs(type(r[0]), tuple)
 
     def test_recursive_pickle(self):
-        with self.AllowPickle():
+        with replaced_module('functools', self.module):
             f = self.partial(capture)
             f.__setstate__((f, (), {}, {}))
             try:
@@ -388,14 +394,9 @@ class TestPartial:
 @unittest.skipUnless(c_functools, 'requires the C _functools module')
 class TestPartialC(TestPartial, unittest.TestCase):
     if c_functools:
+        module = c_functools
         partial = c_functools.partial
 
-    class AllowPickle:
-        def __enter__(self):
-            return self
-        def __exit__(self, type, value, tb):
-            return False
-
     def test_attributes_unwritable(self):
         # attributes should not be writable
         p = self.partial(capture, 1, 2, a=10, b=20)
@@ -438,15 +439,9 @@ class TestPartialC(TestPartial, unittest.TestCase):
 
 
 class TestPartialPy(TestPartial, unittest.TestCase):
+    module = py_functools
     partial = py_functools.partial
 
-    class AllowPickle:
-        def __init__(self):
-            self._cm = replaced_module("functools", py_functools)
-        def __enter__(self):
-            return self._cm.__enter__()
-        def __exit__(self, type, value, tb):
-            return self._cm.__exit__(type, value, tb)
 
 if c_functools:
     class CPartialSubclass(c_functools.partial):
@@ -1860,9 +1855,10 @@ class TestLRU:
 def py_cached_func(x, y):
     return 3 * x + y
 
-@c_functools.lru_cache()
-def c_cached_func(x, y):
-    return 3 * x + y
+if c_functools:
+    @c_functools.lru_cache()
+    def c_cached_func(x, y):
+        return 3 * x + y
 
 
 class TestLRUPy(TestLRU, unittest.TestCase):
@@ -1879,18 +1875,20 @@ class TestLRUPy(TestLRU, unittest.TestCase):
         return 3 * x + y
 
 
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
 class TestLRUC(TestLRU, unittest.TestCase):
-    module = c_functools
-    cached_func = c_cached_func,
+    if c_functools:
+        module = c_functools
+        cached_func = c_cached_func,
 
-    @module.lru_cache()
-    def cached_meth(self, x, y):
-        return 3 * x + y
+        @module.lru_cache()
+        def cached_meth(self, x, y):
+            return 3 * x + y
 
-    @staticmethod
-    @module.lru_cache()
-    def cached_staticmeth(x, y):
-        return 3 * x + y
+        @staticmethod
+        @module.lru_cache()
+        def cached_staticmeth(x, y):
+            return 3 * x + y
 
 
 class TestSingleDispatch(unittest.TestCase):