]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Added ``op.run_async``.
authorFederico Caselli <cfederico87@gmail.com>
Sat, 29 Apr 2023 21:25:21 +0000 (23:25 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 12 May 2023 16:06:14 +0000 (12:06 -0400)
Added :meth:`.Operations.run_async` to the operation module to allow
running async functions in the ``upgrade`` or ``downgrade`` migration
function when running alembic using an async dialect.
This function will receive as first argument an
class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction
used in the migration context.

also restore the .execute() method to BatchOperations

Fixes: #1231
Change-Id: I3c3237d570be3c9bd9834e4c61bb3231bfb82765

alembic/op.pyi
alembic/operations/base.py
alembic/operations/ops.py
alembic/util/sqla_compat.py
docs/build/unreleased/1231.rst [new file with mode: 0644]
tests/test_op.py
tools/write_pyi.py

index aa3ad2d9d10b1cbe6883d5848d139ed08ec2c075..4395f772227a8fc1f3e8b0ea19f155fc79e25f1d 100644 (file)
@@ -4,6 +4,7 @@ from __future__ import annotations
 
 from contextlib import contextmanager
 from typing import Any
+from typing import Awaitable
 from typing import Callable
 from typing import Dict
 from typing import Iterator
@@ -15,6 +16,7 @@ from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from sqlalchemy.sql.expression import TableClause
@@ -38,6 +40,8 @@ if TYPE_CHECKING:
     from .operations.ops import MigrateOperation
     from .runtime.migration import MigrationContext
     from .util.sqla_compat import _literal_bindparam
+
+_T = TypeVar("_T")
 ### end imports ###
 
 def add_column(
@@ -1238,3 +1242,28 @@ def rename_table(
      :class:`~sqlalchemy.sql.elements.quoted_name`.
 
     """
+
+def run_async(
+    async_function: Callable[..., Awaitable[_T]], *args: Any, **kw_args: Any
+) -> _T:
+    """Invoke the given asynchronous callable, passing an asynchronous
+    :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+    argument.
+
+    This method allows calling async functions from within the
+    synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+    method.
+
+    The async connection passed to the callable shares the same
+    transaction as the connection running in the migration context.
+
+    Any additional arg or kw_arg passed to this function are passed
+    to the provided async function.
+
+    .. versionadded: 1.11
+
+    .. note::
+
+        This method can be called only when alembic is called using
+        an async dialect.
+    """
index 6e45a11675fc8d924762529e98bbeaa7c8c97e67..b4190dc3b18511075afa48c9d77503f78988a38c 100644 (file)
@@ -4,6 +4,7 @@ from contextlib import contextmanager
 import re
 import textwrap
 from typing import Any
+from typing import Awaitable
 from typing import Callable
 from typing import Dict
 from typing import Iterator
@@ -14,6 +15,7 @@ from typing import Sequence  # noqa
 from typing import Tuple
 from typing import Type  # noqa
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from sqlalchemy.sql.elements import conv
@@ -28,8 +30,6 @@ from ..util.compat import inspect_getfullargspec
 from ..util.sqla_compat import _literal_bindparam
 
 
-NoneType = type(None)
-
 if TYPE_CHECKING:
     from typing import Literal
 
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
     from ..ddl import DefaultImpl
     from ..runtime.migration import MigrationContext
 __all__ = ("Operations", "BatchOperations")
+_T = TypeVar("_T")
 
 
 class AbstractOperations(util.ModuleClsProxy):
@@ -483,6 +484,46 @@ class AbstractOperations(util.ModuleClsProxy):
         """
         return self.migration_context.impl.bind  # type: ignore[return-value]
 
+    def run_async(
+        self,
+        async_function: Callable[..., Awaitable[_T]],
+        *args: Any,
+        **kw_args: Any,
+    ) -> _T:
+        """Invoke the given asynchronous callable, passing an asynchronous
+        :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+        argument.
+
+        This method allows calling async functions from within the
+        synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+        method.
+
+        The async connection passed to the callable shares the same
+        transaction as the connection running in the migration context.
+
+        Any additional arg or kw_arg passed to this function are passed
+        to the provided async function.
+
+        .. versionadded: 1.11
+
+        .. note::
+
+            This method can be called only when alembic is called using
+            an async dialect.
+        """
+        if not sqla_compat.sqla_14_18:
+            raise NotImplementedError("SQLAlchemy 1.4.18+ required")
+        sync_conn = self.get_bind()
+        if sync_conn is None:
+            raise NotImplementedError("Cannot call run_async in SQL mode")
+        if not sync_conn.dialect.is_async:
+            raise ValueError("Cannot call run_async with a sync engine")
+        from sqlalchemy.ext.asyncio import AsyncConnection
+        from sqlalchemy.util import await_only
+
+        async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+        return await_only(async_function(async_conn, *args, **kw_args))
+
 
 class Operations(AbstractOperations):
     """Define high level migration operations.
index 99d21d9eb997f68b48e133c189cc65f7f914f19c..3a002c17769e7a376099c719888a8e4f93ceec7e 100644 (file)
@@ -2375,6 +2375,7 @@ class BulkInsertOp(MigrateOperation):
 
 
 @Operations.register_operation("execute")
+@BatchOperations.register_operation("execute")
 class ExecuteSQLOp(MigrateOperation):
     """Represent an execute SQL operation."""
 
index 00703376407ea212d687e0271e3d0182309a448f..37e1ee13606f0f47b877b53f8dd698610d8d01c6 100644 (file)
@@ -61,6 +61,8 @@ _vers = tuple(
 )
 sqla_13 = _vers >= (1, 3)
 sqla_14 = _vers >= (1, 4)
+# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
+sqla_14_18 = _vers >= (1, 4, 18)
 sqla_14_26 = _vers >= (1, 4, 26)
 sqla_2 = _vers >= (2,)
 sqlalchemy_version = __version__
diff --git a/docs/build/unreleased/1231.rst b/docs/build/unreleased/1231.rst
new file mode 100644 (file)
index 0000000..37678ca
--- /dev/null
@@ -0,0 +1,11 @@
+
+.. change::
+    :tags: usecase, asyncio
+    :tickets: 1231
+
+    Added :meth:`.Operations.run_async` to the operation module to allow
+    running async functions in the ``upgrade`` or ``downgrade`` migration
+    function when running alembic using an async dialect.
+    This function will receive as first argument an
+    :class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction
+    used in the migration context.
index 8ae22a030860a63063a08b83e70fe9d2ed6c3cfd..35adeaf54f4ba7ff014ece72e597da7d750c08b4 100644 (file)
@@ -1,5 +1,8 @@
 """Test against the builders in the op.* module."""
 
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
 from sqlalchemy import Boolean
 from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
@@ -30,6 +33,7 @@ from alembic.testing import eq_
 from alembic.testing import expect_warnings
 from alembic.testing import is_not_
 from alembic.testing import mock
+from alembic.testing.assertions import expect_raises_message
 from alembic.testing.fixtures import op_fixture
 from alembic.testing.fixtures import TestBase
 from alembic.util import sqla_compat
@@ -1156,6 +1160,46 @@ class OpTest(TestBase):
             ("after_drop", "tb_test"),
         ]
 
+    @config.requirements.sqlalchemy_14
+    def test_run_async_error(self):
+        op_fixture()
+
+        async def go(conn):
+            pass
+
+        with expect_raises_message(
+            NotImplementedError, "SQLAlchemy 1.4.18. required"
+        ):
+            with patch.object(sqla_compat, "sqla_14_18", False):
+                op.run_async(go)
+        with expect_raises_message(
+            NotImplementedError, "Cannot call run_async in SQL mode"
+        ):
+            with patch.object(op._proxy, "get_bind", lambda: None):
+                op.run_async(go)
+        with expect_raises_message(
+            ValueError, "Cannot call run_async with a sync engine"
+        ):
+            op.run_async(go)
+
+    @config.requirements.sqlalchemy_14
+    def test_run_async_ok(self):
+        from sqlalchemy.ext.asyncio import AsyncConnection
+
+        op_fixture()
+        conn = op.get_bind()
+        mock_conn = MagicMock()
+        mock_fn = MagicMock()
+        with patch.object(conn.dialect, "is_async", True), patch.object(
+            AsyncConnection, "_retrieve_proxy_for_target", mock_conn
+        ), patch("sqlalchemy.util.await_only") as mock_await:
+            res = op.run_async(mock_fn, 99, foo=42)
+
+            eq_(res, mock_await.return_value)
+            mock_conn.assert_called_once_with(conn)
+            mock_await.assert_called_once_with(mock_fn.return_value)
+            mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42)
+
 
 class SQLModeOpTest(TestBase):
     def test_auto_literals(self):
index 7d2487071f4122e8c0a73d30721864f6dd45b289..82ceead70bd280f2704f2cf22f4db96700c618d2 100644 (file)
@@ -28,7 +28,6 @@ if True:  # avoid flake/zimports messing with the order
     from alembic.operations import ops
     import sqlalchemy as sa
 
-
 TRIM_MODULE = [
     "alembic.runtime.migration.",
     "alembic.operations.base.",
@@ -179,9 +178,12 @@ def _generate_stub_for_meth(
             retval = repr(annotation).replace("typing.", "")
         elif isinstance(annotation, type):
             retval = annotation.__qualname__
+        elif isinstance(annotation, typing.TypeVar):
+            retval = annotation.__name__
         else:
             retval = annotation
 
+        retval = retval.replace("~", "")  # typevar repr as "~T"
         for trim in TRIM_MODULE:
             retval = retval.replace(trim, "")
 
@@ -371,6 +373,7 @@ cls_ignore = {
     "inline_literal",
     "invoke",
     "register_operation",
+    "run_async",
 }
 
 cases = [