]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve typing of `op.execute`
authorMihail Milushev <mihail@lanzz.org>
Thu, 31 Aug 2023 20:02:22 +0000 (16:02 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 14 Sep 2023 19:56:34 +0000 (21:56 +0200)
Update type annotation for `sqltext` argument of `op.execute` to support
all the documented acceptable types.
Add unit tests for `str` and `TextClause` use cases for `sqltext` argument.
Small repetition cleanup of documentation.

Fixes: #1277
Fixes: #1058
Closes: #1278
Pull-request: https://github.com/sqlalchemy/alembic/pull/1278
Pull-request-sha: c506f99d3b26d55cbc42ae34f55dfdbcd33af234

Change-Id: I405d968d7349760d99f86d846173e75e9f61d908

alembic/context.pyi
alembic/ddl/impl.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/ops.py
alembic/runtime/environment.py
alembic/runtime/migration.py
docs/build/unreleased/op_execute.rst [new file with mode: 0644]
tests/test_op.py
tools/write_pyi.py

index 469797631bc5735d440459a39ce3ea8c320b968e..5c0930129c143d0a56da57794ab984b08995277c 100644 (file)
@@ -21,7 +21,7 @@ from typing import Union
 if TYPE_CHECKING:
     from sqlalchemy.engine.base import Connection
     from sqlalchemy.engine.url import URL
-    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import FetchedValue
     from sqlalchemy.sql.schema import MetaData
@@ -629,7 +629,7 @@ def configure(
     """
 
 def execute(
-    sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
+    sql: Union[Executable, str], execution_options: Optional[dict] = None
 ) -> None:
     """Execute the given SQL using the current change context.
 
index 5ae5f2f93f3a69358ef4c7d69de4333467aaf659..8a7c75d46170ae2738a931b387f5080b404f010c 100644 (file)
@@ -32,7 +32,8 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Dialect
     from sqlalchemy.engine.cursor import CursorResult
     from sqlalchemy.engine.reflection import Inspector
-    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql import ClauseElement
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.schema import Column
@@ -159,7 +160,7 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def _exec(
         self,
-        construct: Union[ClauseElement, str],
+        construct: Union[Executable, str],
         execution_options: Optional[dict[str, Any]] = None,
         multiparams: Sequence[dict] = (),
         params: Dict[str, Any] = util.immutabledict(),
@@ -171,6 +172,7 @@ class DefaultImpl(metaclass=ImplMeta):
                 # TODO: coverage
                 raise Exception("Execution arguments not allowed with as_sql")
 
+            compile_kw: dict[str, Any]
             if self.literal_binds and not isinstance(
                 construct, schema.DDLElement
             ):
@@ -178,9 +180,9 @@ class DefaultImpl(metaclass=ImplMeta):
             else:
                 compile_kw = {}
 
-            compiled = construct.compile(
-                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
-            )
+            if TYPE_CHECKING:
+                assert isinstance(construct, ClauseElement)
+            compiled = construct.compile(dialect=self.dialect, **compile_kw)
             self.static_output(
                 str(compiled).replace("\t", "    ").strip()
                 + self.command_terminator
@@ -195,13 +197,11 @@ class DefaultImpl(metaclass=ImplMeta):
                 assert isinstance(multiparams, tuple)
                 multiparams += (params,)
 
-            return conn.execute(  # type: ignore[call-overload]
-                construct, multiparams
-            )
+            return conn.execute(construct, multiparams)
 
     def execute(
         self,
-        sql: Union[ClauseElement, str],
+        sql: Union[Executable, str],
         execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
         self._exec(sql, execution_options)
@@ -578,13 +578,10 @@ class DefaultImpl(metaclass=ImplMeta):
 
         """
 
-        compile_kw = {
-            "compile_kwargs": {"literal_binds": True, "include_table": False}
-        }
+        compile_kw = {"literal_binds": True, "include_table": False}
+
         return str(
-            expr.compile(
-                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
-            )
+            expr.compile(dialect=self.dialect, compile_kwargs=compile_kw)
         )
 
     def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
index d2721d829b4d918c1bc34bc8cad78c53bdbc3bb4..944b5ae16a64be6670f32b811e897027de2210e8 100644 (file)
@@ -19,14 +19,13 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from sqlalchemy.sql.expression import TableClause
-from sqlalchemy.sql.expression import Update
-
 if TYPE_CHECKING:
     from sqlalchemy.engine import Connection
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
@@ -1024,7 +1023,7 @@ def drop_table_comment(
     """
 
 def execute(
-    sqltext: Union[str, TextClause, Update],
+    sqltext: Union[Executable, str],
     *,
     execution_options: Optional[dict[str, Any]] = None,
 ) -> None:
@@ -1093,9 +1092,8 @@ def execute(
     * a string
     * a :func:`sqlalchemy.sql.expression.text` construct.
     * a :func:`sqlalchemy.sql.expression.insert` construct.
-    * a :func:`sqlalchemy.sql.expression.update`,
-      :func:`sqlalchemy.sql.expression.insert`,
-      or :func:`sqlalchemy.sql.expression.delete`  construct.
+    * a :func:`sqlalchemy.sql.expression.update` construct.
+    * a :func:`sqlalchemy.sql.expression.delete` construct.
     * Any "executable" described in SQLAlchemy Core documentation,
       noting that no result set is returned.
 
index 6a279ee63c24e9f545656a5ecfa8d82b7b7fdf87..e3207be765f0fc6c9fdc5949e79bd7c10cb8d6f0 100644 (file)
@@ -35,10 +35,10 @@ if TYPE_CHECKING:
 
     from sqlalchemy import Table
     from sqlalchemy.engine import Connection
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.expression import ColumnElement
     from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.expression import TextClause
-    from sqlalchemy.sql.expression import Update
     from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
@@ -1433,7 +1433,7 @@ class Operations(AbstractOperations):
 
         def execute(
             self,
-            sqltext: Union[str, TextClause, Update],
+            sqltext: Union[Executable, str],
             *,
             execution_options: Optional[dict[str, Any]] = None,
         ) -> None:
@@ -1502,9 +1502,8 @@ class Operations(AbstractOperations):
             * a string
             * a :func:`sqlalchemy.sql.expression.text` construct.
             * a :func:`sqlalchemy.sql.expression.insert` construct.
-            * a :func:`sqlalchemy.sql.expression.update`,
-              :func:`sqlalchemy.sql.expression.insert`,
-              or :func:`sqlalchemy.sql.expression.delete`  construct.
+            * a :func:`sqlalchemy.sql.expression.update` construct.
+            * a :func:`sqlalchemy.sql.expression.delete` construct.
             * Any "executable" described in SQLAlchemy Core documentation,
               noting that no result set is returned.
 
@@ -1822,7 +1821,7 @@ class BatchOperations(AbstractOperations):
 
         def execute(
             self,
-            sqltext: Union[str, TextClause, Update],
+            sqltext: Union[Executable, str],
             *,
             execution_options: Optional[dict[str, Any]] = None,
         ) -> None:
index bef6e81f4a61ee92ca7b7d0f75acd3b14c639ac0..fe681217520385031b1f6f0c665178a44a503c13 100644 (file)
@@ -28,8 +28,7 @@ from ..util import sqla_compat
 if TYPE_CHECKING:
     from typing import Literal
 
-    from sqlalchemy.sql.dml import Insert
-    from sqlalchemy.sql.dml import Update
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import quoted_name
@@ -2423,7 +2422,7 @@ class ExecuteSQLOp(MigrateOperation):
 
     def __init__(
         self,
-        sqltext: Union[Update, str, Insert, TextClause],
+        sqltext: Union[Executable, str],
         *,
         execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
@@ -2434,7 +2433,7 @@ class ExecuteSQLOp(MigrateOperation):
     def execute(
         cls,
         operations: Operations,
-        sqltext: Union[str, TextClause, Update],
+        sqltext: Union[Executable, str],
         *,
         execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
@@ -2503,9 +2502,8 @@ class ExecuteSQLOp(MigrateOperation):
         * a string
         * a :func:`sqlalchemy.sql.expression.text` construct.
         * a :func:`sqlalchemy.sql.expression.insert` construct.
-        * a :func:`sqlalchemy.sql.expression.update`,
-          :func:`sqlalchemy.sql.expression.insert`,
-          or :func:`sqlalchemy.sql.expression.delete`  construct.
+        * a :func:`sqlalchemy.sql.expression.update` construct.
+        * a :func:`sqlalchemy.sql.expression.delete` construct.
         * Any "executable" described in SQLAlchemy Core documentation,
           noting that no result set is returned.
 
@@ -2526,7 +2524,7 @@ class ExecuteSQLOp(MigrateOperation):
     def batch_execute(
         cls,
         operations: Operations,
-        sqltext: Union[str, TextClause, Update],
+        sqltext: Union[Executable, str],
         *,
         execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
index d729da1990f599cd88f8a5d522abc2a404655e0c..18840470a2f256338a1985fee6ad70c69cfdb0f0 100644 (file)
@@ -27,7 +27,7 @@ from ..operations import Operations
 if TYPE_CHECKING:
     from sqlalchemy.engine import URL
     from sqlalchemy.engine.base import Connection
-    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql import Executable
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import SchemaItem
     from sqlalchemy.sql.type_api import TypeEngine
@@ -938,7 +938,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def execute(
         self,
-        sql: Union[ClauseElement, str],
+        sql: Union[Executable, str],
         execution_options: Optional[dict] = None,
     ) -> None:
         """Execute the given SQL using the current change context.
index c9374c227bd252ff2e981e2cadd80a3c7ac28bbb..24e3d6449f4e650c6ab78e8abcc27a118023eba1 100644 (file)
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
     from sqlalchemy.engine.base import Connection
     from sqlalchemy.engine.base import Transaction
     from sqlalchemy.engine.mock import MockConnection
-    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql import Executable
 
     from .environment import EnvironmentContext
     from ..config import Config
@@ -651,7 +651,7 @@ class MigrationContext:
 
     def execute(
         self,
-        sql: Union[ClauseElement, str],
+        sql: Union[Executable, str],
         execution_options: Optional[dict] = None,
     ) -> None:
         """Execute a SQL construct or string statement.
diff --git a/docs/build/unreleased/op_execute.rst b/docs/build/unreleased/op_execute.rst
new file mode 100644 (file)
index 0000000..aac62c7
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: typing
+    :tickets: 1058, 1277
+
+    Properly type ``op.execute`` method.
+    Pull request curtesy of Mihail Milushev.
index 67d419478114bc73825524141ccdad4662925ca9..f1b8d27d3651901f65b144e9c263ee1301730675 100644 (file)
@@ -1079,6 +1079,55 @@ class OpTest(TestBase):
             "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
         )
 
+    def test_execute_delete(self):
+        context = op_fixture()
+
+        account = table(
+            "account", column("name", String), column("id", Integer)
+        )
+        op.execute(account.delete().where(account.c.name == "account 1"))
+        context.assert_(
+            "DELETE FROM account WHERE account.name = :name_1",
+        )
+
+    def test_execute_insert(self):
+        context = op_fixture()
+
+        account = table(
+            "account", column("name", String), column("id", Integer)
+        )
+        op.execute(account.insert().values(name="account 1"))
+        context.assert_(
+            "INSERT INTO account (name) VALUES (:name)",
+        )
+
+    def test_execute_update(self):
+        context = op_fixture()
+
+        account = table(
+            "account", column("name", String), column("id", Integer)
+        )
+        op.execute(
+            account.update()
+            .where(account.c.name == "account 1")
+            .values({"name": "account 2"})
+        )
+        context.assert_(
+            "UPDATE account SET name=:name " "WHERE account.name = :name_1",
+        )
+
+    def test_execute_str(self):
+        context = op_fixture()
+
+        op.execute("SELECT 'test'")
+        context.assert_("SELECT 'test'")
+
+    def test_execute_textclause(self):
+        context = op_fixture()
+
+        op.execute(text("SELECT 'test'"))
+        context.assert_("SELECT 'test'")
+
     def test_inline_literal(self):
         context = op_fixture()
 
index 499d830fe79966b54d05f2cfda188f8083dd7aeb..5abb26ef1eb30f422153ba5e3a0b97d3d85feee9 100644 (file)
@@ -29,18 +29,19 @@ if True:  # avoid flake/zimports messing with the order
     import sqlalchemy as sa
 
 TRIM_MODULE = [
-    "alembic.runtime.migration.",
+    "alembic.autogenerate.api.",
     "alembic.operations.base.",
     "alembic.operations.ops.",
-    "alembic.autogenerate.api.",
+    "alembic.runtime.migration.",
     "sqlalchemy.engine.base.",
     "sqlalchemy.engine.url.",
+    "sqlalchemy.sql.base.",
+    "sqlalchemy.sql.dml.",
+    "sqlalchemy.sql.elements.",
+    "sqlalchemy.sql.functions.",
     "sqlalchemy.sql.schema.",
     "sqlalchemy.sql.selectable.",
-    "sqlalchemy.sql.elements.",
     "sqlalchemy.sql.type_api.",
-    "sqlalchemy.sql.functions.",
-    "sqlalchemy.sql.dml.",
     "typing.",
 ]
 ADDITIONAL_ENV = {