]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue 24316: Fix types.coroutine() to accept objects from Cython
authorYury Selivanov <yselivanov@sprymix.com>
Fri, 29 May 2015 13:06:05 +0000 (09:06 -0400)
committerYury Selivanov <yselivanov@sprymix.com>
Fri, 29 May 2015 13:06:05 +0000 (09:06 -0400)
Lib/test/test_types.py
Lib/types.py

index ccaf414c4495ddb8962cacfbb603f96d518d120a..956214d080762bd62b3bff85006aa9ef14946df9 100644 (file)
@@ -1196,11 +1196,39 @@ class CoroutineTests(unittest.TestCase):
                 pass
         def bar(): pass
 
-        samples = [Foo, Foo(), bar, None, int, 1]
+        samples = [None, 1, object()]
         for sample in samples:
-            with self.assertRaisesRegex(TypeError, 'expects a generator'):
+            with self.assertRaisesRegex(TypeError,
+                                        'types.coroutine.*expects a callable'):
                 types.coroutine(sample)
 
+    def test_wrong_func(self):
+        @types.coroutine
+        def foo():
+            pass
+        @types.coroutine
+        def gen():
+            def _gen(): yield
+            return _gen()
+
+        for sample in (foo, gen):
+            with self.assertRaisesRegex(TypeError,
+                                        'callable wrapped .* non-coroutine'):
+                sample()
+
+    def test_duck_coro(self):
+        class CoroLike:
+            def send(self): pass
+            def throw(self): pass
+            def close(self): pass
+            def __await__(self): pass
+
+        coro = CoroLike()
+        @types.coroutine
+        def foo():
+            return coro
+        self.assertIs(coro, foo())
+
     def test_genfunc(self):
         def gen():
             yield
index 49e4d04cb173dff36f9b55445cc79c67849f711c..e9cc7948050a8aa059083061c8c055d7eb6e5c43 100644 (file)
@@ -43,30 +43,6 @@ MemberDescriptorType = type(FunctionType.__globals__)
 del sys, _f, _g, _C,                              # Not for export
 
 
-_CO_GENERATOR = 0x20
-_CO_ITERABLE_COROUTINE = 0x100
-
-def coroutine(func):
-    """Convert regular generator function to a coroutine."""
-
-    # TODO: Implement this in C.
-
-    if (not isinstance(func, (FunctionType, MethodType)) or
-            not isinstance(getattr(func, '__code__', None), CodeType) or
-            not (func.__code__.co_flags & _CO_GENERATOR)):
-        raise TypeError('coroutine() expects a generator function')
-
-    co = func.__code__
-    func.__code__ = CodeType(
-        co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, co.co_stacksize,
-        co.co_flags | _CO_ITERABLE_COROUTINE,
-        co.co_code,
-        co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
-        co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
-
-    return func
-
-
 # Provide a PEP 3115 compliant mechanism for class creation
 def new_class(name, bases=(), kwds=None, exec_body=None):
     """Create a class object dynamically using the appropriate metaclass."""
@@ -182,4 +158,46 @@ class DynamicClassAttribute:
         return result
 
 
+import functools as _functools
+import collections.abc as _collections_abc
+
+def coroutine(func):
+    """Convert regular generator function to a coroutine."""
+
+    # We don't want to import 'dis' or 'inspect' just for
+    # these constants.
+    _CO_GENERATOR = 0x20
+    _CO_ITERABLE_COROUTINE = 0x100
+
+    if not callable(func):
+        raise TypeError('types.coroutine() expects a callable')
+
+    if (isinstance(func, FunctionType) and
+        isinstance(getattr(func, '__code__', None), CodeType) and
+        (func.__code__.co_flags & _CO_GENERATOR)):
+
+        # TODO: Implement this in C.
+        co = func.__code__
+        func.__code__ = CodeType(
+            co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
+            co.co_stacksize,
+            co.co_flags | _CO_ITERABLE_COROUTINE,
+            co.co_code,
+            co.co_consts, co.co_names, co.co_varnames, co.co_filename,
+            co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars,
+            co.co_cellvars)
+        return func
+
+    @_functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        coro = func(*args, **kwargs)
+        if not isinstance(coro, _collections_abc.Coroutine):
+            raise TypeError(
+                'callable wrapped with types.coroutine() returned '
+                'non-coroutine: {!r}'.format(coro))
+        return coro
+
+    return wrapped
+
+
 __all__ = [n for n in globals() if n[:1] != '_']