]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Don't import greenlet at all until it's needed
authorFederico Caselli <cfederico87@gmail.com>
Sun, 3 Sep 2023 19:24:45 +0000 (21:24 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Nov 2023 19:28:34 +0000 (14:28 -0500)
Added an initialize step to the import of
``sqlalchemy.ext.asyncio`` so that ``greenlet`` will
be imported only when the asyncio extension is first imported.
Alternatively, the ``greenlet`` library is still imported lazily on
first use to support use case that don't make direct use of the
SQLAlchemy asyncio extension.

Fixes: #10296
Change-Id: I97162a01aa29adb3e3fee97b718ab9567b2f6124

doc/build/changelog/unreleased_21/10296.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/util/_concurrency_py3k.py [deleted file]
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/concurrency.py
setup.cfg
test/base/_concurrency_fixtures.py [new file with mode: 0644]
test/base/test_concurrency.py [moved from test/base/test_concurrency_py3k.py with 89% similarity]
tox.ini

diff --git a/doc/build/changelog/unreleased_21/10296.rst b/doc/build/changelog/unreleased_21/10296.rst
new file mode 100644 (file)
index 0000000..c674ecb
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: change, asyncio
+    :tickets: 10296
+
+    Added an initialize step to the import of
+    ``sqlalchemy.ext.asyncio`` so that ``greenlet`` will
+    be imported only when the asyncio extension is first imported.
+    Alternatively, the ``greenlet`` library is still imported lazily on
+    first use to support use case that don't make direct use of the
+    SQLAlchemy asyncio extension.
\ No newline at end of file
index 8564db6f22ee3f981053546b0b47d10d3c4c6e1c..ce146dbdab93fd495311aabceee54cbade32c95b 100644 (file)
@@ -23,3 +23,7 @@ from .session import AsyncAttrs as AsyncAttrs
 from .session import AsyncSession as AsyncSession
 from .session import AsyncSessionTransaction as AsyncSessionTransaction
 from .session import close_all_sessions as close_all_sessions
+from ...util import concurrency
+
+concurrency._concurrency_shim._initialize()
+del concurrency
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py
deleted file mode 100644 (file)
index 71d10a6..0000000
+++ /dev/null
@@ -1,260 +0,0 @@
-# util/_concurrency_py3k.py
-# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
-
-from __future__ import annotations
-
-import asyncio
-from contextvars import Context
-import sys
-import typing
-from typing import Any
-from typing import Awaitable
-from typing import Callable
-from typing import Coroutine
-from typing import Optional
-from typing import TYPE_CHECKING
-from typing import TypeVar
-
-from .langhelpers import memoized_property
-from .. import exc
-from ..util.typing import Protocol
-from ..util.typing import TypeGuard
-
-_T = TypeVar("_T")
-
-if typing.TYPE_CHECKING:
-
-    class greenlet(Protocol):
-        dead: bool
-        gr_context: Optional[Context]
-
-        def __init__(self, fn: Callable[..., Any], driver: greenlet):
-            ...
-
-        def throw(self, *arg: Any) -> Any:
-            return None
-
-        def switch(self, value: Any) -> Any:
-            return None
-
-    def getcurrent() -> greenlet:
-        ...
-
-else:
-    from greenlet import getcurrent
-    from greenlet import greenlet
-
-
-# If greenlet.gr_context is present in current version of greenlet,
-# it will be set with the current context on creation.
-# Refs: https://github.com/python-greenlet/greenlet/pull/198
-_has_gr_context = hasattr(getcurrent(), "gr_context")
-
-
-def is_exit_exception(e: BaseException) -> bool:
-    # note asyncio.CancelledError is already BaseException
-    # so was an exit exception in any case
-    return not isinstance(e, Exception) or isinstance(
-        e, (asyncio.TimeoutError, asyncio.CancelledError)
-    )
-
-
-# implementation based on snaury gist at
-# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
-# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
-
-
-class _AsyncIoGreenlet(greenlet):
-    dead: bool
-
-    def __init__(self, fn: Callable[..., Any], driver: greenlet):
-        greenlet.__init__(self, fn, driver)
-        self.driver = driver
-        if _has_gr_context:
-            self.gr_context = driver.gr_context
-
-
-_T_co = TypeVar("_T_co", covariant=True)
-
-if TYPE_CHECKING:
-
-    def iscoroutine(
-        awaitable: Awaitable[_T_co],
-    ) -> TypeGuard[Coroutine[Any, Any, _T_co]]:
-        ...
-
-else:
-    iscoroutine = asyncio.iscoroutine
-
-
-def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
-    # https://docs.python.org/3/reference/datamodel.html#coroutine.close
-
-    if iscoroutine(awaitable):
-        awaitable.close()
-
-
-def await_only(awaitable: Awaitable[_T]) -> _T:
-    """Awaits an async function in a sync method.
-
-    The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_only` calls cannot be nested.
-
-    :param awaitable: The coroutine to call.
-
-    """
-    # this is called in the context greenlet while running fn
-    current = getcurrent()
-    if not isinstance(current, _AsyncIoGreenlet):
-        _safe_cancel_awaitable(awaitable)
-
-        raise exc.MissingGreenlet(
-            "greenlet_spawn has not been called; can't call await_only() "
-            "here. Was IO attempted in an unexpected place?"
-        )
-
-    # returns the control to the driver greenlet passing it
-    # a coroutine to run. Once the awaitable is done, the driver greenlet
-    # switches back to this greenlet with the result of awaitable that is
-    # then returned to the caller (or raised as error)
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
-
-
-def await_fallback(awaitable: Awaitable[_T]) -> _T:
-    """Awaits an async function in a sync method.
-
-    The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_fallback` calls cannot be nested.
-
-    :param awaitable: The coroutine to call.
-
-    """
-
-    # this is called in the context greenlet while running fn
-    current = getcurrent()
-    if not isinstance(current, _AsyncIoGreenlet):
-        loop = get_event_loop()
-        if loop.is_running():
-            _safe_cancel_awaitable(awaitable)
-
-            raise exc.MissingGreenlet(
-                "greenlet_spawn has not been called and asyncio event "
-                "loop is already running; can't call await_fallback() here. "
-                "Was IO attempted in an unexpected place?"
-            )
-        return loop.run_until_complete(awaitable)
-
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
-
-
-async def greenlet_spawn(
-    fn: Callable[..., _T],
-    *args: Any,
-    _require_await: bool = False,
-    **kwargs: Any,
-) -> _T:
-    """Runs a sync function ``fn`` in a new greenlet.
-
-    The sync function can then use :func:`await_only` to wait for async
-    functions.
-
-    :param fn: The sync callable to call.
-    :param \\*args: Positional arguments to pass to the ``fn`` callable.
-    :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
-    """
-
-    result: Any
-    context = _AsyncIoGreenlet(fn, getcurrent())
-    # runs the function synchronously in gl greenlet. If the execution
-    # is interrupted by await_only, context is not dead and result is a
-    # coroutine to wait. If the context is dead the function has
-    # returned, and its result can be returned.
-    switch_occurred = False
-    try:
-        result = context.switch(*args, **kwargs)
-        while not context.dead:
-            switch_occurred = True
-            try:
-                # wait for a coroutine from await_only and then return its
-                # result back to it.
-                value = await result
-            except BaseException:
-                # this allows an exception to be raised within
-                # the moderated greenlet so that it can continue
-                # its expected flow.
-                result = context.throw(*sys.exc_info())
-            else:
-                result = context.switch(value)
-    finally:
-        # clean up to avoid cycle resolution by gc
-        del context.driver
-    if _require_await and not switch_occurred:
-        raise exc.AwaitRequired(
-            "The current operation required an async execution but none was "
-            "detected. This will usually happen when using a non compatible "
-            "DBAPI driver. Please ensure that an async DBAPI is used."
-        )
-    return result  # type: ignore[no-any-return]
-
-
-class AsyncAdaptedLock:
-    @memoized_property
-    def mutex(self) -> asyncio.Lock:
-        # there should not be a race here for coroutines creating the
-        # new lock as we are not using await, so therefore no concurrency
-        return asyncio.Lock()
-
-    def __enter__(self) -> bool:
-        # await is used to acquire the lock only after the first calling
-        # coroutine has created the mutex.
-        return await_fallback(self.mutex.acquire())
-
-    def __exit__(self, *arg: Any, **kw: Any) -> None:
-        self.mutex.release()
-
-
-def _util_async_run_coroutine_function(
-    fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
-) -> Any:
-    """for test suite/ util only"""
-
-    loop = get_event_loop()
-    if loop.is_running():
-        raise Exception(
-            "for async run coroutine we expect that no greenlet or event "
-            "loop is running when we start out"
-        )
-    return loop.run_until_complete(fn(*args, **kwargs))
-
-
-def _util_async_run(
-    fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
-) -> Any:
-    """for test suite/ util only"""
-
-    loop = get_event_loop()
-    if not loop.is_running():
-        return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
-    else:
-        # allow for a wrapped test function to call another
-        assert isinstance(getcurrent(), _AsyncIoGreenlet)
-        return fn(*args, **kwargs)
-
-
-def get_event_loop() -> asyncio.AbstractEventLoop:
-    """vendor asyncio.get_event_loop() for python 3.7 and above.
-
-    Python 3.10 deprecates get_event_loop() as a standalone.
-
-    """
-    try:
-        return asyncio.get_running_loop()
-    except RuntimeError:
-        # avoid "During handling of the above exception, another exception..."
-        pass
-    return asyncio.get_event_loop_policy().get_event_loop()
index 7cbaa24069f57afd4cdca5b57c0ca2f44f3707a9..1bc89970313e64c4f6d93ecbff322e6296c2ffc0 100644 (file)
@@ -166,7 +166,7 @@ else:
 
 def importlib_metadata_get(group):
     ep = importlib_metadata.entry_points()
-    if hasattr(ep, "select"):
+    if typing.TYPE_CHECKING or hasattr(ep, "select"):
         return ep.select(group=group)
     else:
         return ep.get(group, ())
index 53a70070b765c0f840f9de68892be02a3185eb6c..084374040f82e944061111fc0fc3cfe0330c4a6c 100644 (file)
 
 from __future__ import annotations
 
-import asyncio  # noqa
-import typing
-
-have_greenlet = False
-greenlet_error = None
-try:
-    import greenlet  # type: ignore # noqa: F401
-except ImportError as e:
-    greenlet_error = str(e)
-    pass
-else:
-    have_greenlet = True
-    from ._concurrency_py3k import await_only as await_only
-    from ._concurrency_py3k import await_fallback as await_fallback
-    from ._concurrency_py3k import greenlet_spawn as greenlet_spawn
-    from ._concurrency_py3k import is_exit_exception as is_exit_exception
-    from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock
-    from ._concurrency_py3k import (
-        _util_async_run as _util_async_run,
-    )  # noqa: F401
-    from ._concurrency_py3k import (
-        _util_async_run_coroutine_function as _util_async_run_coroutine_function,  # noqa: F401, E501
+import asyncio
+from contextvars import Context
+import sys
+from typing import Any
+from typing import Awaitable
+from typing import Callable
+from typing import Coroutine
+from typing import Optional
+from typing import Protocol
+from typing import TYPE_CHECKING
+from typing import TypeVar
+
+from .langhelpers import memoized_property
+from .. import exc
+from ..util.typing import TypeGuard
+
+_T = TypeVar("_T")
+
+
+def is_exit_exception(e: BaseException) -> bool:
+    # note asyncio.CancelledError is already BaseException
+    # so was an exit exception in any case
+    return not isinstance(e, Exception) or isinstance(
+        e, (asyncio.TimeoutError, asyncio.CancelledError)
     )
 
-if not typing.TYPE_CHECKING and not have_greenlet:
 
-    def _not_implemented():
-        # this conditional is to prevent pylance from considering
-        # greenlet_spawn() etc as "no return" and dimming out code below it
-        if have_greenlet:
+_ERROR_MESSAGE = (
+    "The SQLAlchemy asyncio module requires that the Python 'greenlet' "
+    "library is installed.  In order to ensure this dependency is "
+    "available, use the 'sqlalchemy[asyncio]' install target:  "
+    "'pip install sqlalchemy[asyncio]'"
+)
+
+
+if TYPE_CHECKING:
+
+    class greenlet(Protocol):
+        dead: bool
+        gr_context: Optional[Context]
+
+        def __init__(self, fn: Callable[..., Any], driver: greenlet):
+            ...
+
+        def throw(self, *arg: Any) -> Any:
+            return None
+
+        def switch(self, value: Any) -> Any:
             return None
 
-        raise ValueError(
-            "the greenlet library is required to use this function."
-            " %s" % greenlet_error
-            if greenlet_error
-            else ""
+    def getcurrent() -> greenlet:
+        ...
+
+
+class _concurrency_shim_cls:
+    """Late import shim for greenlet"""
+
+    __slots__ = (
+        "greenlet",
+        "_AsyncIoGreenlet",
+        "getcurrent",
+        "_util_async_run",
+    )
+
+    def _initialize(self, *, raise_: bool = True) -> None:
+        """Import greenlet and initialize the class"""
+        if "greenlet" in globals():
+            return
+
+        if not TYPE_CHECKING:
+            global getcurrent, greenlet, _AsyncIoGreenlet, _has_gr_context
+
+        try:
+            from greenlet import getcurrent
+            from greenlet import greenlet
+        except ImportError as e:
+            self._initialize_no_greenlet()
+            if raise_:
+                raise ImportError(_ERROR_MESSAGE) from e
+        else:
+            self._initialize_greenlet()
+
+    def _initialize_greenlet(self) -> None:
+        # If greenlet.gr_context is present in current version of greenlet,
+        # it will be set with the current context on creation.
+        # Refs: https://github.com/python-greenlet/greenlet/pull/198
+        _has_gr_context = hasattr(getcurrent(), "gr_context")
+
+        # implementation based on snaury gist at
+        # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
+        # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501
+
+        class _AsyncIoGreenlet(greenlet):
+            dead: bool
+
+            def __init__(self, fn: Callable[..., Any], driver: greenlet):
+                greenlet.__init__(self, fn, driver)
+                self.driver = driver
+                if _has_gr_context:
+                    self.gr_context = driver.gr_context
+
+        self.greenlet = greenlet
+        self.getcurrent = getcurrent
+        self._AsyncIoGreenlet = _AsyncIoGreenlet
+        self._util_async_run = self._greenlet_util_async_run
+
+    def _initialize_no_greenlet(self):
+        self._util_async_run = self._no_greenlet_util_async_run
+
+    def __getattr__(self, key: str) -> Any:
+        if key in self.__slots__:
+            self._initialize(raise_=not key.startswith("_util"))
+            return getattr(self, key)
+        else:
+            raise AttributeError(key)
+
+    def _greenlet_util_async_run(
+        self, fn: Callable[..., Any], *args: Any, **kwargs: Any
+    ) -> Any:
+        """for test suite/ util only"""
+
+        loop = get_event_loop()
+        if not loop.is_running():
+            return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
+        else:
+            # allow for a wrapped test function to call another
+            assert isinstance(
+                _concurrency_shim.getcurrent(),
+                _concurrency_shim._AsyncIoGreenlet,
+            )
+            return fn(*args, **kwargs)
+
+    def _no_greenlet_util_async_run(
+        self, fn: Callable[..., Any], *args: Any, **kwargs: Any
+    ) -> Any:
+        """for test suite/ util only"""
+
+        return fn(*args, **kwargs)
+
+
+_concurrency_shim = _concurrency_shim_cls()
+
+if TYPE_CHECKING:
+    _T_co = TypeVar("_T_co", covariant=True)
+
+    def iscoroutine(
+        awaitable: Awaitable[_T_co],
+    ) -> TypeGuard[Coroutine[Any, Any, _T_co]]:
+        ...
+
+else:
+    iscoroutine = asyncio.iscoroutine
+
+
+def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
+    # https://docs.python.org/3/reference/datamodel.html#coroutine.close
+
+    if iscoroutine(awaitable):
+        awaitable.close()
+
+
+def await_only(awaitable: Awaitable[_T]) -> _T:
+    """Awaits an async function in a sync method.
+
+    The sync method must be inside a :func:`greenlet_spawn` context.
+    :func:`await_only` calls cannot be nested.
+
+    :param awaitable: The coroutine to call.
+
+    """
+    # this is called in the context greenlet while running fn
+    current = _concurrency_shim.getcurrent()
+    if not isinstance(current, _concurrency_shim._AsyncIoGreenlet):
+        _safe_cancel_awaitable(awaitable)
+
+        raise exc.MissingGreenlet(
+            "greenlet_spawn has not been called; can't call await_only() "
+            "here. Was IO attempted in an unexpected place?"
         )
 
-    def is_exit_exception(e):  # noqa: F811
-        return not isinstance(e, Exception)
+    # returns the control to the driver greenlet passing it
+    # a coroutine to run. Once the awaitable is done, the driver greenlet
+    # switches back to this greenlet with the result of awaitable that is
+    # then returned to the caller (or raised as error)
+    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+
+
+def await_fallback(awaitable: Awaitable[_T]) -> _T:
+    """Awaits an async function in a sync method.
+
+    The sync method must be inside a :func:`greenlet_spawn` context.
+    :func:`await_fallback` calls cannot be nested.
+
+    :param awaitable: The coroutine to call.
+
+    """
+
+    # this is called in the context greenlet while running fn
+    current = _concurrency_shim.getcurrent()
+    if not isinstance(current, _concurrency_shim._AsyncIoGreenlet):
+        loop = get_event_loop()
+        if loop.is_running():
+            _safe_cancel_awaitable(awaitable)
+
+            raise exc.MissingGreenlet(
+                "greenlet_spawn has not been called and asyncio event "
+                "loop is already running; can't call await_fallback() here. "
+                "Was IO attempted in an unexpected place?"
+            )
+        return loop.run_until_complete(awaitable)
+
+    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+
+
+async def greenlet_spawn(
+    fn: Callable[..., _T],
+    *args: Any,
+    _require_await: bool = False,
+    **kwargs: Any,
+) -> _T:
+    """Runs a sync function ``fn`` in a new greenlet.
+
+    The sync function can then use :func:`await_only` to wait for async
+    functions.
+
+    :param fn: The sync callable to call.
+    :param \\*args: Positional arguments to pass to the ``fn`` callable.
+    :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
+    """
+
+    result: Any
+    context = _concurrency_shim._AsyncIoGreenlet(
+        fn, _concurrency_shim.getcurrent()
+    )
+    # runs the function synchronously in gl greenlet. If the execution
+    # is interrupted by await_only, context is not dead and result is a
+    # coroutine to wait. If the context is dead the function has
+    # returned, and its result can be returned.
+    switch_occurred = False
+    try:
+        result = context.switch(*args, **kwargs)
+        while not context.dead:
+            switch_occurred = True
+            try:
+                # wait for a coroutine from await_only and then return its
+                # result back to it.
+                value = await result
+            except BaseException:
+                # this allows an exception to be raised within
+                # the moderated greenlet so that it can continue
+                # its expected flow.
+                result = context.throw(*sys.exc_info())
+            else:
+                result = context.switch(value)
+    finally:
+        # clean up to avoid cycle resolution by gc
+        del context.driver
+    if _require_await and not switch_occurred:
+        raise exc.AwaitRequired(
+            "The current operation required an async execution but none was "
+            "detected. This will usually happen when using a non compatible "
+            "DBAPI driver. Please ensure that an async DBAPI is used."
+        )
+    return result  # type: ignore[no-any-return]
+
+
+class AsyncAdaptedLock:
+    @memoized_property
+    def mutex(self) -> asyncio.Lock:
+        # there should not be a race here for coroutines creating the
+        # new lock as we are not using await, so therefore no concurrency
+        return asyncio.Lock()
+
+    def __enter__(self) -> bool:
+        # await is used to acquire the lock only after the first calling
+        # coroutine has created the mutex.
+        return await_fallback(self.mutex.acquire())
+
+    def __exit__(self, *arg: Any, **kw: Any) -> None:
+        self.mutex.release()
+
+
+def _util_async_run_coroutine_function(
+    fn: Callable[..., Any], *args: Any, **kwargs: Any
+) -> Any:
+    """for test suite/ util only"""
+
+    loop = get_event_loop()
+    if loop.is_running():
+        raise Exception(
+            "for async run coroutine we expect that no greenlet or event "
+            "loop is running when we start out"
+        )
+    return loop.run_until_complete(fn(*args, **kwargs))
+
+
+def _util_async_run(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
+    """for test suite/ util only"""
 
-    def await_only(thing):  # type: ignore  # noqa: F811
-        _not_implemented()
+    _util_async_run = _concurrency_shim._util_async_run
 
-    def await_fallback(thing):  # type: ignore  # noqa: F811
-        return thing
+    return _util_async_run(fn, *args, **kwargs)
 
-    def greenlet_spawn(fn, *args, **kw):  # type: ignore  # noqa: F811
-        _not_implemented()
 
-    def AsyncAdaptedLock(*args, **kw):  # type: ignore  # noqa: F811
-        _not_implemented()
+def get_event_loop() -> asyncio.AbstractEventLoop:
+    """vendor asyncio.get_event_loop() for python 3.7 and above.
 
-    def _util_async_run(fn, *arg, **kw):  # type: ignore  # noqa: F811
-        return fn(*arg, **kw)
+    Python 3.10 deprecates get_event_loop() as a standalone.
 
-    def _util_async_run_coroutine_function(fn, *arg, **kw):  # type: ignore  # noqa: F811,E501
-        _not_implemented()
+    """
+    try:
+        return asyncio.get_running_loop()
+    except RuntimeError:
+        # avoid "During handling of the above exception, another exception..."
+        pass
+    return asyncio.get_event_loop_policy().get_event_loop()
index e3ae98b770869550ff9e37d5edf6b7aa026d454a..f45bfa68e601c9098393651882b786b4c4dac419 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -43,6 +43,7 @@ asyncio =
     greenlet!=0.4.17
 mypy =
     mypy >= 0.910
+    types-greenlet >= 2
 mssql = pyodbc
 mssql_pymssql = pymssql
 mssql_pyodbc = pyodbc
@@ -79,7 +80,6 @@ asyncmy =
 aiosqlite =
     %(asyncio)s
     aiosqlite
-    typing_extensions!=3.10.0.1
 sqlcipher =
     sqlcipher3_binary
 
diff --git a/test/base/_concurrency_fixtures.py b/test/base/_concurrency_fixtures.py
new file mode 100644 (file)
index 0000000..587eb64
--- /dev/null
@@ -0,0 +1,59 @@
+"""Module that defines function that are run in a separate process.
+NOTE: the module must not import sqlalchemy at the top level.
+"""
+
+import asyncio  # noqa: F401
+import sys
+
+
+def greenlet_not_imported():
+    assert "greenlet" not in sys.modules
+    assert "sqlalchemy" not in sys.modules
+
+    import sqlalchemy
+    import sqlalchemy.util.concurrency  # noqa: F401
+    from sqlalchemy.util import greenlet_spawn  # noqa: F401
+    from sqlalchemy.util.concurrency import await_only  # noqa: F401
+
+    assert "greenlet" not in sys.modules
+
+
+def greenlet_setup_in_ext():
+    assert "greenlet" not in sys.modules
+    assert "sqlalchemy" not in sys.modules
+
+    import sqlalchemy.ext.asyncio  # noqa: F401
+    from sqlalchemy.util import greenlet_spawn
+
+    assert "greenlet" in sys.modules
+    value = -1
+
+    def go(arg):
+        nonlocal value
+        value = arg
+
+    async def call():
+        await greenlet_spawn(go, 42)
+
+    asyncio.run(call())
+
+    assert value == 42
+
+
+def greenlet_setup_on_call():
+    from sqlalchemy.util import greenlet_spawn
+
+    assert "greenlet" not in sys.modules
+    value = -1
+
+    def go(arg):
+        nonlocal value
+        value = arg
+
+    async def call():
+        await greenlet_spawn(go, 42)
+
+    asyncio.run(call())
+
+    assert "greenlet" in sys.modules
+    assert value == 42
similarity index 89%
rename from test/base/test_concurrency_py3k.py
rename to test/base/test_concurrency.py
index b4fb34d0259a5d35ecbd965f6c2c0f17cf1b6ea7..04d6e5208949eb8fc07099c63d0ddd54b6a04163 100644 (file)
@@ -1,19 +1,25 @@
 import asyncio
 import contextvars
+from multiprocessing import get_context
 import random
 import threading
 
 from sqlalchemy import exc
+from sqlalchemy import testing
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing.config import combinations
 from sqlalchemy.util import await_fallback
 from sqlalchemy.util import await_only
 from sqlalchemy.util import greenlet_spawn
 from sqlalchemy.util import queue
+from ._concurrency_fixtures import greenlet_not_imported
+from ._concurrency_fixtures import greenlet_setup_in_ext
+from ._concurrency_fixtures import greenlet_setup_on_call
 
 try:
     from greenlet import greenlet
@@ -264,3 +270,23 @@ class TestAsyncAdaptedQueue(fixtures.TestBase):
         t.join()
 
         is_true(run[0])
+
+
+class GreenletImportTests(fixtures.TestBase):
+    def _run_in_process(self, fn):
+        ctx = get_context("spawn")
+        process = ctx.Process(target=fn)
+        try:
+            process.start()
+            process.join(10)
+            eq_(process.exitcode, 0)
+        finally:
+            process.kill()
+
+    @combinations(
+        greenlet_not_imported,
+        (greenlet_setup_in_ext, testing.requires.greenlet),
+        (greenlet_setup_on_call, testing.requires.greenlet),
+    )
+    def test_concurrency_fn(self, fn):
+        self._run_in_process(fn)
diff --git a/tox.ini b/tox.ini
index bc95175597526e423bf3b6a01205b28d19b103b6..d11a88202959d99971cf442a299ba35f83f7056c 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -176,6 +176,7 @@ commands=
 deps=
      greenlet != 0.4.17
      mypy >= 1.6.0
+     types-greenlet
 commands =
     mypy  {env:MYPY_COLOR} ./lib/sqlalchemy
     # pyright changes too often with not-exactly-correct errors
@@ -189,7 +190,7 @@ deps=
      greenlet != 0.4.17
      mypy >= 1.2.0
      patch==1.*
-
+     types-greenlet
 commands =
     pytest {env:PYTEST_COLOR} -m mypy {posargs}