From b2087336869511c5c6f766a1df455b76d8d23ce9 Mon Sep 17 00:00:00 2001 From: Mihail Milushev Date: Thu, 31 Aug 2023 16:02:22 -0400 Subject: [PATCH] Improve typing of `op.execute` Update type annotation for `sqltext` argument of `op.execute` to support all the documented acceptable types. Add unit tests for `str` and `TextClause` use cases for `sqltext` argument. Small repetition cleanup of documentation. Fixes: #1277 Fixes: #1058 Closes: #1278 Pull-request: https://github.com/sqlalchemy/alembic/pull/1278 Pull-request-sha: c506f99d3b26d55cbc42ae34f55dfdbcd33af234 Change-Id: I405d968d7349760d99f86d846173e75e9f61d908 --- alembic/context.pyi | 4 +-- alembic/ddl/impl.py | 27 +++++++-------- alembic/op.pyi | 12 +++---- alembic/operations/base.py | 11 +++---- alembic/operations/ops.py | 14 ++++---- alembic/runtime/environment.py | 4 +-- alembic/runtime/migration.py | 4 +-- docs/build/unreleased/op_execute.rst | 6 ++++ tests/test_op.py | 49 ++++++++++++++++++++++++++++ tools/write_pyi.py | 11 ++++--- 10 files changed, 95 insertions(+), 47 deletions(-) create mode 100644 docs/build/unreleased/op_execute.rst diff --git a/alembic/context.pyi b/alembic/context.pyi index 46979763..5c093012 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -21,7 +21,7 @@ from typing import Union if TYPE_CHECKING: from sqlalchemy.engine.base import Connection from sqlalchemy.engine.url import URL - from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql import Executable from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import FetchedValue from sqlalchemy.sql.schema import MetaData @@ -629,7 +629,7 @@ def configure( """ def execute( - sql: Union[ClauseElement, str], execution_options: Optional[dict] = None + sql: Union[Executable, str], execution_options: Optional[dict] = None ) -> None: """Execute the given SQL using the current change context. diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 5ae5f2f9..8a7c75d4 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -32,7 +32,8 @@ if TYPE_CHECKING: from sqlalchemy.engine import Dialect from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.reflection import Inspector - from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql import ClauseElement + from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.schema import Column @@ -159,7 +160,7 @@ class DefaultImpl(metaclass=ImplMeta): def _exec( self, - construct: Union[ClauseElement, str], + construct: Union[Executable, str], execution_options: Optional[dict[str, Any]] = None, multiparams: Sequence[dict] = (), params: Dict[str, Any] = util.immutabledict(), @@ -171,6 +172,7 @@ class DefaultImpl(metaclass=ImplMeta): # TODO: coverage raise Exception("Execution arguments not allowed with as_sql") + compile_kw: dict[str, Any] if self.literal_binds and not isinstance( construct, schema.DDLElement ): @@ -178,9 +180,9 @@ class DefaultImpl(metaclass=ImplMeta): else: compile_kw = {} - compiled = construct.compile( - dialect=self.dialect, **compile_kw # type: ignore[arg-type] - ) + if TYPE_CHECKING: + assert isinstance(construct, ClauseElement) + compiled = construct.compile(dialect=self.dialect, **compile_kw) self.static_output( str(compiled).replace("\t", " ").strip() + self.command_terminator @@ -195,13 +197,11 @@ class DefaultImpl(metaclass=ImplMeta): assert isinstance(multiparams, tuple) multiparams += (params,) - return conn.execute( # type: ignore[call-overload] - construct, multiparams - ) + return conn.execute(construct, multiparams) def execute( self, - sql: Union[ClauseElement, str], + sql: Union[Executable, str], execution_options: Optional[dict[str, Any]] = None, ) -> None: self._exec(sql, execution_options) @@ -578,13 +578,10 @@ class DefaultImpl(metaclass=ImplMeta): """ - compile_kw = { - "compile_kwargs": {"literal_binds": True, "include_table": False} - } + compile_kw = {"literal_binds": True, "include_table": False} + return str( - expr.compile( - dialect=self.dialect, **compile_kw # type: ignore[arg-type] - ) + expr.compile(dialect=self.dialect, compile_kwargs=compile_kw) ) def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable: diff --git a/alembic/op.pyi b/alembic/op.pyi index d2721d82..944b5ae1 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -19,14 +19,13 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.sql.expression import TableClause -from sqlalchemy.sql.expression import Update - if TYPE_CHECKING: from sqlalchemy.engine import Connection + from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed @@ -1024,7 +1023,7 @@ def drop_table_comment( """ def execute( - sqltext: Union[str, TextClause, Update], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: @@ -1093,9 +1092,8 @@ def execute( * a string * a :func:`sqlalchemy.sql.expression.text` construct. * a :func:`sqlalchemy.sql.expression.insert` construct. - * a :func:`sqlalchemy.sql.expression.update`, - :func:`sqlalchemy.sql.expression.insert`, - or :func:`sqlalchemy.sql.expression.delete` construct. + * a :func:`sqlalchemy.sql.expression.update` construct. + * a :func:`sqlalchemy.sql.expression.delete` construct. * Any "executable" described in SQLAlchemy Core documentation, noting that no result set is returned. diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 6a279ee6..e3207be7 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -35,10 +35,10 @@ if TYPE_CHECKING: from sqlalchemy import Table from sqlalchemy.engine import Connection + from sqlalchemy.sql import Executable from sqlalchemy.sql.expression import ColumnElement from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.expression import TextClause - from sqlalchemy.sql.expression import Update from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed @@ -1433,7 +1433,7 @@ class Operations(AbstractOperations): def execute( self, - sqltext: Union[str, TextClause, Update], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: @@ -1502,9 +1502,8 @@ class Operations(AbstractOperations): * a string * a :func:`sqlalchemy.sql.expression.text` construct. * a :func:`sqlalchemy.sql.expression.insert` construct. - * a :func:`sqlalchemy.sql.expression.update`, - :func:`sqlalchemy.sql.expression.insert`, - or :func:`sqlalchemy.sql.expression.delete` construct. + * a :func:`sqlalchemy.sql.expression.update` construct. + * a :func:`sqlalchemy.sql.expression.delete` construct. * Any "executable" described in SQLAlchemy Core documentation, noting that no result set is returned. @@ -1822,7 +1821,7 @@ class BatchOperations(AbstractOperations): def execute( self, - sqltext: Union[str, TextClause, Update], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index bef6e81f..fe681217 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -28,8 +28,7 @@ from ..util import sqla_compat if TYPE_CHECKING: from typing import Literal - from sqlalchemy.sql.dml import Insert - from sqlalchemy.sql.dml import Update + from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import quoted_name @@ -2423,7 +2422,7 @@ class ExecuteSQLOp(MigrateOperation): def __init__( self, - sqltext: Union[Update, str, Insert, TextClause], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: @@ -2434,7 +2433,7 @@ class ExecuteSQLOp(MigrateOperation): def execute( cls, operations: Operations, - sqltext: Union[str, TextClause, Update], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: @@ -2503,9 +2502,8 @@ class ExecuteSQLOp(MigrateOperation): * a string * a :func:`sqlalchemy.sql.expression.text` construct. * a :func:`sqlalchemy.sql.expression.insert` construct. - * a :func:`sqlalchemy.sql.expression.update`, - :func:`sqlalchemy.sql.expression.insert`, - or :func:`sqlalchemy.sql.expression.delete` construct. + * a :func:`sqlalchemy.sql.expression.update` construct. + * a :func:`sqlalchemy.sql.expression.delete` construct. * Any "executable" described in SQLAlchemy Core documentation, noting that no result set is returned. @@ -2526,7 +2524,7 @@ class ExecuteSQLOp(MigrateOperation): def batch_execute( cls, operations: Operations, - sqltext: Union[str, TextClause, Update], + sqltext: Union[Executable, str], *, execution_options: Optional[dict[str, Any]] = None, ) -> None: diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index d729da19..18840470 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -27,7 +27,7 @@ from ..operations import Operations if TYPE_CHECKING: from sqlalchemy.engine import URL from sqlalchemy.engine.base import Connection - from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql import Executable from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import SchemaItem from sqlalchemy.sql.type_api import TypeEngine @@ -938,7 +938,7 @@ class EnvironmentContext(util.ModuleClsProxy): def execute( self, - sql: Union[ClauseElement, str], + sql: Union[Executable, str], execution_options: Optional[dict] = None, ) -> None: """Execute the given SQL using the current change context. diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index c9374c22..24e3d644 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Transaction from sqlalchemy.engine.mock import MockConnection - from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql import Executable from .environment import EnvironmentContext from ..config import Config @@ -651,7 +651,7 @@ class MigrationContext: def execute( self, - sql: Union[ClauseElement, str], + sql: Union[Executable, str], execution_options: Optional[dict] = None, ) -> None: """Execute a SQL construct or string statement. diff --git a/docs/build/unreleased/op_execute.rst b/docs/build/unreleased/op_execute.rst new file mode 100644 index 00000000..aac62c7d --- /dev/null +++ b/docs/build/unreleased/op_execute.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: typing + :tickets: 1058, 1277 + + Properly type ``op.execute`` method. + Pull request curtesy of Mihail Milushev. diff --git a/tests/test_op.py b/tests/test_op.py index 67d41947..f1b8d27d 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1079,6 +1079,55 @@ class OpTest(TestBase): "FOREIGN KEY(foo_bar) REFERENCES foo (bar))" ) + def test_execute_delete(self): + context = op_fixture() + + account = table( + "account", column("name", String), column("id", Integer) + ) + op.execute(account.delete().where(account.c.name == "account 1")) + context.assert_( + "DELETE FROM account WHERE account.name = :name_1", + ) + + def test_execute_insert(self): + context = op_fixture() + + account = table( + "account", column("name", String), column("id", Integer) + ) + op.execute(account.insert().values(name="account 1")) + context.assert_( + "INSERT INTO account (name) VALUES (:name)", + ) + + def test_execute_update(self): + context = op_fixture() + + account = table( + "account", column("name", String), column("id", Integer) + ) + op.execute( + account.update() + .where(account.c.name == "account 1") + .values({"name": "account 2"}) + ) + context.assert_( + "UPDATE account SET name=:name " "WHERE account.name = :name_1", + ) + + def test_execute_str(self): + context = op_fixture() + + op.execute("SELECT 'test'") + context.assert_("SELECT 'test'") + + def test_execute_textclause(self): + context = op_fixture() + + op.execute(text("SELECT 'test'")) + context.assert_("SELECT 'test'") + def test_inline_literal(self): context = op_fixture() diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 499d830f..5abb26ef 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -29,18 +29,19 @@ if True: # avoid flake/zimports messing with the order import sqlalchemy as sa TRIM_MODULE = [ - "alembic.runtime.migration.", + "alembic.autogenerate.api.", "alembic.operations.base.", "alembic.operations.ops.", - "alembic.autogenerate.api.", + "alembic.runtime.migration.", "sqlalchemy.engine.base.", "sqlalchemy.engine.url.", + "sqlalchemy.sql.base.", + "sqlalchemy.sql.dml.", + "sqlalchemy.sql.elements.", + "sqlalchemy.sql.functions.", "sqlalchemy.sql.schema.", "sqlalchemy.sql.selectable.", - "sqlalchemy.sql.elements.", "sqlalchemy.sql.type_api.", - "sqlalchemy.sql.functions.", - "sqlalchemy.sql.dml.", "typing.", ] ADDITIONAL_ENV = { -- 2.47.2