]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
fix(typings): improve typing for server_default
authorSebastian Kreft <911768+sk-@users.noreply.github.com>
Sun, 18 Jan 2026 15:26:05 +0000 (10:26 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Jan 2026 18:43:15 +0000 (13:43 -0500)
Fixed typing issue where the :paramref:`.AlterColumnOp.server_default` and
:paramref:`.AlterColumnOp.existing_server_default` parameters failed to
accommodate common SQLAlchemy SQL constructs such as ``null()`` and
``text()``.   Pull request courtesy Sebastian Kreft.

this sets up a standard type for the server default argument using
an alias, and adds modifications to write_pyi for extremely basic
ability to render type aliases (with limitations).

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Fixes: #1669
Closes: #1670
Pull-request: https://github.com/sqlalchemy/alembic/pull/1670
Pull-request-sha: e6464647b6e33e077e7baf4bbc5c7549ab570a06

Change-Id: Id25bf7fd706f91aa637adf9b67f0529f1d7d1080

12 files changed:
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/postgresql.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/toimpl.py
docs/build/unreleased/1669.rst [new file with mode: 0644]
tools/write_pyi.py

index ad2847eb2f76066264f2218ede2e173032082f92..550fe147fe006ed5e5dfc1ece880026eb9dfe623 100644 (file)
@@ -4,6 +4,7 @@
 from __future__ import annotations
 
 import functools
+from typing import Any
 from typing import Optional
 from typing import TYPE_CHECKING
 from typing import Union
@@ -14,7 +15,10 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import DDLElement
+from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.sql.elements import quoted_name
+from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.schema import FetchedValue
 
 from ..util.sqla_compat import _columns_for_constraint  # noqa
 from ..util.sqla_compat import _find_columns  # noqa
@@ -23,20 +27,16 @@ from ..util.sqla_compat import _is_type_bound  # noqa
 from ..util.sqla_compat import _table_for_constraint  # noqa
 
 if TYPE_CHECKING:
-    from typing import Any
 
     from sqlalchemy import Computed
     from sqlalchemy import Identity
     from sqlalchemy.sql.compiler import Compiled
     from sqlalchemy.sql.compiler import DDLCompiler
-    from sqlalchemy.sql.elements import TextClause
-    from sqlalchemy.sql.functions import Function
-    from sqlalchemy.sql.schema import FetchedValue
     from sqlalchemy.sql.type_api import TypeEngine
 
     from .impl import DefaultImpl
 
-_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
+_ServerDefaultType = Union[FetchedValue, str, TextClause, ColumnElement[Any]]
 
 
 class AlterTable(DDLElement):
@@ -75,7 +75,7 @@ class AlterColumn(AlterTable):
         schema: Optional[str] = None,
         existing_type: Optional[TypeEngine] = None,
         existing_nullable: Optional[bool] = None,
-        existing_server_default: Optional[_ServerDefault] = None,
+        existing_server_default: Optional[_ServerDefaultType] = None,
         existing_comment: Optional[str] = None,
     ) -> None:
         super().__init__(name, schema=schema)
@@ -119,7 +119,7 @@ class ColumnDefault(AlterColumn):
         self,
         name: str,
         column_name: str,
-        default: Optional[_ServerDefault],
+        default: Optional[_ServerDefaultType],
         **kw,
     ) -> None:
         super().__init__(name, column_name, **kw)
