]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Define type for generic classes
authorFederico Caselli <cfederico87@gmail.com>
Tue, 16 May 2023 19:52:02 +0000 (21:52 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 16 May 2023 19:52:02 +0000 (21:52 +0200)
Fixed typing use of :class:`~sqlalchemy.schema.Column` and other
generic SQLAlchemy classes.

Fixes: #1246
Change-Id: I5ee80395d626894a52e3395c9986213289576355

18 files changed:
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/context.pyi
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/util/sqla_compat.py
docs/build/unreleased/1246.rst [new file with mode: 0644]
tools/write_pyi.py

index 5727891f0e522f9564f03044e0c4c586b7e5af98..031d683baa5b42cc2f5644f2bb4f790ffc972beb 100644 (file)
@@ -926,8 +926,8 @@ def _compare_nullable(
     schema: Optional[str],
     tname: Union[quoted_name, str],
     cname: Union[quoted_name, str],
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> None:
 
     metadata_col_nullable = metadata_col.nullable
@@ -968,8 +968,8 @@ def _setup_autoincrement(
     schema: Optional[str],
     tname: Union[quoted_name, str],
     cname: quoted_name,
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> None:
 
     if metadata_col.table._autoincrement_column is metadata_col:
@@ -987,8 +987,8 @@ def _compare_type(
     schema: Optional[str],
     tname: Union[quoted_name, str],
     cname: Union[quoted_name, str],
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> None:
 
     conn_type = conn_col.type
@@ -1060,8 +1060,8 @@ def _compare_computed_default(
     schema: Optional[str],
     tname: str,
     cname: str,
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> None:
     rendered_metadata_default = str(
         cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
@@ -1126,8 +1126,8 @@ def _compare_server_default(
     schema: Optional[str],
     tname: Union[quoted_name, str],
     cname: Union[quoted_name, str],
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> Optional[bool]:
 
     metadata_default = metadata_col.server_default
@@ -1215,8 +1215,8 @@ def _compare_column_comment(
     schema: Optional[str],
     tname: Union[quoted_name, str],
     cname: quoted_name,
-    conn_col: Column,
-    metadata_col: Column,
+    conn_col: Column[Any],
+    metadata_col: Column[Any],
 ) -> Optional[Literal[False]]:
 
     assert autogen_context.dialect is not None
index dc841f83ef38bc10f21c60c0de704a72d4dd72b5..00d1d2fe56772a857311e4f326528512ff672987 100644 (file)
@@ -664,7 +664,9 @@ def _user_defined_render(
     return False
 
 
-def _render_column(column: Column, autogen_context: AutogenContext) -> str:
+def _render_column(
+    column: Column[Any], autogen_context: AutogenContext
+) -> str:
     rendered = _user_defined_render("column", column, autogen_context)
     if rendered is not False:
         return rendered
@@ -727,7 +729,9 @@ def _should_render_server_default_positionally(server_default: Any) -> bool:
 
 
 def _render_server_default(
-    default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]],
+    default: Optional[
+        Union[FetchedValue, str, TextClause, ColumnElement[Any]]
+    ],
     autogen_context: AutogenContext,
     repr_: bool = True,
 ) -> Optional[str]:
index 621599d345ba320f7923c9b52ca6a72d7f48cec4..eedf7afd8ea8c4ef86a399cb0e7fe6e5de4a0ce4 100644 (file)
@@ -151,8 +151,8 @@ def configure(
         Callable[
             [
                 MigrationContext,
-                Column,
-                Column,
+                Column[Any],
+                Column[Any],
                 Optional[str],
                 Optional[FetchedValue],
                 Optional[str],
index 65da32f40cf24c2f93d08de20a793b572d3b5c67..339db0c4a5d9e78e7e7b608895fb30de8653b22a 100644 (file)
@@ -150,7 +150,7 @@ class AddColumn(AlterTable):
     def __init__(
         self,
         name: str,
-        column: Column,
+        column: Column[Any],
         schema: Optional[Union[quoted_name, str]] = None,
     ) -> None:
         super().__init__(name, schema=schema)
@@ -159,7 +159,7 @@ class AddColumn(AlterTable):
 
 class DropColumn(AlterTable):
     def __init__(
-        self, name: str, column: Column, schema: Optional[str] = None
+        self, name: str, column: Column[Any], schema: Optional[str] = None
     ) -> None:
         super().__init__(name, schema=schema)
         self.column = column
@@ -320,7 +320,7 @@ def alter_column(compiler: DDLCompiler, name: str) -> str:
     return "ALTER COLUMN %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler: DDLCompiler, column: Column, **kw) -> str:
+def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
     text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
 
     const = " ".join(
index 84f5d86cc4674b1507c0c5713228a31dcb810f47..03f134d584fc641da9387666f337e7be5a71e690 100644 (file)
@@ -316,7 +316,7 @@ class DefaultImpl(metaclass=ImplMeta):
     def add_column(
         self,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         schema: Optional[Union[str, quoted_name]] = None,
     ) -> None:
         self._exec(base.AddColumn(table_name, column, schema=schema))
@@ -324,7 +324,7 @@ class DefaultImpl(metaclass=ImplMeta):
     def drop_column(
         self,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         schema: Optional[str] = None,
         **kw,
     ) -> None:
@@ -388,7 +388,7 @@ class DefaultImpl(metaclass=ImplMeta):
     def drop_table_comment(self, table: Table) -> None:
         self._exec(schema.DropTableComment(table))
 
-    def create_column_comment(self, column: ColumnElement) -> None:
+    def create_column_comment(self, column: ColumnElement[Any]) -> None:
         self._exec(schema.SetColumnComment(column))
 
     def drop_index(self, index: Index) -> None:
@@ -526,7 +526,7 @@ class DefaultImpl(metaclass=ImplMeta):
         return True
 
     def compare_type(
-        self, inspector_column: Column, metadata_column: Column
+        self, inspector_column: Column[Any], metadata_column: Column
     ) -> bool:
         """Returns True if there ARE differences between the types of the two
         columns. Takes impl.type_synonyms into account between retrospected
index ebf4db19afae99e5cd68e453f816b5ab8f66ad63..10c1a6b986c063062a8c428dbc707c2c48ec5726 100644 (file)
@@ -201,7 +201,7 @@ class MSSQLImpl(DefaultImpl):
     def drop_column(
         self,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         schema: Optional[str] = None,
         **kw,
     ) -> None:
@@ -273,7 +273,7 @@ class _ExecDropConstraint(Executable, ClauseElement):
     def __init__(
         self,
         tname: str,
-        colname: Union[Column, str],
+        colname: Union[Column[Any], str],
         type_: str,
         schema: Optional[str],
     ) -> None:
@@ -287,7 +287,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement):
     inherit_cache = False
 
     def __init__(
-        self, tname: str, colname: Column, schema: Optional[str]
+        self, tname: str, colname: Column[Any], schema: Optional[str]
     ) -> None:
         self.tname = tname
         self.colname = colname
@@ -347,7 +347,9 @@ def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
     )
 
 
-def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str:
+def mssql_add_column(
+    compiler: MSDDLCompiler, column: Column[Any], **kw
+) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
index 9715c1e81a7e67ae741dbc02b75e99e41fc3c097..e56bb2102f45d0e807ec7ad908fd86f359efc2ea 100644 (file)
@@ -176,7 +176,7 @@ def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
     return "MODIFY %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str:
+def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
index 6c858e7bdf80edc6db209a0bf369e958cc8efd24..e3ada90827af2b1230b99ba6edd4b8a462fceb14 100644 (file)
@@ -486,7 +486,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         table_name: Union[str, quoted_name],
         elements: Union[
             Sequence[Tuple[str, str]],
-            Sequence[Tuple[ColumnClause, str]],
+            Sequence[Tuple[ColumnClause[Any], str]],
         ],
         where: Optional[Union[BinaryExpression, str]] = None,
         schema: Optional[str] = None,
@@ -706,7 +706,9 @@ def _exclude_constraint(
 
 
 def _render_potential_column(
-    value: Union[ColumnClause, Column, TextClause, FunctionElement],
+    value: Union[
+        ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any]
+    ],
     autogen_context: AutogenContext,
 ) -> str:
     if isinstance(value, ColumnClause):
index 302a87752285141495fab04e0229d72b939c9d02..67a1c2845984f996eb7522bac48d8d78f1dd8719 100644 (file)
@@ -95,8 +95,8 @@ class SQLiteImpl(DefaultImpl):
 
     def compare_server_default(
         self,
-        inspector_column: Column,
-        metadata_column: Column,
+        inspector_column: Column[Any],
+        metadata_column: Column[Any],
         rendered_metadata_default: Optional[str],
         rendered_inspector_default: Optional[str],
     ) -> bool:
@@ -173,7 +173,7 @@ class SQLiteImpl(DefaultImpl):
 
     def cast_for_batch_migrate(
         self,
-        existing: Column,
+        existing: Column[Any],
         existing_transfer: Dict[str, Union[TypeEngine, Cast]],
         new_type: TypeEngine,
     ) -> None:
index 10e6f5914afbb149fe1cb0b81127180f8fd069d7..1eb14954e3de0077326bd186380c476f81bf765c 100644 (file)
@@ -45,7 +45,7 @@ _T = TypeVar("_T")
 ### end imports ###
 
 def add_column(
-    table_name: str, column: Column, *, schema: Optional[str] = None
+    table_name: str, column: Column[Any], *, schema: Optional[str] = None
 ) -> None:
     """Issue an "add column" instruction using the current
     migration context.
index 4e59e5ba56a3ce8f18438a0a9ec2ca4fd3a5d7fa..fa3fe1316f6139991ac2e1b303f13cb898c73dbc 100644 (file)
@@ -569,7 +569,7 @@ class Operations(AbstractOperations):
         def add_column(
             self,
             table_name: str,
-            column: Column,
+            column: Column[Any],
             *,
             schema: Optional[str] = None,
         ) -> None:
@@ -1574,7 +1574,7 @@ class BatchOperations(AbstractOperations):
 
         def add_column(
             self,
-            column: Column,
+            column: Column[Any],
             *,
             insert_before: Optional[str] = None,
             insert_after: Optional[str] = None,
index f4a058bc9a201f5039d14d938a0f38de43de7644..5b6b54775fb1a3d1ff15415b7697e4476c091031 100644 (file)
@@ -243,7 +243,7 @@ class ApplyBatchImpl:
 
     def _grab_table_elements(self) -> None:
         schema = self.table.schema
-        self.columns: Dict[str, Column] = OrderedDict()
+        self.columns: Dict[str, Column[Any]] = OrderedDict()
         for c in self.table.c:
             c_copy = _copy(c, schema=schema)
             c_copy.unique = c_copy.index = False
@@ -607,7 +607,7 @@ class ApplyBatchImpl:
     def add_column(
         self,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         insert_before: Optional[str] = None,
         insert_after: Optional[str] = None,
         **kw,
@@ -621,7 +621,10 @@ class ApplyBatchImpl:
         self.column_transfers[column.name] = {}
 
     def drop_column(
-        self, table_name: str, column: Union[ColumnClause, Column], **kw
+        self,
+        table_name: str,
+        column: Union[ColumnClause[Any], Column[Any]],
+        **kw,
     ) -> None:
         if column.name in self.table.primary_key.columns:
             _remove_column_from_collection(
index 5334a01e0774511c3db93e126bb14f2f25b2466b..472c0e83cb8852e60aa994f301784c548358113c 100644 (file)
@@ -1994,7 +1994,7 @@ class AddColumnOp(AlterTableOp):
     def __init__(
         self,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         *,
         schema: Optional[str] = None,
         **kw: Any,
@@ -2010,7 +2010,7 @@ class AddColumnOp(AlterTableOp):
 
     def to_diff_tuple(
         self,
-    ) -> Tuple[str, Optional[str], str, Column]:
+    ) -> Tuple[str, Optional[str], str, Column[Any]]:
         return ("add_column", self.schema, self.table_name, self.column)
 
     def to_column(self) -> Column:
@@ -2025,7 +2025,7 @@ class AddColumnOp(AlterTableOp):
         cls,
         schema: Optional[str],
         tname: str,
-        col: Column,
+        col: Column[Any],
     ) -> AddColumnOp:
         return cls(tname, col, schema=schema)
 
@@ -2034,7 +2034,7 @@ class AddColumnOp(AlterTableOp):
         cls,
         operations: Operations,
         table_name: str,
-        column: Column,
+        column: Column[Any],
         *,
         schema: Optional[str] = None,
     ) -> None:
@@ -2123,7 +2123,7 @@ class AddColumnOp(AlterTableOp):
     def batch_add_column(
         cls,
         operations: BatchOperations,
-        column: Column,
+        column: Column[Any],
         *,
         insert_before: Optional[str] = None,
         insert_after: Optional[str] = None,
@@ -2173,7 +2173,7 @@ class DropColumnOp(AlterTableOp):
 
     def to_diff_tuple(
         self,
-    ) -> Tuple[str, Optional[str], str, Column]:
+    ) -> Tuple[str, Optional[str], str, Column[Any]]:
         return (
             "remove_column",
             self.schema,
@@ -2197,7 +2197,7 @@ class DropColumnOp(AlterTableOp):
         cls,
         schema: Optional[str],
         tname: str,
-        col: Column,
+        col: Column[Any],
     ) -> DropColumnOp:
         return cls(
             tname,
index 3087377361dd4ac42cfd3435af4c25b739a95740..acd5cd1ebb2b2502ad93a676531bddc2c8bb05eb 100644 (file)
@@ -84,8 +84,8 @@ OnVersionApplyFn = Callable[
 CompareServerDefault = Callable[
     [
         MigrationContext,
-        Column,
-        Column,
+        "Column[Any]",
+        "Column[Any]",
         Optional[str],
         Optional[FetchedValue],
         Optional[str],
index 8baeaf0ba691017ea56799b6b0b9e278656b3b79..1715e8af9edcf330eba732a520d96e78e2008f39 100644 (file)
@@ -708,7 +708,7 @@ class MigrationContext:
             return None
 
     def _compare_type(
-        self, inspector_column: Column, metadata_column: Column
+        self, inspector_column: Column[Any], metadata_column: Column
     ) -> bool:
         if self._user_compare_type is False:
             return False
@@ -728,8 +728,8 @@ class MigrationContext:
 
     def _compare_server_default(
         self,
-        inspector_column: Column,
-        metadata_column: Column,
+        inspector_column: Column[Any],
+        metadata_column: Column[Any],
         rendered_metadata_default: Optional[str],
         rendered_column_default: Optional[str],
     ) -> bool:
index 37e1ee13606f0f47b877b53f8dd698610d8d01c6..376448ac34b1e7d550cc7d3adfe194a26d94f765 100644 (file)
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.selectable import Select
     from sqlalchemy.sql.selectable import TableClause
 
-_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
 
 
 def _safe_int(value: str) -> Union[int, str]:
@@ -390,7 +390,7 @@ def _find_columns(clause):
 
 
 def _remove_column_from_collection(
-    collection: ColumnCollection, column: Union[Column, ColumnClause]
+    collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
 ) -> None:
     """remove a column from a ColumnCollection."""
 
@@ -408,8 +408,8 @@ def _remove_column_from_collection(
 
 
 def _textual_index_column(
-    table: Table, text_: Union[str, TextClause, ColumnElement]
-) -> Union[ColumnElement, Column]:
+    table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
+) -> Union[ColumnElement[Any], Column[Any]]:
     """a workaround for the Index construct's severe lack of flexibility"""
     if isinstance(text_, str):
         c = Column(text_, sqltypes.NULLTYPE)
diff --git a/docs/build/unreleased/1246.rst b/docs/build/unreleased/1246.rst
new file mode 100644 (file)
index 0000000..a33de0d
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1246
+
+    Fixed typing use of :class:`~sqlalchemy.schema.Column` and other
+    generic SQLAlchemy classes.
index 82ceead70bd280f2704f2cf22f4db96700c618d2..da5b4845b969e86198bfb5832ae39aa5a5be14d3 100644 (file)
@@ -41,6 +41,7 @@ TRIM_MODULE = [
     "sqlalchemy.sql.type_api.",
     "sqlalchemy.sql.functions.",
     "sqlalchemy.sql.dml.",
+    "typing."
 ]
 ADDITIONAL_ENV = {
     "MigrationContext": MigrationContext,
@@ -180,6 +181,11 @@ def _generate_stub_for_meth(
             retval = annotation.__qualname__
         elif isinstance(annotation, typing.TypeVar):
             retval = annotation.__name__
+        elif hasattr(annotation, "__args__") and hasattr(
+            annotation, "__origin__"
+        ):
+            # generic class
+            retval = str(annotation)
         else:
             retval = annotation