From: Sebastian Kreft <911768+sk-@users.noreply.github.com> Date: Sun, 18 Jan 2026 15:26:05 +0000 (-0500) Subject: fix(typings): improve typing for server_default X-Git-Tag: rel_1_18_2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=91a51160fbcb80c92cd4ff4d3ce892588131c773;p=thirdparty%2Fsqlalchemy%2Falembic.git fix(typings): improve typing for server_default Fixed typing issue where the :paramref:`.AlterColumnOp.server_default` and :paramref:`.AlterColumnOp.existing_server_default` parameters failed to accommodate common SQLAlchemy SQL constructs such as ``null()`` and ``text()``. Pull request courtesy Sebastian Kreft. this sets up a standard type for the server default argument using an alias, and adds modifications to write_pyi for extremely basic ability to render type aliases (with limitations). Co-authored-by: Mike Bayer Fixes: #1669 Closes: #1670 Pull-request: https://github.com/sqlalchemy/alembic/pull/1670 Pull-request-sha: e6464647b6e33e077e7baf4bbc5c7549ab570a06 Change-Id: Id25bf7fd706f91aa637adf9b67f0529f1d7d1080 --- diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index ad2847eb..550fe147 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -4,6 +4,7 @@ from __future__ import annotations import functools +from typing import Any from typing import Optional from typing import TYPE_CHECKING from typing import Union @@ -14,7 +15,10 @@ from sqlalchemy import types as sqltypes from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import Column from sqlalchemy.schema import DDLElement +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.schema import FetchedValue from ..util.sqla_compat import _columns_for_constraint # noqa from ..util.sqla_compat import _find_columns # noqa @@ -23,20 +27,16 @@ from ..util.sqla_compat import _is_type_bound # noqa from ..util.sqla_compat import _table_for_constraint # noqa if TYPE_CHECKING: - from typing import Any from sqlalchemy import Computed from sqlalchemy import Identity from sqlalchemy.sql.compiler import Compiled from sqlalchemy.sql.compiler import DDLCompiler - from sqlalchemy.sql.elements import TextClause - from sqlalchemy.sql.functions import Function - from sqlalchemy.sql.schema import FetchedValue from sqlalchemy.sql.type_api import TypeEngine from .impl import DefaultImpl -_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str] +_ServerDefaultType = Union[FetchedValue, str, TextClause, ColumnElement[Any]] class AlterTable(DDLElement): @@ -75,7 +75,7 @@ class AlterColumn(AlterTable): schema: Optional[str] = None, existing_type: Optional[TypeEngine] = None, existing_nullable: Optional[bool] = None, - existing_server_default: Optional[_ServerDefault] = None, + existing_server_default: Optional[_ServerDefaultType] = None, existing_comment: Optional[str] = None, ) -> None: super().__init__(name, schema=schema) @@ -119,7 +119,7 @@ class ColumnDefault(AlterColumn): self, name: str, column_name: str, - default: Optional[_ServerDefault], + default: Optional[_ServerDefaultType], **kw, ) -> None: super().__init__(name, column_name, **kw) @@ -308,7 +308,7 @@ def format_column_name( def format_server_default( compiler: DDLCompiler, - default: Optional[_ServerDefault], + default: Optional[_ServerDefaultType], ) -> str: # this can be updated to use compiler.render_default_string # for SQLAlchemy 2.0 and above; not in 1.4 diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index f75cb77a..c0d1751d 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -58,7 +58,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine - from .base import _ServerDefault + from .base import _ServerDefaultType from ..autogenerate.api import AutogenContext from ..operations.batch import ApplyBatchImpl from ..operations.batch import BatchOperationsImpl @@ -269,7 +269,7 @@ class DefaultImpl(metaclass=ImplMeta): *, nullable: Optional[bool] = None, server_default: Optional[ - Union[_ServerDefault, Literal[False]] + Union[_ServerDefaultType, Literal[False]] ] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, @@ -278,7 +278,9 @@ class DefaultImpl(metaclass=ImplMeta): comment: Optional[Union[str, Literal[False]]] = False, existing_comment: Optional[str] = None, existing_type: Optional[TypeEngine] = None, - existing_server_default: Optional[_ServerDefault] = None, + existing_server_default: Optional[ + Union[_ServerDefaultType, Literal[False]] + ] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, **kw: Any, diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 22bd0e4b..91cd9e42 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine - from .base import _ServerDefault + from .base import _ServerDefaultType from .impl import _ReflectedConstraint @@ -92,14 +92,14 @@ class MSSQLImpl(DefaultImpl): *, nullable: Optional[bool] = None, server_default: Optional[ - Union[_ServerDefault, Literal[False]] + Union[_ServerDefaultType, Literal[False]] ] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, existing_type: Optional[TypeEngine] = None, existing_server_default: Union[ - _ServerDefault, Literal[False], None + _ServerDefaultType, Literal[False], None ] = None, existing_nullable: Optional[bool] = None, **kw: Any, diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 3d7cf21a..27f808b0 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.type_api import TypeEngine - from .base import _ServerDefault + from .base import _ServerDefaultType class MySQLImpl(DefaultImpl): @@ -83,13 +83,15 @@ class MySQLImpl(DefaultImpl): *, nullable: Optional[bool] = None, server_default: Optional[ - Union[_ServerDefault, Literal[False]] + Union[_ServerDefaultType, Literal[False]] ] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, existing_type: Optional[TypeEngine] = None, - existing_server_default: Optional[_ServerDefault] = None, + existing_server_default: Optional[ + Union[_ServerDefaultType, Literal[False]] + ] = None, existing_nullable: Optional[bool] = None, autoincrement: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, @@ -207,7 +209,7 @@ class MySQLImpl(DefaultImpl): def _is_mysql_allowed_functional_default( self, type_: Optional[TypeEngine], - server_default: Optional[Union[_ServerDefault, Literal[False]]], + server_default: Optional[Union[_ServerDefaultType, Literal[False]]], ) -> bool: return ( type_ is not None @@ -358,7 +360,7 @@ class MySQLAlterDefault(AlterColumn): self, name: str, column_name: str, - default: Optional[_ServerDefault], + default: Optional[_ServerDefaultType], schema: Optional[str] = None, ) -> None: super(AlterColumn, self).__init__(name, schema=schema) @@ -375,7 +377,7 @@ class MySQLChangeColumn(AlterColumn): newname: Optional[str] = None, type_: Optional[TypeEngine] = None, nullable: Optional[bool] = None, - default: Optional[Union[_ServerDefault, Literal[False]]] = False, + default: Optional[Union[_ServerDefaultType, Literal[False]]] = False, autoincrement: Optional[bool] = None, comment: Optional[Union[str, Literal[False]]] = False, ) -> None: @@ -464,7 +466,7 @@ def _mysql_change_column( def _mysql_colspec( compiler: MySQLDDLCompiler, nullable: Optional[bool], - server_default: Optional[Union[_ServerDefault, Literal[False]]], + server_default: Optional[Union[_ServerDefaultType, Literal[False]]], type_: TypeEngine, autoincrement: Optional[bool], comment: Optional[Union[str, Literal[False]]], diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index d55664bb..cc03f453 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -70,7 +70,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.schema import Table from sqlalchemy.sql.type_api import TypeEngine - from .base import _ServerDefault + from .base import _ServerDefaultType from .impl import _ReflectedConstraint from ..autogenerate.api import AutogenContext from ..autogenerate.render import _f_name @@ -164,14 +164,16 @@ class PostgresqlImpl(DefaultImpl): *, nullable: Optional[bool] = None, server_default: Optional[ - Union[_ServerDefault, Literal[False]] + Union[_ServerDefaultType, Literal[False]] ] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, autoincrement: Optional[bool] = None, existing_type: Optional[TypeEngine] = None, - existing_server_default: Optional[_ServerDefault] = None, + existing_server_default: Optional[ + Union[_ServerDefaultType, Literal[False]] + ] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, **kw: Any, diff --git a/alembic/op.pyi b/alembic/op.pyi index 96f68b82..1f2c0364 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -28,13 +28,12 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.schema import Column - from sqlalchemy.sql.schema import Computed - from sqlalchemy.sql.schema import Identity from sqlalchemy.sql.schema import SchemaItem from sqlalchemy.sql.schema import Table from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.util import immutabledict + from .ddl.base import _ServerDefaultType from .operations.base import BatchOperations from .operations.ops import AddColumnOp from .operations.ops import AddConstraintOp @@ -154,14 +153,12 @@ def alter_column( *, nullable: Optional[bool] = None, comment: Union[str, Literal[False], None] = False, - server_default: Union[ - str, bool, Identity, Computed, TextClause, None - ] = False, + server_default: Union[_ServerDefaultType, None, Literal[False]] = False, new_column_name: Optional[str] = None, type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + _ServerDefaultType, None, Literal[False] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, diff --git a/alembic/operations/base.py b/alembic/operations/base.py index be3a77b2..702787e6 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -27,13 +27,13 @@ from sqlalchemy.sql.elements import conv from . import batch from . import schemaobj from .. import util +from ..ddl.base import _ServerDefaultType from ..util import sqla_compat from ..util.compat import formatannotation_fwdref from ..util.compat import inspect_formatargspec from ..util.compat import inspect_getfullargspec from ..util.sqla_compat import _literal_bindparam - if TYPE_CHECKING: from typing import Literal @@ -44,8 +44,6 @@ if TYPE_CHECKING: from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.expression import TextClause from sqlalchemy.sql.schema import Column - from sqlalchemy.sql.schema import Computed - from sqlalchemy.sql.schema import Identity from sqlalchemy.sql.schema import SchemaItem from sqlalchemy.types import TypeEngine @@ -724,7 +722,7 @@ class Operations(AbstractOperations): nullable: Optional[bool] = None, comment: Union[str, Literal[False], None] = False, server_default: Union[ - str, bool, Identity, Computed, TextClause, None + _ServerDefaultType, None, Literal[False] ] = False, new_column_name: Optional[str] = None, type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, @@ -732,7 +730,7 @@ class Operations(AbstractOperations): TypeEngine[Any], Type[TypeEngine[Any]], None ] = None, existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + _ServerDefaultType, None, Literal[False] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, @@ -1691,14 +1689,16 @@ class BatchOperations(AbstractOperations): *, nullable: Optional[bool] = None, comment: Union[str, Literal[False], None] = False, - server_default: Any = False, + server_default: Union[ + _ServerDefaultType, None, Literal[False] + ] = False, new_column_name: Optional[str] = None, type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, existing_type: Union[ TypeEngine[Any], Type[TypeEngine[Any]], None ] = None, existing_server_default: Union[ - str, bool, Identity, Computed, None + _ServerDefaultType, None, Literal[False] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index fe183e9c..9b48be59 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -44,10 +44,10 @@ if TYPE_CHECKING: from sqlalchemy.engine import Dialect from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.type_api import TypeEngine + from ..ddl.base import _ServerDefaultType from ..ddl.impl import DefaultImpl @@ -485,7 +485,9 @@ class ApplyBatchImpl: table_name: str, column_name: str, nullable: Optional[bool] = None, - server_default: Optional[Union[Function[Any], str, bool]] = False, + server_default: Union[ + _ServerDefaultType, None, Literal[False] + ] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, autoincrement: Optional[Union[bool, Literal["auto"]]] = None, diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index c9b1526b..3bc1e835 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -39,10 +39,8 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Column - from sqlalchemy.sql.schema import Computed from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint - from sqlalchemy.sql.schema import Identity from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import PrimaryKeyConstraint @@ -53,6 +51,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.type_api import TypeEngine from ..autogenerate.rewriter import Rewriter + from ..ddl.base import _ServerDefaultType from ..runtime.migration import MigrationContext from ..script.revision import _RevIdType @@ -1696,7 +1695,9 @@ class AlterColumnOp(AlterTableOp): *, schema: Optional[str] = None, existing_type: Optional[Any] = None, - existing_server_default: Any = False, + existing_server_default: Union[ + _ServerDefaultType, None, Literal[False] + ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, modify_nullable: Optional[bool] = None, @@ -1856,7 +1857,7 @@ class AlterColumnOp(AlterTableOp): nullable: Optional[bool] = None, comment: Optional[Union[str, Literal[False]]] = False, server_default: Union[ - str, bool, Identity, Computed, TextClause, None + _ServerDefaultType, None, Literal[False] ] = False, new_column_name: Optional[str] = None, type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, @@ -1864,7 +1865,7 @@ class AlterColumnOp(AlterTableOp): Union[TypeEngine[Any], Type[TypeEngine[Any]]] ] = None, existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + _ServerDefaultType, None, Literal[False] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, @@ -1980,14 +1981,16 @@ class AlterColumnOp(AlterTableOp): *, nullable: Optional[bool] = None, comment: Optional[Union[str, Literal[False]]] = False, - server_default: Any = False, + server_default: Union[ + _ServerDefaultType, None, Literal[False] + ] = False, new_column_name: Optional[str] = None, type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, existing_type: Optional[ Union[TypeEngine[Any], Type[TypeEngine[Any]]] ] = None, - existing_server_default: Optional[ - Union[str, bool, Identity, Computed] + existing_server_default: Union[ + _ServerDefaultType, None, Literal[False] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py index c18ec790..5f93464e 100644 --- a/alembic/operations/toimpl.py +++ b/alembic/operations/toimpl.py @@ -50,6 +50,11 @@ def alter_column( if _count_constraint(constraint): operations.impl.drop_constraint(constraint) + # some weird pyright quirk here, these have Literal[False] + # in their types, not sure why pyright thinks they could be True + assert existing_server_default is not True # type: ignore[comparison-overlap] # noqa: E501 + assert comment is not True # type: ignore[comparison-overlap] + operations.impl.alter_column( table_name, column_name, diff --git a/docs/build/unreleased/1669.rst b/docs/build/unreleased/1669.rst new file mode 100644 index 00000000..0e148c10 --- /dev/null +++ b/docs/build/unreleased/1669.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, typing + :tickets: 1669 + + Fixed typing issue where the :paramref:`.AlterColumnOp.server_default` and + :paramref:`.AlterColumnOp.existing_server_default` parameters failed to + accommodate common SQLAlchemy SQL constructs such as ``null()`` and + ``text()``. Pull request courtesy Sebastian Kreft. + diff --git a/tools/write_pyi.py b/tools/write_pyi.py index ec5a11e3..bd32b383 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -10,13 +10,17 @@ import shutil import sys from tempfile import NamedTemporaryFile import textwrap +import types import typing +from sqlalchemy.util import typing as sa_typing + sys.path.append(str(Path(__file__).parent.parent)) if True: # avoid flake/zimports messing with the order from alembic.autogenerate.api import AutogenContext + from alembic.operations.base import _ServerDefaultType from alembic.ddl.impl import DefaultImpl from alembic.runtime.migration import MigrationInfo from alembic.operations.base import BatchOperations @@ -176,6 +180,7 @@ def _generate_stub_for_meth( spec.annotations.update(annotations) except NameError as e: print(f"{cls.__name__}.{name} NameError: {e}", file=sys.stderr) + raise name_args = spec[0] assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"] @@ -184,7 +189,25 @@ def _generate_stub_for_meth( name_args[0:1] = [] def _formatannotation(annotation, base_module=None): - if getattr(annotation, "__module__", None) == "typing": + retval = None + if sa_typing.is_union(annotation): + for ta in type_aliases: + + if set(ta.__args__).issubset(annotation.__args__): + remainder = set(annotation.__args__).difference( + ta.__args__ + ) + retval = ( + f"Union[{type_aliases[ta]}, " + f"{', '.join(sorted("None" if a is types.NoneType else repr(a) for a in remainder))}]" # noqa: E501 + ) + break + + if retval is not None: + pass + elif annotation in type_aliases: + retval = type_aliases[annotation] + elif getattr(annotation, "__module__", None) == "typing": retval = repr(annotation).replace("typing.", "") elif getattr(annotation, "__module__", None) == "types": retval = repr(annotation).replace("types.", "") @@ -195,6 +218,7 @@ def _generate_stub_for_meth( elif hasattr(annotation, "__args__") and hasattr( annotation, "__origin__" ): + # generic class retval = str(annotation) else: @@ -412,6 +436,9 @@ cls_ignore = { "run_async", } +type_aliases = {_ServerDefaultType: "_ServerDefaultType"} + + cases = [ StubFileInfo( "op",