@@ -308,7 +308,7 @@ def format_column_name(
 
 def format_server_default(
     compiler: DDLCompiler,
-    default: Optional[_ServerDefault],
+    default: Optional[_ServerDefaultType],
 ) -> str:
     # this can be updated to use compiler.render_default_string
     # for SQLAlchemy 2.0 and above; not in 1.4
index f75cb77a1c98db9e256d8294eed5136128e8e537..c0d1751d7163fbac8e0832d0f10d1caa3cbda7c4 100644 (file)
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.selectable import TableClause
     from sqlalchemy.sql.type_api import TypeEngine
 
-    from .base import _ServerDefault
+    from .base import _ServerDefaultType
     from ..autogenerate.api import AutogenContext
     from ..operations.batch import ApplyBatchImpl
     from ..operations.batch import BatchOperationsImpl
@@ -269,7 +269,7 @@ class DefaultImpl(metaclass=ImplMeta):
         *,
         nullable: Optional[bool] = None,
         server_default: Optional[
-            Union[_ServerDefault, Literal[False]]
+            Union[_ServerDefaultType, Literal[False]]
         ] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
@@ -278,7 +278,9 @@ class DefaultImpl(metaclass=ImplMeta):
         comment: Optional[Union[str, Literal[False]]] = False,
         existing_comment: Optional[str] = None,
         existing_type: Optional[TypeEngine] = None,
-        existing_server_default: Optional[_ServerDefault] = None,
+        existing_server_default: Optional[
+            Union[_ServerDefaultType, Literal[False]]
+        ] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
         **kw: Any,
index 22bd0e4b0b45d213f6cde77f40986bda0f3442de..91cd9e428d08c366b506d709574c2ceae935eacf 100644 (file)
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.selectable import TableClause
     from sqlalchemy.sql.type_api import TypeEngine
 
-    from .base import _ServerDefault
+    from .base import _ServerDefaultType
     from .impl import _ReflectedConstraint
 
 
@@ -92,14 +92,14 @@ class MSSQLImpl(DefaultImpl):
         *,
         nullable: Optional[bool] = None,
         server_default: Optional[
-            Union[_ServerDefault, Literal[False]]
+            Union[_ServerDefaultType, Literal[False]]
         ] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
         existing_type: Optional[TypeEngine] = None,
         existing_server_default: Union[
-            _ServerDefault, Literal[False], None
+            _ServerDefaultType, Literal[False], None
         ] = None,
         existing_nullable: Optional[bool] = None,
         **kw: Any,
index 3d7cf21a49a3b724ccf3eb2335c37b90556d70d1..27f808b050541edba840dac6227e7653d6f7de51 100644 (file)
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.schema import Constraint
     from sqlalchemy.sql.type_api import TypeEngine
 
-    from .base import _ServerDefault
+    from .base import _ServerDefaultType
 
 
 class MySQLImpl(DefaultImpl):
@@ -83,13 +83,15 @@ class MySQLImpl(DefaultImpl):
         *,
         nullable: Optional[bool] = None,
         server_default: Optional[
-            Union[_ServerDefault, Literal[False]]
+            Union[_ServerDefaultType, Literal[False]]
         ] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
         existing_type: Optional[TypeEngine] = None,
-        existing_server_default: Optional[_ServerDefault] = None,
+        existing_server_default: Optional[
+            Union[_ServerDefaultType, Literal[False]]
+        ] = None,
         existing_nullable: Optional[bool] = None,
         autoincrement: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
@@ -207,7 +209,7 @@ class MySQLImpl(DefaultImpl):
     def _is_mysql_allowed_functional_default(
         self,
         type_: Optional[TypeEngine],
-        server_default: Optional[Union[_ServerDefault, Literal[False]]],
+        server_default: Optional[Union[_ServerDefaultType, Literal[False]]],
     ) -> bool:
         return (
             type_ is not None
@@ -358,7 +360,7 @@ class MySQLAlterDefault(AlterColumn):
         self,
         name: str,
         column_name: str,
-        default: Optional[_ServerDefault],
+        default: Optional[_ServerDefaultType],
         schema: Optional[str] = None,
     ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
@@ -375,7 +377,7 @@ class MySQLChangeColumn(AlterColumn):
         newname: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         nullable: Optional[bool] = None,
-        default: Optional[Union[_ServerDefault, Literal[False]]] = False,
+        default: Optional[Union[_ServerDefaultType, Literal[False]]] = False,
         autoincrement: Optional[bool] = None,
         comment: Optional[Union[str, Literal[False]]] = False,
     ) -> None:
@@ -464,7 +466,7 @@ def _mysql_change_column(
 def _mysql_colspec(
     compiler: MySQLDDLCompiler,
     nullable: Optional[bool],
-    server_default: Optional[Union[_ServerDefault, Literal[False]]],
+    server_default: Optional[Union[_ServerDefaultType, Literal[False]]],
     type_: TypeEngine,
     autoincrement: Optional[bool],
     comment: Optional[Union[str, Literal[False]]],
index d55664bb75e814dbea56659d97fd407728839e5e..cc03f45346d81c873a3724aee1a51fd2ea2cf1a6 100644 (file)
@@ -70,7 +70,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.schema import Table
     from sqlalchemy.sql.type_api import TypeEngine
 
-    from .base import _ServerDefault
+    from .base import _ServerDefaultType
     from .impl import _ReflectedConstraint
     from ..autogenerate.api import AutogenContext
     from ..autogenerate.render import _f_name
@@ -164,14 +164,16 @@ class PostgresqlImpl(DefaultImpl):
         *,
         nullable: Optional[bool] = None,
         server_default: Optional[
-            Union[_ServerDefault, Literal[False]]
+            Union[_ServerDefaultType, Literal[False]]
         ] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
         autoincrement: Optional[bool] = None,
         existing_type: Optional[TypeEngine] = None,
-        existing_server_default: Optional[_ServerDefault] = None,
+        existing_server_default: Optional[
+            Union[_ServerDefaultType, Literal[False]]
+        ] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
         **kw: Any,
index 96f68b82ffadb4aa23897edcc0ca5b116cb70270..1f2c03642b084980527c1e1263fd3543eb269391 100644 (file)
@@ -28,13 +28,12 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.schema import Column
-    from sqlalchemy.sql.schema import Computed
-    from sqlalchemy.sql.schema import Identity
     from sqlalchemy.sql.schema import SchemaItem
     from sqlalchemy.sql.schema import Table
     from sqlalchemy.sql.type_api import TypeEngine
     from sqlalchemy.util import immutabledict
 
+    from .ddl.base import _ServerDefaultType
     from .operations.base import BatchOperations
     from .operations.ops import AddColumnOp
     from .operations.ops import AddConstraintOp
@@ -154,14 +153,12 @@ def alter_column(
     *,
     nullable: Optional[bool] = None,
     comment: Union[str, Literal[False], None] = False,
-    server_default: Union[
-        str, bool, Identity, Computed, TextClause, None
-    ] = False,
+    server_default: Union[_ServerDefaultType, None, Literal[False]] = False,
     new_column_name: Optional[str] = None,
     type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
     existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
     existing_server_default: Union[
-        str, bool, Identity, Computed, TextClause, None
+        _ServerDefaultType, None, Literal[False]
     ] = False,
     existing_nullable: Optional[bool] = None,
     existing_comment: Optional[str] = None,
index be3a77b2ada212b7d9a238c331f8f1ce48062e3e..702787e63ca51c17ed241a0b690ba721b9bfbb24 100644 (file)
@@ -27,13 +27,13 @@ from sqlalchemy.sql.elements import conv
 from . import batch
 from . import schemaobj
 from .. import util
+from ..ddl.base import _ServerDefaultType
 from ..util import sqla_compat
 from ..util.compat import formatannotation_fwdref
 from ..util.compat import inspect_formatargspec
 from ..util.compat import inspect_getfullargspec
 from ..util.sqla_compat import _literal_bindparam
 
-
 if TYPE_CHECKING:
     from typing import Literal
 
@@ -44,8 +44,6 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.expression import TextClause
     from sqlalchemy.sql.schema import Column
-    from sqlalchemy.sql.schema import Computed
-    from sqlalchemy.sql.schema import Identity
     from sqlalchemy.sql.schema import SchemaItem
     from sqlalchemy.types import TypeEngine
 
@@ -724,7 +722,7 @@ class Operations(AbstractOperations):
             nullable: Optional[bool] = None,
             comment: Union[str, Literal[False], None] = False,
             server_default: Union[
-                str, bool, Identity, Computed, TextClause, None
+                _ServerDefaultType, None, Literal[False]
             ] = False,
             new_column_name: Optional[str] = None,
             type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
@@ -732,7 +730,7 @@ class Operations(AbstractOperations):
                 TypeEngine[Any], Type[TypeEngine[Any]], None
             ] = None,
             existing_server_default: Union[
-                str, bool, Identity, Computed, TextClause, None
+                _ServerDefaultType, None, Literal[False]
             ] = False,
             existing_nullable: Optional[bool] = None,
             existing_comment: Optional[str] = None,
@@ -1691,14 +1689,16 @@ class BatchOperations(AbstractOperations):
             *,
             nullable: Optional[bool] = None,
             comment: Union[str, Literal[False], None] = False,
-            server_default: Any = False,
+            server_default: Union[
+                _ServerDefaultType, None, Literal[False]
+            ] = False,
             new_column_name: Optional[str] = None,
             type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
             existing_type: Union[
                 TypeEngine[Any], Type[TypeEngine[Any]], None
             ] = None,
             existing_server_default: Union[
-                str, bool, Identity, Computed, None
+                _ServerDefaultType, None, Literal[False]
             ] = False,
             existing_nullable: Optional[bool] = None,
             existing_comment: Optional[str] = None,
index fe183e9c8815b950d10a2280c9167969923e53b9..9b48be598625f32837be93ae055b14e8faac9d3e 100644 (file)
@@ -44,10 +44,10 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Dialect
     from sqlalchemy.sql.elements import ColumnClause
     from sqlalchemy.sql.elements import quoted_name
-    from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import Constraint
     from sqlalchemy.sql.type_api import TypeEngine
 
+    from ..ddl.base import _ServerDefaultType
     from ..ddl.impl import DefaultImpl
 
 
@@ -485,7 +485,9 @@ class ApplyBatchImpl:
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Optional[Union[Function[Any], str, bool]] = False,
+        server_default: Union[
+            _ServerDefaultType, None, Literal[False]
+        ] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         autoincrement: Optional[Union[bool, Literal["auto"]]] = None,
index c9b1526b61ffcb99770eeadd61d5e1e09c4e4066..3bc1e83556a8114c1b09a55ccce06d0c25274172 100644 (file)
@@ -39,10 +39,8 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import CheckConstraint
     from sqlalchemy.sql.schema import Column
-    from sqlalchemy.sql.schema import Computed
     from sqlalchemy.sql.schema import Constraint
     from sqlalchemy.sql.schema import ForeignKeyConstraint
-    from sqlalchemy.sql.schema import Identity
     from sqlalchemy.sql.schema import Index
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import PrimaryKeyConstraint
@@ -53,6 +51,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.type_api import TypeEngine
 
     from ..autogenerate.rewriter import Rewriter
+    from ..ddl.base import _ServerDefaultType
     from ..runtime.migration import MigrationContext
     from ..script.revision import _RevIdType
 
@@ -1696,7 +1695,9 @@ class AlterColumnOp(AlterTableOp):
         *,
         schema: Optional[str] = None,
         existing_type: Optional[Any] = None,
-        existing_server_default: Any = False,
+        existing_server_default: Union[
+            _ServerDefaultType, None, Literal[False]
+        ] = False,
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
         modify_nullable: Optional[bool] = None,
@@ -1856,7 +1857,7 @@ class AlterColumnOp(AlterTableOp):
         nullable: Optional[bool] = None,
         comment: Optional[Union[str, Literal[False]]] = False,
         server_default: Union[
-            str, bool, Identity, Computed, TextClause, None
+            _ServerDefaultType, None, Literal[False]
         ] = False,
         new_column_name: Optional[str] = None,
         type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
@@ -1864,7 +1865,7 @@ class AlterColumnOp(AlterTableOp):
             Union[TypeEngine[Any], Type[TypeEngine[Any]]]
         ] = None,
         existing_server_default: Union[
-            str, bool, Identity, Computed, TextClause, None
+            _ServerDefaultType, None, Literal[False]
         ] = False,
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
@@ -1980,14 +1981,16 @@ class AlterColumnOp(AlterTableOp):
         *,
         nullable: Optional[bool] = None,
         comment: Optional[Union[str, Literal[False]]] = False,
-        server_default: Any = False,
+        server_default: Union[
+            _ServerDefaultType, None, Literal[False]
+        ] = False,
         new_column_name: Optional[str] = None,
         type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
         existing_type: Optional[
             Union[TypeEngine[Any], Type[TypeEngine[Any]]]
         ] = None,
-        existing_server_default: Optional[
-            Union[str, bool, Identity, Computed]
+        existing_server_default: Union[
+            _ServerDefaultType, None, Literal[False]
         ] = False,
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
index c18ec790176d6db1a848e962f190202bbed47162..5f93464e73cdea4f4faf6bef91e138cd01cef50e 100644 (file)
@@ -50,6 +50,11 @@ def alter_column(
             if _count_constraint(constraint):
                 operations.impl.drop_constraint(constraint)
 
+    # some weird pyright quirk here, these have Literal[False]
+    # in their types, not sure why pyright thinks they could be True
+    assert existing_server_default is not True  # type: ignore[comparison-overlap]  # noqa: E501
+    assert comment is not True  # type: ignore[comparison-overlap]
+
     operations.impl.alter_column(
         table_name,
         column_name,
diff --git a/docs/build/unreleased/1669.rst b/docs/build/unreleased/1669.rst
new file mode 100644 (file)
index 0000000..0e148c1
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1669
+
+    Fixed typing issue where the :paramref:`.AlterColumnOp.server_default` and
+    :paramref:`.AlterColumnOp.existing_server_default` parameters failed to
+    accommodate common SQLAlchemy SQL constructs such as ``null()`` and
+    ``text()``.   Pull request courtesy Sebastian Kreft.
+
index ec5a11e32fb79d6ca5e1f27e959cf62e41f61e59..bd32b38386081318f0855acbde498af24a84f42b 100644 (file)
@@ -10,13 +10,17 @@ import shutil
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
+import types
 import typing
 
+from sqlalchemy.util import typing as sa_typing
+
 sys.path.append(str(Path(__file__).parent.parent))
 
 
 if True:  # avoid flake/zimports messing with the order
     from alembic.autogenerate.api import AutogenContext
+    from alembic.operations.base import _ServerDefaultType
     from alembic.ddl.impl import DefaultImpl
     from alembic.runtime.migration import MigrationInfo
     from alembic.operations.base import BatchOperations
@@ -176,6 +180,7 @@ def _generate_stub_for_meth(
         spec.annotations.update(annotations)
     except NameError as e:
         print(f"{cls.__name__}.{name} NameError: {e}", file=sys.stderr)
+        raise
 
     name_args = spec[0]
     assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]
@@ -184,7 +189,25 @@ def _generate_stub_for_meth(
         name_args[0:1] = []
 
     def _formatannotation(annotation, base_module=None):
-        if getattr(annotation, "__module__", None) == "typing":
+        retval = None
+        if sa_typing.is_union(annotation):
+            for ta in type_aliases:
+
+                if set(ta.__args__).issubset(annotation.__args__):
+                    remainder = set(annotation.__args__).difference(
+                        ta.__args__
+                    )
+                    retval = (
+                        f"Union[{type_aliases[ta]}, "
+                        f"{', '.join(sorted("None" if a is types.NoneType else repr(a) for a in remainder))}]"  # noqa: E501
+                    )
+                    break
+
+        if retval is not None:
+            pass
+        elif annotation in type_aliases:
+            retval = type_aliases[annotation]
+        elif getattr(annotation, "__module__", None) == "typing":
             retval = repr(annotation).replace("typing.", "")
         elif getattr(annotation, "__module__", None) == "types":
             retval = repr(annotation).replace("types.", "")
@@ -195,6 +218,7 @@ def _generate_stub_for_meth(
         elif hasattr(annotation, "__args__") and hasattr(
             annotation, "__origin__"
         ):
+
             # generic class
             retval = str(annotation)
         else:
@@ -412,6 +436,9 @@ cls_ignore = {
     "run_async",
 }
 
+type_aliases = {_ServerDefaultType: "_ServerDefaultType"}
+
+
 cases = [
     StubFileInfo(
         "op",