from contextlib import contextmanager
from typing import Any
+from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from sqlalchemy.sql.expression import TableClause
from .operations.ops import MigrateOperation
from .runtime.migration import MigrationContext
from .util.sqla_compat import _literal_bindparam
+
+_T = TypeVar("_T")
### end imports ###
def add_column(
:class:`~sqlalchemy.sql.elements.quoted_name`.
"""
+
+def run_async(
+ async_function: Callable[..., Awaitable[_T]], *args: Any, **kw_args: Any
+) -> _T:
+ """Invoke the given asynchronous callable, passing an asynchronous
+ :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+ argument.
+
+ This method allows calling async functions from within the
+ synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+ method.
+
+ The async connection passed to the callable shares the same
+ transaction as the connection running in the migration context.
+
+ Any additional arg or kw_arg passed to this function are passed
+ to the provided async function.
+
+ .. versionadded: 1.11
+
+ .. note::
+
+ This method can be called only when alembic is called using
+ an async dialect.
+ """
import re
import textwrap
from typing import Any
+from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import Tuple
from typing import Type # noqa
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from sqlalchemy.sql.elements import conv
from ..util.sqla_compat import _literal_bindparam
-NoneType = type(None)
-
if TYPE_CHECKING:
from typing import Literal
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
__all__ = ("Operations", "BatchOperations")
+_T = TypeVar("_T")
class AbstractOperations(util.ModuleClsProxy):
"""
return self.migration_context.impl.bind # type: ignore[return-value]
+ def run_async(
+ self,
+ async_function: Callable[..., Awaitable[_T]],
+ *args: Any,
+ **kw_args: Any,
+ ) -> _T:
+ """Invoke the given asynchronous callable, passing an asynchronous
+ :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+ argument.
+
+ This method allows calling async functions from within the
+ synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+ method.
+
+ The async connection passed to the callable shares the same
+ transaction as the connection running in the migration context.
+
+ Any additional arg or kw_arg passed to this function are passed
+ to the provided async function.
+
+ .. versionadded: 1.11
+
+ .. note::
+
+ This method can be called only when alembic is called using
+ an async dialect.
+ """
+ if not sqla_compat.sqla_14_18:
+ raise NotImplementedError("SQLAlchemy 1.4.18+ required")
+ sync_conn = self.get_bind()
+ if sync_conn is None:
+ raise NotImplementedError("Cannot call run_async in SQL mode")
+ if not sync_conn.dialect.is_async:
+ raise ValueError("Cannot call run_async with a sync engine")
+ from sqlalchemy.ext.asyncio import AsyncConnection
+ from sqlalchemy.util import await_only
+
+ async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+ return await_only(async_function(async_conn, *args, **kw_args))
+
class Operations(AbstractOperations):
"""Define high level migration operations.
@Operations.register_operation("execute")
+@BatchOperations.register_operation("execute")
class ExecuteSQLOp(MigrateOperation):
"""Represent an execute SQL operation."""
)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
+# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
+sqla_14_18 = _vers >= (1, 4, 18)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__
--- /dev/null
+
+.. change::
+ :tags: usecase, asyncio
+ :tickets: 1231
+
+ Added :meth:`.Operations.run_async` to the operation module to allow
+ running async functions in the ``upgrade`` or ``downgrade`` migration
+ function when running alembic using an async dialect.
+ This function will receive as first argument an
+ :class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction
+ used in the migration context.
"""Test against the builders in the op.* module."""
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from alembic.testing import expect_warnings
from alembic.testing import is_not_
from alembic.testing import mock
+from alembic.testing.assertions import expect_raises_message
from alembic.testing.fixtures import op_fixture
from alembic.testing.fixtures import TestBase
from alembic.util import sqla_compat
("after_drop", "tb_test"),
]
+ @config.requirements.sqlalchemy_14
+ def test_run_async_error(self):
+ op_fixture()
+
+ async def go(conn):
+ pass
+
+ with expect_raises_message(
+ NotImplementedError, "SQLAlchemy 1.4.18. required"
+ ):
+ with patch.object(sqla_compat, "sqla_14_18", False):
+ op.run_async(go)
+ with expect_raises_message(
+ NotImplementedError, "Cannot call run_async in SQL mode"
+ ):
+ with patch.object(op._proxy, "get_bind", lambda: None):
+ op.run_async(go)
+ with expect_raises_message(
+ ValueError, "Cannot call run_async with a sync engine"
+ ):
+ op.run_async(go)
+
+ @config.requirements.sqlalchemy_14
+ def test_run_async_ok(self):
+ from sqlalchemy.ext.asyncio import AsyncConnection
+
+ op_fixture()
+ conn = op.get_bind()
+ mock_conn = MagicMock()
+ mock_fn = MagicMock()
+ with patch.object(conn.dialect, "is_async", True), patch.object(
+ AsyncConnection, "_retrieve_proxy_for_target", mock_conn
+ ), patch("sqlalchemy.util.await_only") as mock_await:
+ res = op.run_async(mock_fn, 99, foo=42)
+
+ eq_(res, mock_await.return_value)
+ mock_conn.assert_called_once_with(conn)
+ mock_await.assert_called_once_with(mock_fn.return_value)
+ mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42)
+
class SQLModeOpTest(TestBase):
def test_auto_literals(self):
from alembic.operations import ops
import sqlalchemy as sa
-
TRIM_MODULE = [
"alembic.runtime.migration.",
"alembic.operations.base.",
retval = repr(annotation).replace("typing.", "")
elif isinstance(annotation, type):
retval = annotation.__qualname__
+ elif isinstance(annotation, typing.TypeVar):
+ retval = annotation.__name__
else:
retval = annotation
+ retval = retval.replace("~", "") # typevar repr as "~T"
for trim in TRIM_MODULE:
retval = retval.replace(trim, "")
"inline_literal",
"invoke",
"register_operation",
+ "run_async",
}
cases = [