From: Mike Bayer Date: Thu, 11 Dec 2025 22:20:21 +0000 (-0500) Subject: organize into a "plugin" directory structure X-Git-Tag: rel_1_18_0~7^2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=e532a7e39cb6b0e91fbe045778f31a997ded452d;p=thirdparty%2Fsqlalchemy%2Falembic.git organize into a "plugin" directory structure we attempt to move autogen functions into independent units that are more obviously pluggable, and we add support for arbitrary "plugin" entrypoints that could add more pluggable units into autogenerate or anywhere else Change-Id: Id606a76dc6d12a308028f6cfdad690e0e63a43e5 --- diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e31b2a2a..5f45d054 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: - id: black - repo: https://github.com/sqlalchemyorg/zimports - rev: v0.6.2 + rev: v0.7.0 hooks: - id: zimports args: diff --git a/alembic/__init__.py b/alembic/__init__.py index 4ba57ecd..500f21c2 100644 --- a/alembic/__init__.py +++ b/alembic/__init__.py @@ -1,4 +1,6 @@ from . import context from . import op +from .runtime import plugins -__version__ = "1.17.3" + +__version__ = "1.18.0" diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 811462e8..b2e3faef 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import logging from typing import Any from typing import Dict from typing import Iterator @@ -17,11 +18,9 @@ from . import compare from . import render from .. import util from ..operations import ops +from ..runtime.plugins import Plugin from ..util import sqla_compat -"""Provide the 'autogenerate' feature which can produce migration operations -automatically.""" - if TYPE_CHECKING: from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect @@ -42,6 +41,10 @@ if TYPE_CHECKING: from ..script.base import Script from ..script.base import ScriptDirectory from ..script.revision import _GetRevArg + from ..util import PriorityDispatcher + + +log = logging.getLogger(__name__) def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: @@ -304,7 +307,7 @@ class AutogenContext: """ - dialect: Optional[Dialect] = None + dialect: Dialect """The :class:`~sqlalchemy.engine.Dialect` object currently in use. This is normally obtained from the @@ -326,9 +329,11 @@ class AutogenContext: """ - migration_context: MigrationContext = None # type: ignore[assignment] + migration_context: MigrationContext """The :class:`.MigrationContext` established by the ``env.py`` script.""" + comparators: PriorityDispatcher + def __init__( self, migration_context: MigrationContext, @@ -346,6 +351,19 @@ class AutogenContext: "the database for schema information" ) + # branch off from the "global" comparators. This collection + # is empty in Alembic except that it is populated by third party + # extensions that don't use the plugin system. so we will build + # off of whatever is in there. + if autogenerate: + self.comparators = compare.comparators.branch() + Plugin.populate_autogenerate_priority_dispatch( + self.comparators, + include_plugins=migration_context.opts.get( + "autogenerate_plugins", ["alembic.autogenerate.*"] + ), + ) + if opts is None: opts = migration_context.opts @@ -380,9 +398,8 @@ class AutogenContext: self._name_filters = name_filters self.migration_context = migration_context - if self.migration_context is not None: - self.connection = self.migration_context.bind - self.dialect = self.migration_context.dialect + self.connection = self.migration_context.bind + self.dialect = self.migration_context.dialect self.imports = set() self.opts: Dict[str, Any] = opts diff --git a/alembic/autogenerate/compare/__init__.py b/alembic/autogenerate/compare/__init__.py new file mode 100644 index 00000000..a49640cf --- /dev/null +++ b/alembic/autogenerate/compare/__init__.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from . import comments +from . import constraints +from . import schema +from . import server_defaults +from . import tables +from . import types +from ... import util +from ...runtime.plugins import Plugin + +if TYPE_CHECKING: + from ..api import AutogenContext + from ...operations.ops import MigrationScript + from ...operations.ops import UpgradeOps + + +log = logging.getLogger(__name__) + +comparators = util.PriorityDispatcher() +"""global registry which alembic keeps empty, but copies when creating +a new AutogenContext. + +This is to support a variety of third party plugins that hook their autogen +functionality onto this collection. + +""" + + +def _populate_migration_script( + autogen_context: AutogenContext, migration_script: MigrationScript +) -> None: + upgrade_ops = migration_script.upgrade_ops_list[-1] + downgrade_ops = migration_script.downgrade_ops_list[-1] + + _produce_net_changes(autogen_context, upgrade_ops) + upgrade_ops.reverse_into(downgrade_ops) + + +def _produce_net_changes( + autogen_context: AutogenContext, upgrade_ops: UpgradeOps +) -> None: + assert autogen_context.dialect is not None + + autogen_context.comparators.dispatch( + "autogenerate", qualifier=autogen_context.dialect.name + )(autogen_context, upgrade_ops) + + +Plugin.setup_plugin_from_module(schema, "alembic.autogenerate.schemas") +Plugin.setup_plugin_from_module(tables, "alembic.autogenerate.tables") +Plugin.setup_plugin_from_module(types, "alembic.autogenerate.types") +Plugin.setup_plugin_from_module( + constraints, "alembic.autogenerate.constraints" +) +Plugin.setup_plugin_from_module( + server_defaults, "alembic.autogenerate.defaults" +) +Plugin.setup_plugin_from_module(comments, "alembic.autogenerate.comments") diff --git a/alembic/autogenerate/compare/comments.py b/alembic/autogenerate/compare/comments.py new file mode 100644 index 00000000..70de74e2 --- /dev/null +++ b/alembic/autogenerate/compare/comments.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from ...operations import ops +from ...util import PriorityDispatchResult + +if TYPE_CHECKING: + + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Table + + from ..api import AutogenContext + from ...operations.ops import AlterColumnOp + from ...operations.ops import ModifyTableOps + from ...runtime.plugins import Plugin + +log = logging.getLogger(__name__) + + +def _compare_column_comment( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + assert autogen_context.dialect is not None + if not autogen_context.dialect.supports_comments: + return PriorityDispatchResult.CONTINUE + + metadata_comment = metadata_col.comment + conn_col_comment = conn_col.comment + if conn_col_comment is None and metadata_comment is None: + return PriorityDispatchResult.CONTINUE + + alter_column_op.existing_comment = conn_col_comment + + if conn_col_comment != metadata_comment: + alter_column_op.modify_comment = metadata_comment + log.info("Detected column comment '%s.%s'", tname, cname) + + return PriorityDispatchResult.STOP + else: + return PriorityDispatchResult.CONTINUE + + +def _compare_table_comment( + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], +) -> PriorityDispatchResult: + assert autogen_context.dialect is not None + if not autogen_context.dialect.supports_comments: + return PriorityDispatchResult.CONTINUE + + # if we're doing CREATE TABLE, comments will be created inline + # with the create_table op. + if conn_table is None or metadata_table is None: + return PriorityDispatchResult.CONTINUE + + if conn_table.comment is None and metadata_table.comment is None: + return PriorityDispatchResult.CONTINUE + + if metadata_table.comment is None and conn_table.comment is not None: + modify_table_ops.ops.append( + ops.DropTableCommentOp( + tname, existing_comment=conn_table.comment, schema=schema + ) + ) + return PriorityDispatchResult.STOP + elif metadata_table.comment != conn_table.comment: + modify_table_ops.ops.append( + ops.CreateTableCommentOp( + tname, + metadata_table.comment, + existing_comment=conn_table.comment, + schema=schema, + ) + ) + return PriorityDispatchResult.STOP + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _compare_column_comment, + "column", + "comments", + ) + plugin.add_autogenerate_comparator( + _compare_table_comment, + "table", + "comments", + ) diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare/constraints.py similarity index 51% rename from alembic/autogenerate/compare.py rename to alembic/autogenerate/compare/constraints.py index a9adda1c..0b524b97 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare/constraints.py @@ -1,495 +1,52 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics +# mypy: allow-untyped-defs, allow-untyped-calls, allow-incomplete-defs from __future__ import annotations -import contextlib import logging -import re from typing import Any from typing import cast from typing import Dict -from typing import Iterator from typing import Mapping from typing import Optional -from typing import Set -from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy import event -from sqlalchemy import inspect from sqlalchemy import schema as sa_schema from sqlalchemy import text -from sqlalchemy import types as sqltypes from sqlalchemy.sql import expression -from sqlalchemy.sql.elements import conv from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import UniqueConstraint -from sqlalchemy.util import OrderedSet -from .. import util -from ..ddl._autogen import is_index_sig -from ..ddl._autogen import is_uq_sig -from ..operations import ops -from ..util import sqla_compat +from .util import _InspectorConv +from ... import util +from ...ddl._autogen import is_index_sig +from ...ddl._autogen import is_uq_sig +from ...operations import ops +from ...util import PriorityDispatchResult +from ...util import sqla_compat if TYPE_CHECKING: - from typing import Literal - - from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Table - from alembic.autogenerate.api import AutogenContext - from alembic.ddl.impl import DefaultImpl - from alembic.operations.ops import AlterColumnOp - from alembic.operations.ops import MigrationScript - from alembic.operations.ops import ModifyTableOps - from alembic.operations.ops import UpgradeOps - from ..ddl._autogen import _constraint_sig - - -log = logging.getLogger(__name__) - - -def _populate_migration_script( - autogen_context: AutogenContext, migration_script: MigrationScript -) -> None: - upgrade_ops = migration_script.upgrade_ops_list[-1] - downgrade_ops = migration_script.downgrade_ops_list[-1] - - _produce_net_changes(autogen_context, upgrade_ops) - upgrade_ops.reverse_into(downgrade_ops) - - -comparators = util.Dispatcher(uselist=True) - - -def _produce_net_changes( - autogen_context: AutogenContext, upgrade_ops: UpgradeOps -) -> None: - connection = autogen_context.connection - assert connection is not None - include_schemas = autogen_context.opts.get("include_schemas", False) - - inspector: Inspector = inspect(connection) - - default_schema = connection.dialect.default_schema_name - schemas: Set[Optional[str]] - if include_schemas: - schemas = set(inspector.get_schema_names()) - # replace default schema name with None - schemas.discard("information_schema") - # replace the "default" schema with None - schemas.discard(default_schema) - schemas.add(None) - else: - schemas = {None} - - schemas = { - s for s in schemas if autogen_context.run_name_filters(s, "schema", {}) - } - - assert autogen_context.dialect is not None - comparators.dispatch("schema", autogen_context.dialect.name)( - autogen_context, upgrade_ops, schemas - ) - - -@comparators.dispatch_for("schema") -def _autogen_for_tables( - autogen_context: AutogenContext, - upgrade_ops: UpgradeOps, - schemas: Union[Set[None], Set[Optional[str]]], -) -> None: - inspector = autogen_context.inspector - - conn_table_names: Set[Tuple[Optional[str], str]] = set() - - version_table_schema = ( - autogen_context.migration_context.version_table_schema - ) - version_table = autogen_context.migration_context.version_table - - for schema_name in schemas: - tables = set(inspector.get_table_names(schema=schema_name)) - if schema_name == version_table_schema: - tables = tables.difference( - [autogen_context.migration_context.version_table] - ) - - conn_table_names.update( - (schema_name, tname) - for tname in tables - if autogen_context.run_name_filters( - tname, "table", {"schema_name": schema_name} - ) - ) - - metadata_table_names = OrderedSet( - [(table.schema, table.name) for table in autogen_context.sorted_tables] - ).difference([(version_table_schema, version_table)]) - - _compare_tables( - conn_table_names, - metadata_table_names, - inspector, - upgrade_ops, - autogen_context, - ) - - -def _compare_tables( - conn_table_names: set, - metadata_table_names: set, - inspector: Inspector, - upgrade_ops: UpgradeOps, - autogen_context: AutogenContext, -) -> None: - default_schema = inspector.bind.dialect.default_schema_name - - # tables coming from the connection will not have "schema" - # set if it matches default_schema_name; so we need a list - # of table names from local metadata that also have "None" if schema - # == default_schema_name. Most setups will be like this anyway but - # some are not (see #170) - metadata_table_names_no_dflt_schema = OrderedSet( - [ - (schema if schema != default_schema else None, tname) - for schema, tname in metadata_table_names - ] - ) - - # to adjust for the MetaData collection storing the tables either - # as "schemaname.tablename" or just "tablename", create a new lookup - # which will match the "non-default-schema" keys to the Table object. - tname_to_table = { - no_dflt_schema: autogen_context.table_key_to_table[ - sa_schema._get_table_key(tname, schema) - ] - for no_dflt_schema, (schema, tname) in zip( - metadata_table_names_no_dflt_schema, metadata_table_names - ) - } - metadata_table_names = metadata_table_names_no_dflt_schema - - for s, tname in metadata_table_names.difference(conn_table_names): - name = "%s.%s" % (s, tname) if s else tname - metadata_table = tname_to_table[(s, tname)] - if autogen_context.run_object_filters( - metadata_table, tname, "table", False, None - ): - upgrade_ops.ops.append( - ops.CreateTableOp.from_table(metadata_table) - ) - log.info("Detected added table %r", name) - modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) - - comparators.dispatch("table")( - autogen_context, - modify_table_ops, - s, - tname, - None, - metadata_table, - ) - if not modify_table_ops.is_empty(): - upgrade_ops.ops.append(modify_table_ops) - - removal_metadata = sa_schema.MetaData() - for s, tname in conn_table_names.difference(metadata_table_names): - name = sa_schema._get_table_key(tname, s) - exists = name in removal_metadata.tables - t = sa_schema.Table(tname, removal_metadata, schema=s) - - if not exists: - event.listen( - t, - "column_reflect", - # fmt: off - autogen_context.migration_context.impl. - _compat_autogen_column_reflect - (inspector), - # fmt: on - ) - _InspectorConv(inspector).reflect_table(t, include_columns=None) - if autogen_context.run_object_filters(t, tname, "table", True, None): - modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) - - comparators.dispatch("table")( - autogen_context, modify_table_ops, s, tname, t, None - ) - if not modify_table_ops.is_empty(): - upgrade_ops.ops.append(modify_table_ops) - - upgrade_ops.ops.append(ops.DropTableOp.from_table(t)) - log.info("Detected removed table %r", name) - - existing_tables = conn_table_names.intersection(metadata_table_names) - - existing_metadata = sa_schema.MetaData() - conn_column_info = {} - for s, tname in existing_tables: - name = sa_schema._get_table_key(tname, s) - exists = name in existing_metadata.tables - t = sa_schema.Table(tname, existing_metadata, schema=s) - if not exists: - event.listen( - t, - "column_reflect", - # fmt: off - autogen_context.migration_context.impl. - _compat_autogen_column_reflect(inspector), - # fmt: on - ) - _InspectorConv(inspector).reflect_table(t, include_columns=None) - - conn_column_info[(s, tname)] = t - - for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])): - s = s or None - name = "%s.%s" % (s, tname) if s else tname - metadata_table = tname_to_table[(s, tname)] - conn_table = existing_metadata.tables[name] - - if autogen_context.run_object_filters( - metadata_table, tname, "table", False, conn_table - ): - modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) - with _compare_columns( - s, - tname, - conn_table, - metadata_table, - modify_table_ops, - autogen_context, - inspector, - ): - comparators.dispatch("table")( - autogen_context, - modify_table_ops, - s, - tname, - conn_table, - metadata_table, - ) - - if not modify_table_ops.is_empty(): - upgrade_ops.ops.append(modify_table_ops) - - -_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict( - { - "asc": expression.asc, - "desc": expression.desc, - "nulls_first": expression.nullsfirst, - "nulls_last": expression.nullslast, - "nullsfirst": expression.nullsfirst, # 1_3 name - "nullslast": expression.nullslast, # 1_3 name - } -) - - -def _make_index( - impl: DefaultImpl, params: Dict[str, Any], conn_table: Table -) -> Optional[Index]: - exprs: list[Union[Column[Any], TextClause]] = [] - sorting = params.get("column_sorting") - - for num, col_name in enumerate(params["column_names"]): - item: Union[Column[Any], TextClause] - if col_name is None: - assert "expressions" in params - name = params["expressions"][num] - item = text(name) - else: - name = col_name - item = conn_table.c[col_name] - if sorting and name in sorting: - for operator in sorting[name]: - if operator in _IndexColumnSortingOps: - item = _IndexColumnSortingOps[operator](item) - exprs.append(item) - ix = sa_schema.Index( - params["name"], - *exprs, - unique=params["unique"], - _table=conn_table, - **impl.adjust_reflected_dialect_options(params, "index"), - ) - if "duplicates_constraint" in params: - ix.info["duplicates_constraint"] = params["duplicates_constraint"] - return ix - - -def _make_unique_constraint( - impl: DefaultImpl, params: Dict[str, Any], conn_table: Table -) -> UniqueConstraint: - uq = sa_schema.UniqueConstraint( - *[conn_table.c[cname] for cname in params["column_names"]], - name=params["name"], - **impl.adjust_reflected_dialect_options(params, "unique_constraint"), - ) - if "duplicates_index" in params: - uq.info["duplicates_index"] = params["duplicates_index"] - - return uq - - -def _make_foreign_key( - params: Dict[str, Any], conn_table: Table -) -> ForeignKeyConstraint: - tname = params["referred_table"] - if params["referred_schema"]: - tname = "%s.%s" % (params["referred_schema"], tname) - - options = params.get("options", {}) - - const = sa_schema.ForeignKeyConstraint( - [conn_table.c[cname] for cname in params["constrained_columns"]], - ["%s.%s" % (tname, n) for n in params["referred_columns"]], - onupdate=options.get("onupdate"), - ondelete=options.get("ondelete"), - deferrable=options.get("deferrable"), - initially=options.get("initially"), - name=params["name"], - ) - # needed by 0.7 - conn_table.append_constraint(const) - return const - - -@contextlib.contextmanager -def _compare_columns( - schema: Optional[str], - tname: Union[quoted_name, str], - conn_table: Table, - metadata_table: Table, - modify_table_ops: ModifyTableOps, - autogen_context: AutogenContext, - inspector: Inspector, -) -> Iterator[None]: - name = "%s.%s" % (schema, tname) if schema else tname - metadata_col_names = OrderedSet( - c.name for c in metadata_table.c if not c.system - ) - metadata_cols_by_name = { - c.name: c for c in metadata_table.c if not c.system - } - - conn_col_names = { - c.name: c - for c in conn_table.c - if autogen_context.run_name_filters( - c.name, "column", {"table_name": tname, "schema_name": schema} - ) - } - - for cname in metadata_col_names.difference(conn_col_names): - if autogen_context.run_object_filters( - metadata_cols_by_name[cname], cname, "column", False, None - ): - modify_table_ops.ops.append( - ops.AddColumnOp.from_column_and_tablename( - schema, tname, metadata_cols_by_name[cname] - ) - ) - log.info("Detected added column '%s.%s'", name, cname) - - for colname in metadata_col_names.intersection(conn_col_names): - metadata_col = metadata_cols_by_name[colname] - conn_col = conn_table.c[colname] - if not autogen_context.run_object_filters( - metadata_col, colname, "column", False, conn_col - ): - continue - alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema) - - comparators.dispatch("column")( - autogen_context, - alter_column_op, - schema, - tname, - colname, - conn_col, - metadata_col, - ) - - if alter_column_op.has_changes(): - modify_table_ops.ops.append(alter_column_op) - - yield - - for cname in set(conn_col_names).difference(metadata_col_names): - if autogen_context.run_object_filters( - conn_table.c[cname], cname, "column", True, None - ): - modify_table_ops.ops.append( - ops.DropColumnOp.from_column_and_tablename( - schema, tname, conn_table.c[cname] - ) - ) - log.info("Detected removed column '%s.%s'", name, cname) + from ...autogenerate.api import AutogenContext + from ...ddl._autogen import _constraint_sig + from ...ddl.impl import DefaultImpl + from ...operations.ops import AlterColumnOp + from ...operations.ops import ModifyTableOps + from ...runtime.plugins import Plugin _C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index]) -class _InspectorConv: - __slots__ = ("inspector",) - - def __init__(self, inspector): - self.inspector = inspector - - def _apply_reflectinfo_conv(self, consts): - if not consts: - return consts - for const in consts: - if const["name"] is not None and not isinstance( - const["name"], conv - ): - const["name"] = conv(const["name"]) - return consts - - def _apply_constraint_conv(self, consts): - if not consts: - return consts - for const in consts: - if const.name is not None and not isinstance(const.name, conv): - const.name = conv(const.name) - return consts - - def get_indexes(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_indexes(*args, **kw) - ) - - def get_unique_constraints(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_unique_constraints(*args, **kw) - ) - - def get_foreign_keys(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_foreign_keys(*args, **kw) - ) - - def reflect_table(self, table, *, include_columns): - self.inspector.reflect_table(table, include_columns=include_columns) - - # I had a cool version of this using _ReflectInfo, however that doesn't - # work in 1.4 and it's not public API in 2.x. Then this is just a two - # liner. So there's no competition... - self._apply_constraint_conv(table.constraints) - self._apply_constraint_conv(table.indexes) +log = logging.getLogger(__name__) -@comparators.dispatch_for("table") def _compare_indexes_and_uniques( autogen_context: AutogenContext, modify_ops: ModifyTableOps, @@ -497,7 +54,7 @@ def _compare_indexes_and_uniques( tname: Union[quoted_name, str], conn_table: Optional[Table], metadata_table: Optional[Table], -) -> None: +) -> PriorityDispatchResult: inspector = autogen_context.inspector is_create_table = conn_table is None is_drop_table = metadata_table is None @@ -636,8 +193,13 @@ def _compare_indexes_and_uniques( if c.is_named } - conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig] - conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig] + conn_uniques_by_name: Dict[ + sqla_compat._ConstraintName, + _constraint_sig[sa_schema.UniqueConstraint], + ] + conn_indexes_by_name: Dict[ + sqla_compat._ConstraintName, _constraint_sig[sa_schema.Index] + ] conn_uniques_by_name = {c.name: c for c in conn_unique_constraints} conn_indexes_by_name = {c.name: c for c in conn_indexes_sig} @@ -676,7 +238,12 @@ def _compare_indexes_and_uniques( # 4. The backend may double up indexes as unique constraints and # vice versa (e.g. MySQL, Postgresql) - def obj_added(obj: _constraint_sig): + def obj_added( + obj: ( + _constraint_sig[sa_schema.UniqueConstraint] + | _constraint_sig[sa_schema.Index] + ), + ): if is_index_sig(obj): if autogen_context.run_object_filters( obj.const, obj.name, "index", False, None @@ -709,7 +276,12 @@ def _compare_indexes_and_uniques( else: assert False - def obj_removed(obj: _constraint_sig): + def obj_removed( + obj: ( + _constraint_sig[sa_schema.UniqueConstraint] + | _constraint_sig[sa_schema.Index] + ), + ): if is_index_sig(obj): if obj.is_unique and not supports_unique_constraints: # many databases double up unique constraints @@ -742,8 +314,14 @@ def _compare_indexes_and_uniques( assert False def obj_changed( - old: _constraint_sig, - new: _constraint_sig, + old: ( + _constraint_sig[sa_schema.UniqueConstraint] + | _constraint_sig[sa_schema.Index] + ), + new: ( + _constraint_sig[sa_schema.UniqueConstraint] + | _constraint_sig[sa_schema.Index] + ), msg: str, ): if is_index_sig(old): @@ -815,6 +393,13 @@ def _compare_indexes_and_uniques( obj_removed(conn_obj) obj_added(metadata_obj) else: + # TODO: for plugins, let's do is_index_sig / is_uq_sig + # here so we know index or unique, then + # do a sub-dispatch, + # autogen_context.comparators.dispatch("index") + # or + # autogen_context.comparators.dispatch("unique_constraint") + # comparison = metadata_obj.compare_to_reflected(conn_obj) if comparison.is_different: @@ -843,6 +428,8 @@ def _compare_indexes_and_uniques( if uq_sig not in conn_uniques_by_sig: obj_added(unnamed_metadata_uniques[uq_sig]) + return PriorityDispatchResult.CONTINUE + def _correct_for_uq_duplicates_uix( conn_unique_constraints, @@ -907,307 +494,87 @@ def _correct_for_uq_duplicates_uix( conn_indexes.discard(conn_ix_names[overlap]) -@comparators.dispatch_for("column") -def _compare_nullable( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: Union[quoted_name, str], - cname: Union[quoted_name, str], - conn_col: Column[Any], - metadata_col: Column[Any], -) -> None: - metadata_col_nullable = metadata_col.nullable - conn_col_nullable = conn_col.nullable - alter_column_op.existing_nullable = conn_col_nullable - - if conn_col_nullable is not metadata_col_nullable: - if ( - sqla_compat._server_default_is_computed( - metadata_col.server_default, conn_col.server_default - ) - and sqla_compat._nullability_might_be_unset(metadata_col) - or ( - sqla_compat._server_default_is_identity( - metadata_col.server_default, conn_col.server_default - ) - ) - ): - log.info( - "Ignoring nullable change on identity column '%s.%s'", - tname, - cname, - ) - else: - alter_column_op.modify_nullable = metadata_col_nullable - log.info( - "Detected %s on column '%s.%s'", - "NULL" if metadata_col_nullable else "NOT NULL", - tname, - cname, - ) - - -@comparators.dispatch_for("column") -def _setup_autoincrement( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: Union[quoted_name, str], - cname: quoted_name, - conn_col: Column[Any], - metadata_col: Column[Any], -) -> None: - if metadata_col.table._autoincrement_column is metadata_col: - alter_column_op.kw["autoincrement"] = True - elif metadata_col.autoincrement is True: - alter_column_op.kw["autoincrement"] = True - elif metadata_col.autoincrement is False: - alter_column_op.kw["autoincrement"] = False - - -@comparators.dispatch_for("column") -def _compare_type( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: Union[quoted_name, str], - cname: Union[quoted_name, str], - conn_col: Column[Any], - metadata_col: Column[Any], -) -> None: - conn_type = conn_col.type - alter_column_op.existing_type = conn_type - metadata_type = metadata_col.type - if conn_type._type_affinity is sqltypes.NullType: - log.info( - "Couldn't determine database type " "for column '%s.%s'", - tname, - cname, - ) - return - if metadata_type._type_affinity is sqltypes.NullType: - log.info( - "Column '%s.%s' has no type within " "the model; can't compare", - tname, - cname, - ) - return - - isdiff = autogen_context.migration_context._compare_type( - conn_col, metadata_col - ) +_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict( + { + "asc": expression.asc, + "desc": expression.desc, + "nulls_first": expression.nullsfirst, + "nulls_last": expression.nullslast, + "nullsfirst": expression.nullsfirst, # 1_3 name + "nullslast": expression.nullslast, # 1_3 name + } +) - if isdiff: - alter_column_op.modify_type = metadata_type - log.info( - "Detected type change from %r to %r on '%s.%s'", - conn_type, - metadata_type, - tname, - cname, - ) +def _make_index( + impl: DefaultImpl, params: Dict[str, Any], conn_table: Table +) -> Optional[Index]: + exprs: list[Union[Column[Any], TextClause]] = [] + sorting = params.get("column_sorting") -def _render_server_default_for_compare( - metadata_default: Optional[Any], autogen_context: AutogenContext -) -> Optional[str]: - if isinstance(metadata_default, sa_schema.DefaultClause): - if isinstance(metadata_default.arg, str): - metadata_default = metadata_default.arg + for num, col_name in enumerate(params["column_names"]): + item: Union[Column[Any], TextClause] + if col_name is None: + assert "expressions" in params + name = params["expressions"][num] + item = text(name) else: - metadata_default = str( - metadata_default.arg.compile( - dialect=autogen_context.dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - if isinstance(metadata_default, str): - return metadata_default - else: - return None - - -def _normalize_computed_default(sqltext: str) -> str: - """we want to warn if a computed sql expression has changed. however - we don't want false positives and the warning is not that critical. - so filter out most forms of variability from the SQL text. - - """ - - return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower() - - -def _compare_computed_default( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: str, - cname: str, - conn_col: Column[Any], - metadata_col: Column[Any], -) -> None: - rendered_metadata_default = str( - cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile( - dialect=autogen_context.dialect, - compile_kwargs={"literal_binds": True}, - ) + name = col_name + item = conn_table.c[col_name] + if sorting and name in sorting: + for operator in sorting[name]: + if operator in _IndexColumnSortingOps: + item = _IndexColumnSortingOps[operator](item) + exprs.append(item) + ix = sa_schema.Index( + params["name"], + *exprs, + unique=params["unique"], + _table=conn_table, + **impl.adjust_reflected_dialect_options(params, "index"), ) + if "duplicates_constraint" in params: + ix.info["duplicates_constraint"] = params["duplicates_constraint"] + return ix - # since we cannot change computed columns, we do only a crude comparison - # here where we try to eliminate syntactical differences in order to - # get a minimal comparison just to emit a warning. - rendered_metadata_default = _normalize_computed_default( - rendered_metadata_default +def _make_unique_constraint( + impl: DefaultImpl, params: Dict[str, Any], conn_table: Table +) -> UniqueConstraint: + uq = sa_schema.UniqueConstraint( + *[conn_table.c[cname] for cname in params["column_names"]], + name=params["name"], + **impl.adjust_reflected_dialect_options(params, "unique_constraint"), ) + if "duplicates_index" in params: + uq.info["duplicates_index"] = params["duplicates_index"] - if isinstance(conn_col.server_default, sa_schema.Computed): - rendered_conn_default = str( - conn_col.server_default.sqltext.compile( - dialect=autogen_context.dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - if rendered_conn_default is None: - rendered_conn_default = "" - else: - rendered_conn_default = _normalize_computed_default( - rendered_conn_default - ) - else: - rendered_conn_default = "" - - if rendered_metadata_default != rendered_conn_default: - _warn_computed_not_supported(tname, cname) + return uq -def _warn_computed_not_supported(tname: str, cname: str) -> None: - util.warn("Computed default on %s.%s cannot be modified" % (tname, cname)) +def _make_foreign_key( + params: Dict[str, Any], conn_table: Table +) -> ForeignKeyConstraint: + tname = params["referred_table"] + if params["referred_schema"]: + tname = "%s.%s" % (params["referred_schema"], tname) + options = params.get("options", {}) -def _compare_identity_default( - autogen_context, - alter_column_op, - schema, - tname, - cname, - conn_col, - metadata_col, -): - impl = autogen_context.migration_context.impl - diff, ignored_attr, is_alter = impl._compare_identity_default( - metadata_col.server_default, conn_col.server_default + const = sa_schema.ForeignKeyConstraint( + [conn_table.c[cname] for cname in params["constrained_columns"]], + ["%s.%s" % (tname, n) for n in params["referred_columns"]], + onupdate=options.get("onupdate"), + ondelete=options.get("ondelete"), + deferrable=options.get("deferrable"), + initially=options.get("initially"), + name=params["name"], ) - - return diff, is_alter - - -@comparators.dispatch_for("column") -def _compare_server_default( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: Union[quoted_name, str], - cname: Union[quoted_name, str], - conn_col: Column[Any], - metadata_col: Column[Any], -) -> Optional[bool]: - metadata_default = metadata_col.server_default - conn_col_default = conn_col.server_default - if conn_col_default is None and metadata_default is None: - return False - - if sqla_compat._server_default_is_computed(metadata_default): - return _compare_computed_default( # type:ignore[func-returns-value] - autogen_context, - alter_column_op, - schema, - tname, - cname, - conn_col, - metadata_col, - ) - if sqla_compat._server_default_is_computed(conn_col_default): - _warn_computed_not_supported(tname, cname) - return False - - if sqla_compat._server_default_is_identity( - metadata_default, conn_col_default - ): - alter_column_op.existing_server_default = conn_col_default - diff, is_alter = _compare_identity_default( - autogen_context, - alter_column_op, - schema, - tname, - cname, - conn_col, - metadata_col, - ) - if is_alter: - alter_column_op.modify_server_default = metadata_default - if diff: - log.info( - "Detected server default on column '%s.%s': " - "identity options attributes %s", - tname, - cname, - sorted(diff), - ) - else: - rendered_metadata_default = _render_server_default_for_compare( - metadata_default, autogen_context - ) - - rendered_conn_default = ( - cast(Any, conn_col_default).arg.text if conn_col_default else None - ) - - alter_column_op.existing_server_default = conn_col_default - - is_diff = autogen_context.migration_context._compare_server_default( - conn_col, - metadata_col, - rendered_metadata_default, - rendered_conn_default, - ) - if is_diff: - alter_column_op.modify_server_default = metadata_default - log.info("Detected server default on column '%s.%s'", tname, cname) - - return None - - -@comparators.dispatch_for("column") -def _compare_column_comment( - autogen_context: AutogenContext, - alter_column_op: AlterColumnOp, - schema: Optional[str], - tname: Union[quoted_name, str], - cname: quoted_name, - conn_col: Column[Any], - metadata_col: Column[Any], -) -> Optional[Literal[False]]: - assert autogen_context.dialect is not None - if not autogen_context.dialect.supports_comments: - return None - - metadata_comment = metadata_col.comment - conn_col_comment = conn_col.comment - if conn_col_comment is None and metadata_comment is None: - return False - - alter_column_op.existing_comment = conn_col_comment - - if conn_col_comment != metadata_comment: - alter_column_op.modify_comment = metadata_comment - log.info("Detected column comment '%s.%s'", tname, cname) - - return None + # needed by 0.7 + conn_table.append_constraint(const) + return const -@comparators.dispatch_for("table") def _compare_foreign_keys( autogen_context: AutogenContext, modify_table_ops: ModifyTableOps, @@ -1215,11 +582,11 @@ def _compare_foreign_keys( tname: Union[quoted_name, str], conn_table: Table, metadata_table: Table, -) -> None: +) -> PriorityDispatchResult: # if we're doing CREATE TABLE, all FKs are created # inline within the table def if conn_table is None or metadata_table is None: - return + return PriorityDispatchResult.CONTINUE inspector = autogen_context.inspector metadata_fks = { @@ -1316,7 +683,7 @@ def _compare_foreign_keys( if removed_sig not in metadata_fks_by_sig: compare_to = ( metadata_fks_by_name[const.name].const - if const.name in metadata_fks_by_name + if const.name and const.name in metadata_fks_by_name else None ) _remove_fk(const, compare_to) @@ -1326,45 +693,71 @@ def _compare_foreign_keys( if added_sig not in conn_fks_by_sig: compare_to = ( conn_fks_by_name[const.name].const - if const.name in conn_fks_by_name + if const.name and const.name in conn_fks_by_name else None ) _add_fk(const, compare_to) + return PriorityDispatchResult.CONTINUE -@comparators.dispatch_for("table") -def _compare_table_comment( + +def _compare_nullable( autogen_context: AutogenContext, - modify_table_ops: ModifyTableOps, + alter_column_op: AlterColumnOp, schema: Optional[str], tname: Union[quoted_name, str], - conn_table: Optional[Table], - metadata_table: Optional[Table], -) -> None: - assert autogen_context.dialect is not None - if not autogen_context.dialect.supports_comments: - return - - # if we're doing CREATE TABLE, comments will be created inline - # with the create_table op. - if conn_table is None or metadata_table is None: - return - - if conn_table.comment is None and metadata_table.comment is None: - return + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + metadata_col_nullable = metadata_col.nullable + conn_col_nullable = conn_col.nullable + alter_column_op.existing_nullable = conn_col_nullable - if metadata_table.comment is None and conn_table.comment is not None: - modify_table_ops.ops.append( - ops.DropTableCommentOp( - tname, existing_comment=conn_table.comment, schema=schema + if conn_col_nullable is not metadata_col_nullable: + if ( + sqla_compat._server_default_is_computed( + metadata_col.server_default, conn_col.server_default ) - ) - elif metadata_table.comment != conn_table.comment: - modify_table_ops.ops.append( - ops.CreateTableCommentOp( + and sqla_compat._nullability_might_be_unset(metadata_col) + or ( + sqla_compat._server_default_is_identity( + metadata_col.server_default, conn_col.server_default + ) + ) + ): + log.info( + "Ignoring nullable change on identity column '%s.%s'", tname, - metadata_table.comment, - existing_comment=conn_table.comment, - schema=schema, + cname, ) - ) + else: + alter_column_op.modify_nullable = metadata_col_nullable + log.info( + "Detected %s on column '%s.%s'", + "NULL" if metadata_col_nullable else "NOT NULL", + tname, + cname, + ) + # column nullablity changed, no further nullable checks needed + return PriorityDispatchResult.STOP + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _compare_indexes_and_uniques, + "table", + "indexes", + ) + plugin.add_autogenerate_comparator( + _compare_foreign_keys, + "table", + "foreignkeys", + ) + plugin.add_autogenerate_comparator( + _compare_nullable, + "column", + "nullable", + ) diff --git a/alembic/autogenerate/compare/schema.py b/alembic/autogenerate/compare/schema.py new file mode 100644 index 00000000..1f46aff4 --- /dev/null +++ b/alembic/autogenerate/compare/schema.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-calls + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + +from sqlalchemy import inspect + +from ...util import PriorityDispatchResult + +if TYPE_CHECKING: + from sqlalchemy.engine.reflection import Inspector + + from ...autogenerate.api import AutogenContext + from ...operations.ops import UpgradeOps + from ...runtime.plugins import Plugin + + +log = logging.getLogger(__name__) + + +def _produce_net_changes( + autogen_context: AutogenContext, upgrade_ops: UpgradeOps +) -> PriorityDispatchResult: + connection = autogen_context.connection + assert connection is not None + include_schemas = autogen_context.opts.get("include_schemas", False) + + inspector: Inspector = inspect(connection) + + default_schema = connection.dialect.default_schema_name + schemas: Set[Optional[str]] + if include_schemas: + schemas = set(inspector.get_schema_names()) + # replace default schema name with None + schemas.discard("information_schema") + # replace the "default" schema with None + schemas.discard(default_schema) + schemas.add(None) + else: + schemas = {None} + + schemas = { + s for s in schemas if autogen_context.run_name_filters(s, "schema", {}) + } + + assert autogen_context.dialect is not None + autogen_context.comparators.dispatch( + "schema", qualifier=autogen_context.dialect.name + )(autogen_context, upgrade_ops, schemas) + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _produce_net_changes, + "autogenerate", + ) diff --git a/alembic/autogenerate/compare/server_defaults.py b/alembic/autogenerate/compare/server_defaults.py new file mode 100644 index 00000000..e48f4c8e --- /dev/null +++ b/alembic/autogenerate/compare/server_defaults.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import logging +import re +from types import NoneType +from typing import Any +from typing import cast +from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema as sa_schema +from sqlalchemy.sql.schema import DefaultClause + +from ... import util +from ...util import DispatchPriority +from ...util import PriorityDispatchResult +from ...util import sqla_compat + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Column + + from ...autogenerate.api import AutogenContext + from ...operations.ops import AlterColumnOp + from ...runtime.plugins import Plugin + +log = logging.getLogger(__name__) + + +def _render_server_default_for_compare( + metadata_default: Optional[Any], autogen_context: AutogenContext +) -> Optional[str]: + if isinstance(metadata_default, sa_schema.DefaultClause): + if isinstance(metadata_default.arg, str): + metadata_default = metadata_default.arg + else: + metadata_default = str( + metadata_default.arg.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + if isinstance(metadata_default, str): + return metadata_default + else: + return None + + +def _normalize_computed_default(sqltext: str) -> str: + """we want to warn if a computed sql expression has changed. however + we don't want false positives and the warning is not that critical. + so filter out most forms of variability from the SQL text. + + """ + + return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower() + + +def _compare_computed_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: str, + cname: str, + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + + metadata_default = metadata_col.server_default + conn_col_default = conn_col.server_default + if conn_col_default is None and metadata_default is None: + return PriorityDispatchResult.CONTINUE + + if sqla_compat._server_default_is_computed( + conn_col_default + ) and not sqla_compat._server_default_is_computed(metadata_default): + _warn_computed_not_supported(tname, cname) + return PriorityDispatchResult.STOP + + if not sqla_compat._server_default_is_computed(metadata_default): + return PriorityDispatchResult.CONTINUE + + rendered_metadata_default = str( + cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + + # since we cannot change computed columns, we do only a crude comparison + # here where we try to eliminate syntactical differences in order to + # get a minimal comparison just to emit a warning. + + rendered_metadata_default = _normalize_computed_default( + rendered_metadata_default + ) + + if isinstance(conn_col.server_default, sa_schema.Computed): + rendered_conn_default = str( + conn_col.server_default.sqltext.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + rendered_conn_default = _normalize_computed_default( + rendered_conn_default + ) + else: + rendered_conn_default = "" + + if rendered_metadata_default != rendered_conn_default: + _warn_computed_not_supported(tname, cname) + + return PriorityDispatchResult.STOP + + +def _warn_computed_not_supported(tname: str, cname: str) -> None: + util.warn("Computed default on %s.%s cannot be modified" % (tname, cname)) + + +def _compare_identity_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], + skip: Sequence[str] = ( + "order", + "on_null", + "oracle_order", + "oracle_on_null", + ), +) -> PriorityDispatchResult: + + metadata_default = metadata_col.server_default + conn_col_default = conn_col.server_default + if ( + conn_col_default is None + and metadata_default is None + or not sqla_compat._server_default_is_identity( + metadata_default, conn_col_default + ) + ): + return PriorityDispatchResult.CONTINUE + + assert isinstance( + metadata_col.server_default, + (sa_schema.Identity, sa_schema.Sequence, NoneType), + ) + assert isinstance( + conn_col.server_default, + (sa_schema.Identity, sa_schema.Sequence, NoneType), + ) + + impl = autogen_context.migration_context.impl + diff, _, is_alter = impl._compare_identity_default( # type: ignore[no-untyped-call] # noqa: E501 + metadata_col.server_default, conn_col.server_default + ) + + if is_alter: + alter_column_op.modify_server_default = metadata_default + if diff: + log.info( + "Detected server default on column '%s.%s': " + "identity options attributes %s", + tname, + cname, + sorted(diff), + ) + + return PriorityDispatchResult.STOP + + return PriorityDispatchResult.CONTINUE + + +def _user_compare_server_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + + metadata_default = metadata_col.server_default + conn_col_default = conn_col.server_default + if conn_col_default is None and metadata_default is None: + return PriorityDispatchResult.CONTINUE + + alter_column_op.existing_server_default = conn_col_default + + migration_context = autogen_context.migration_context + + if migration_context._user_compare_server_default is False: + return PriorityDispatchResult.STOP + + if not callable(migration_context._user_compare_server_default): + return PriorityDispatchResult.CONTINUE + + rendered_metadata_default = _render_server_default_for_compare( + metadata_default, autogen_context + ) + rendered_conn_default = ( + cast(Any, conn_col_default).arg.text if conn_col_default else None + ) + + is_diff = migration_context._user_compare_server_default( + migration_context, + conn_col, + metadata_col, + rendered_conn_default, + metadata_col.server_default, + rendered_metadata_default, + ) + if is_diff: + alter_column_op.modify_server_default = metadata_default + log.info( + "User defined function %s detected " + "server default on column '%s.%s'", + migration_context._user_compare_server_default, + tname, + cname, + ) + return PriorityDispatchResult.STOP + return PriorityDispatchResult.CONTINUE + + +def _dialect_impl_compare_server_default( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + """use dialect.impl.compare_server_default. + + This would in theory not be needed. however we dont know if any + third party libraries haven't made their own alembic dialect and + implemented this method. + + """ + metadata_default = metadata_col.server_default + conn_col_default = conn_col.server_default + if conn_col_default is None and metadata_default is None: + return PriorityDispatchResult.CONTINUE + + # this is already done by _user_compare_server_default, + # but doing it here also for unit tests that want to call + # _dialect_impl_compare_server_default directly + alter_column_op.existing_server_default = conn_col_default + + if not isinstance( + metadata_default, (DefaultClause, NoneType) + ) or not isinstance(conn_col_default, (DefaultClause, NoneType)): + return PriorityDispatchResult.CONTINUE + + migration_context = autogen_context.migration_context + + rendered_metadata_default = _render_server_default_for_compare( + metadata_default, autogen_context + ) + rendered_conn_default = ( + cast(Any, conn_col_default).arg.text if conn_col_default else None + ) + + is_diff = migration_context.impl.compare_server_default( # type: ignore[no-untyped-call] # noqa: E501 + conn_col, + metadata_col, + rendered_metadata_default, + rendered_conn_default, + ) + if is_diff: + alter_column_op.modify_server_default = metadata_default + log.info( + "Dialect impl %s detected server default on column '%s.%s'", + migration_context.impl, + tname, + cname, + ) + return PriorityDispatchResult.STOP + return PriorityDispatchResult.CONTINUE + + +def _setup_autoincrement( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + if metadata_col.table._autoincrement_column is metadata_col: + alter_column_op.kw["autoincrement"] = True + elif metadata_col.autoincrement is True: + alter_column_op.kw["autoincrement"] = True + elif metadata_col.autoincrement is False: + alter_column_op.kw["autoincrement"] = False + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _user_compare_server_default, + "column", + "server_default", + priority=DispatchPriority.FIRST, + ) + plugin.add_autogenerate_comparator( + _compare_computed_default, + "column", + "server_default", + ) + + plugin.add_autogenerate_comparator( + _compare_identity_default, + "column", + "server_default", + ) + + plugin.add_autogenerate_comparator( + _setup_autoincrement, + "column", + "server_default", + ) + plugin.add_autogenerate_comparator( + _dialect_impl_compare_server_default, + "column", + "server_default", + priority=DispatchPriority.LAST, + ) diff --git a/alembic/autogenerate/compare/tables.py b/alembic/autogenerate/compare/tables.py new file mode 100644 index 00000000..0847ff5e --- /dev/null +++ b/alembic/autogenerate/compare/tables.py @@ -0,0 +1,303 @@ +# mypy: allow-untyped-calls + +from __future__ import annotations + +import contextlib +import logging +from typing import Iterator +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import event +from sqlalchemy import schema as sa_schema +from sqlalchemy.util import OrderedSet + +from .util import _InspectorConv +from ...operations import ops +from ...util import PriorityDispatchResult + +if TYPE_CHECKING: + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Table + + from ...autogenerate.api import AutogenContext + from ...operations.ops import ModifyTableOps + from ...operations.ops import UpgradeOps + from ...runtime.plugins import Plugin + + +log = logging.getLogger(__name__) + + +def _autogen_for_tables( + autogen_context: AutogenContext, + upgrade_ops: UpgradeOps, + schemas: Set[Optional[str]], +) -> PriorityDispatchResult: + inspector = autogen_context.inspector + + conn_table_names: Set[Tuple[Optional[str], str]] = set() + + version_table_schema = ( + autogen_context.migration_context.version_table_schema + ) + version_table = autogen_context.migration_context.version_table + + for schema_name in schemas: + tables = set(inspector.get_table_names(schema=schema_name)) + if schema_name == version_table_schema: + tables = tables.difference( + [autogen_context.migration_context.version_table] + ) + + conn_table_names.update( + (schema_name, tname) + for tname in tables + if autogen_context.run_name_filters( + tname, "table", {"schema_name": schema_name} + ) + ) + + metadata_table_names = OrderedSet( + [(table.schema, table.name) for table in autogen_context.sorted_tables] + ).difference([(version_table_schema, version_table)]) + + _compare_tables( + conn_table_names, + metadata_table_names, + inspector, + upgrade_ops, + autogen_context, + ) + + return PriorityDispatchResult.CONTINUE + + +def _compare_tables( + conn_table_names: set[tuple[str | None, str]], + metadata_table_names: set[tuple[str | None, str]], + inspector: Inspector, + upgrade_ops: UpgradeOps, + autogen_context: AutogenContext, +) -> None: + default_schema = inspector.bind.dialect.default_schema_name + + # tables coming from the connection will not have "schema" + # set if it matches default_schema_name; so we need a list + # of table names from local metadata that also have "None" if schema + # == default_schema_name. Most setups will be like this anyway but + # some are not (see #170) + metadata_table_names_no_dflt_schema = OrderedSet( + [ + (schema if schema != default_schema else None, tname) + for schema, tname in metadata_table_names + ] + ) + + # to adjust for the MetaData collection storing the tables either + # as "schemaname.tablename" or just "tablename", create a new lookup + # which will match the "non-default-schema" keys to the Table object. + tname_to_table = { + no_dflt_schema: autogen_context.table_key_to_table[ + sa_schema._get_table_key(tname, schema) + ] + for no_dflt_schema, (schema, tname) in zip( + metadata_table_names_no_dflt_schema, metadata_table_names + ) + } + metadata_table_names = metadata_table_names_no_dflt_schema + + for s, tname in metadata_table_names.difference(conn_table_names): + name = "%s.%s" % (s, tname) if s else tname + metadata_table = tname_to_table[(s, tname)] + if autogen_context.run_object_filters( + metadata_table, tname, "table", False, None + ): + upgrade_ops.ops.append( + ops.CreateTableOp.from_table(metadata_table) + ) + log.info("Detected added table %r", name) + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + autogen_context.comparators.dispatch( + "table", qualifier=autogen_context.dialect.name + )( + autogen_context, + modify_table_ops, + s, + tname, + None, + metadata_table, + ) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + removal_metadata = sa_schema.MetaData() + for s, tname in conn_table_names.difference(metadata_table_names): + name = sa_schema._get_table_key(tname, s) + exists = name in removal_metadata.tables + t = sa_schema.Table(tname, removal_metadata, schema=s) + + if not exists: + event.listen( + t, + "column_reflect", + # fmt: off + autogen_context.migration_context.impl. + _compat_autogen_column_reflect + (inspector), + # fmt: on + ) + _InspectorConv(inspector).reflect_table(t, include_columns=None) + if autogen_context.run_object_filters(t, tname, "table", True, None): + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + + autogen_context.comparators.dispatch( + "table", qualifier=autogen_context.dialect.name + )(autogen_context, modify_table_ops, s, tname, t, None) + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + upgrade_ops.ops.append(ops.DropTableOp.from_table(t)) + log.info("Detected removed table %r", name) + + existing_tables = conn_table_names.intersection(metadata_table_names) + + existing_metadata = sa_schema.MetaData() + conn_column_info = {} + for s, tname in existing_tables: + name = sa_schema._get_table_key(tname, s) + exists = name in existing_metadata.tables + t = sa_schema.Table(tname, existing_metadata, schema=s) + if not exists: + event.listen( + t, + "column_reflect", + # fmt: off + autogen_context.migration_context.impl. + _compat_autogen_column_reflect(inspector), + # fmt: on + ) + _InspectorConv(inspector).reflect_table(t, include_columns=None) + + conn_column_info[(s, tname)] = t + + for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])): + s = s or None + name = "%s.%s" % (s, tname) if s else tname + metadata_table = tname_to_table[(s, tname)] + conn_table = existing_metadata.tables[name] + + if autogen_context.run_object_filters( + metadata_table, tname, "table", False, conn_table + ): + modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) + with _compare_columns( + s, + tname, + conn_table, + metadata_table, + modify_table_ops, + autogen_context, + inspector, + ): + autogen_context.comparators.dispatch( + "table", qualifier=autogen_context.dialect.name + )( + autogen_context, + modify_table_ops, + s, + tname, + conn_table, + metadata_table, + ) + + if not modify_table_ops.is_empty(): + upgrade_ops.ops.append(modify_table_ops) + + +@contextlib.contextmanager +def _compare_columns( + schema: Optional[str], + tname: Union[quoted_name, str], + conn_table: Table, + metadata_table: Table, + modify_table_ops: ModifyTableOps, + autogen_context: AutogenContext, + inspector: Inspector, +) -> Iterator[None]: + name = "%s.%s" % (schema, tname) if schema else tname + metadata_col_names = OrderedSet( + c.name for c in metadata_table.c if not c.system + ) + metadata_cols_by_name = { + c.name: c for c in metadata_table.c if not c.system + } + + conn_col_names = { + c.name: c + for c in conn_table.c + if autogen_context.run_name_filters( + c.name, "column", {"table_name": tname, "schema_name": schema} + ) + } + + for cname in metadata_col_names.difference(conn_col_names): + if autogen_context.run_object_filters( + metadata_cols_by_name[cname], cname, "column", False, None + ): + modify_table_ops.ops.append( + ops.AddColumnOp.from_column_and_tablename( + schema, tname, metadata_cols_by_name[cname] + ) + ) + log.info("Detected added column '%s.%s'", name, cname) + + for colname in metadata_col_names.intersection(conn_col_names): + metadata_col = metadata_cols_by_name[colname] + conn_col = conn_table.c[colname] + if not autogen_context.run_object_filters( + metadata_col, colname, "column", False, conn_col + ): + continue + alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema) + + autogen_context.comparators.dispatch( + "column", qualifier=autogen_context.dialect.name + )( + autogen_context, + alter_column_op, + schema, + tname, + colname, + conn_col, + metadata_col, + ) + + if alter_column_op.has_changes(): + modify_table_ops.ops.append(alter_column_op) + + yield + + for cname in set(conn_col_names).difference(metadata_col_names): + if autogen_context.run_object_filters( + conn_table.c[cname], cname, "column", True, None + ): + modify_table_ops.ops.append( + ops.DropColumnOp.from_column_and_tablename( + schema, tname, conn_table.c[cname] + ) + ) + log.info("Detected removed column '%s.%s'", name, cname) + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _autogen_for_tables, + "schema", + "tables", + ) diff --git a/alembic/autogenerate/compare/types.py b/alembic/autogenerate/compare/types.py new file mode 100644 index 00000000..1d5d160a --- /dev/null +++ b/alembic/autogenerate/compare/types.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import logging +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import types as sqltypes + +from ...util import DispatchPriority +from ...util import PriorityDispatchResult + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Column + + from ...autogenerate.api import AutogenContext + from ...operations.ops import AlterColumnOp + from ...runtime.plugins import Plugin + + +log = logging.getLogger(__name__) + + +def _compare_type_setup( + alter_column_op: AlterColumnOp, + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> bool: + + conn_type = conn_col.type + alter_column_op.existing_type = conn_type + metadata_type = metadata_col.type + if conn_type._type_affinity is sqltypes.NullType: + log.info( + "Couldn't determine database type for column '%s.%s'", + tname, + cname, + ) + return False + if metadata_type._type_affinity is sqltypes.NullType: + log.info( + "Column '%s.%s' has no type within the model; can't compare", + tname, + cname, + ) + return False + + return True + + +def _user_compare_type( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + + migration_context = autogen_context.migration_context + + if migration_context._user_compare_type is False: + return PriorityDispatchResult.STOP + + if not _compare_type_setup( + alter_column_op, tname, cname, conn_col, metadata_col + ): + return PriorityDispatchResult.CONTINUE + + if not callable(migration_context._user_compare_type): + return PriorityDispatchResult.CONTINUE + + is_diff = migration_context._user_compare_type( + migration_context, + conn_col, + metadata_col, + conn_col.type, + metadata_col.type, + ) + if is_diff: + alter_column_op.modify_type = metadata_col.type + log.info( + "Detected type change from %r to %r on '%s.%s'", + conn_col.type, + metadata_col.type, + tname, + cname, + ) + return PriorityDispatchResult.STOP + elif is_diff is False: + # if user compare type returns False and not None, + # it means "dont do any more type comparison" + return PriorityDispatchResult.STOP + + return PriorityDispatchResult.CONTINUE + + +def _dialect_impl_compare_type( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: Optional[str], + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column[Any], + metadata_col: Column[Any], +) -> PriorityDispatchResult: + + if not _compare_type_setup( + alter_column_op, tname, cname, conn_col, metadata_col + ): + return PriorityDispatchResult.CONTINUE + + migration_context = autogen_context.migration_context + is_diff = migration_context.impl.compare_type(conn_col, metadata_col) + + if is_diff: + alter_column_op.modify_type = metadata_col.type + log.info( + "Detected type change from %r to %r on '%s.%s'", + conn_col.type, + metadata_col.type, + tname, + cname, + ) + return PriorityDispatchResult.STOP + + return PriorityDispatchResult.CONTINUE + + +def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + _user_compare_type, + "column", + "types", + priority=DispatchPriority.FIRST, + ) + plugin.add_autogenerate_comparator( + _dialect_impl_compare_type, + "column", + "types", + priority=DispatchPriority.LAST, + ) diff --git a/alembic/autogenerate/compare/util.py b/alembic/autogenerate/compare/util.py new file mode 100644 index 00000000..199d8280 --- /dev/null +++ b/alembic/autogenerate/compare/util.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from sqlalchemy.sql.elements import conv + + +class _InspectorConv: + __slots__ = ("inspector",) + + def __init__(self, inspector): + self.inspector = inspector + + def _apply_reflectinfo_conv(self, consts): + if not consts: + return consts + for const in consts: + if const["name"] is not None and not isinstance( + const["name"], conv + ): + const["name"] = conv(const["name"]) + return consts + + def _apply_constraint_conv(self, consts): + if not consts: + return consts + for const in consts: + if const.name is not None and not isinstance(const.name, conv): + const.name = conv(const.name) + return consts + + def get_indexes(self, *args, **kw): + return self._apply_reflectinfo_conv( + self.inspector.get_indexes(*args, **kw) + ) + + def get_unique_constraints(self, *args, **kw): + return self._apply_reflectinfo_conv( + self.inspector.get_unique_constraints(*args, **kw) + ) + + def get_foreign_keys(self, *args, **kw): + return self._apply_reflectinfo_conv( + self.inspector.get_foreign_keys(*args, **kw) + ) + + def reflect_table(self, table, *, include_columns): + self.inspector.reflect_table(table, include_columns=include_columns) + + # I had a cool version of this using _ReflectInfo, however that doesn't + # work in 1.4 and it's not public API in 2.x. Then this is just a two + # liner. So there's no competition... + self._apply_constraint_conv(table.constraints) + self._apply_constraint_conv(table.indexes) diff --git a/alembic/context.pyi b/alembic/context.pyi index 9117c31e..6045d8b3 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -203,6 +203,7 @@ def configure( None, ] ] = None, + autogenerate_plugins: Optional[Sequence[str]] = None, **kw: Any, ) -> None: """Configure a :class:`.MigrationContext` within this @@ -622,6 +623,25 @@ def configure( :paramref:`.command.revision.process_revision_directives` + :param autogenerate_plugins: A list of string names of "plugins" that + should participate in this autogenerate run. Defaults to the list + ``["alembic.autogenerate.*"]``, which indicates that Alembic's default + autogeneration plugins will be used. + + See the section :ref:`plugins_autogenerate` for complete background + on how to use this parameter. + + .. versionadded:: 1.18.0 Added a new plugin system for autogenerate + compare directives. + + .. seealso:: + + :ref:`plugins_autogenerate` - background on enabling/disabling + autogenerate plugins + + :ref:`alembic.plugins.toplevel` - Introduction and documentation + to the plugin system + Parameters specific to individual backends: :param mssql_batch_separator: The "batch separator" which will diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 80ca2b6c..5817e2d9 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -441,6 +441,7 @@ class EnvironmentContext(util.ModuleClsProxy): sqlalchemy_module_prefix: str = "sa.", user_module_prefix: Optional[str] = None, on_version_apply: Optional[OnVersionApplyFn] = None, + autogenerate_plugins: Sequence[str] | None = None, **kw: Any, ) -> None: """Configure a :class:`.MigrationContext` within this @@ -860,6 +861,25 @@ class EnvironmentContext(util.ModuleClsProxy): :paramref:`.command.revision.process_revision_directives` + :param autogenerate_plugins: A list of string names of "plugins" that + should participate in this autogenerate run. Defaults to the list + ``["alembic.autogenerate.*"]``, which indicates that Alembic's default + autogeneration plugins will be used. + + See the section :ref:`plugins_autogenerate` for complete background + on how to use this parameter. + + .. versionadded:: 1.18.0 Added a new plugin system for autogenerate + compare directives. + + .. seealso:: + + :ref:`plugins_autogenerate` - background on enabling/disabling + autogenerate plugins + + :ref:`alembic.plugins.toplevel` - Introduction and documentation + to the plugin system + Parameters specific to individual backends: :param mssql_batch_separator: The "batch separator" which will @@ -903,6 +923,9 @@ class EnvironmentContext(util.ModuleClsProxy): opts["process_revision_directives"] = process_revision_directives opts["on_version_apply"] = util.to_tuple(on_version_apply, default=()) + if autogenerate_plugins is not None: + opts["autogenerate_plugins"] = autogenerate_plugins + if render_item is not None: opts["render_item"] = render_item opts["compare_type"] = compare_type diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index c1c7b0fc..3fccf22a 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -21,7 +21,6 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union -from sqlalchemy import Column from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy.engine import Engine @@ -706,54 +705,6 @@ class MigrationContext: else: return None - def _compare_type( - self, inspector_column: Column[Any], metadata_column: Column - ) -> bool: - if self._user_compare_type is False: - return False - - if callable(self._user_compare_type): - user_value = self._user_compare_type( - self, - inspector_column, - metadata_column, - inspector_column.type, - metadata_column.type, - ) - if user_value is not None: - return user_value - - return self.impl.compare_type(inspector_column, metadata_column) - - def _compare_server_default( - self, - inspector_column: Column[Any], - metadata_column: Column[Any], - rendered_metadata_default: Optional[str], - rendered_column_default: Optional[str], - ) -> bool: - if self._user_compare_server_default is False: - return False - - if callable(self._user_compare_server_default): - user_value = self._user_compare_server_default( - self, - inspector_column, - metadata_column, - rendered_column_default, - metadata_column.server_default, - rendered_metadata_default, - ) - if user_value is not None: - return user_value - - return self.impl.compare_server_default( - inspector_column, - metadata_column, - rendered_metadata_default, - rendered_column_default, - ) - class HeadMaintainer: def __init__(self, context: MigrationContext, heads: Any) -> None: diff --git a/alembic/runtime/plugins.py b/alembic/runtime/plugins.py new file mode 100644 index 00000000..4d47443f --- /dev/null +++ b/alembic/runtime/plugins.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from importlib import metadata +import logging +import re +from types import ModuleType +from typing import Callable +from typing import Pattern +from typing import TYPE_CHECKING + +from .. import util +from ..util import DispatchPriority +from ..util import PriorityDispatcher + +if TYPE_CHECKING: + from ..util import PriorityDispatchResult + +_all_plugins = {} + + +log = logging.getLogger("__name__") + + +class Plugin: + """Describe a series of functions that are pulled in as a plugin. + + This is initially to provide for portable lists of autogenerate + comparison functions, however the setup for a plugin can run any + other kinds of global registration as well. + + .. versionadded:: 1.18.0 + + """ + + def __init__(self, name: str): + self.name = name + log.info("setup plugin %s", name) + if name in _all_plugins: + raise ValueError(f"A plugin named {name} is already registered") + _all_plugins[name] = self + self.autogenerate_comparators = PriorityDispatcher() + + def remove(self) -> None: + """remove this plugin""" + + del _all_plugins[self.name] + + def add_autogenerate_comparator( + self, + fn: Callable[..., PriorityDispatchResult], + compare_target: str, + compare_element: str | None = None, + *, + qualifier: str = "default", + priority: DispatchPriority = DispatchPriority.MEDIUM, + ) -> None: + """Register an autogenerate comparison function. + + See the section :ref:`plugins_registering_autogenerate` for detailed + examples on how to use this method. + + :param fn: The comparison function to register. The function receives + arguments specific to the type of comparison being performed and + should return a :class:`.PriorityDispatchResult` value. + + :param compare_target: The type of comparison being performed + (e.g., ``"table"``, ``"column"``, ``"type"``). + + :param compare_element: Optional sub-element being compared within + the target type. + + :param qualifier: Database dialect qualifier. Use ``"default"`` for + all dialects, or specify a dialect name like ``"postgresql"`` to + register a dialect-specific handler. Defaults to ``"default"``. + + :param priority: Execution priority for this comparison function. + Functions are executed in priority order from + :attr:`.DispatchPriority.FIRST` to :attr:`.DispatchPriority.LAST`. + Defaults to :attr:`.DispatchPriority.MEDIUM`. + + """ + self.autogenerate_comparators.dispatch_for( + compare_target, + subgroup=compare_element, + priority=priority, + qualifier=qualifier, + )(fn) + + @classmethod + def populate_autogenerate_priority_dispatch( + cls, comparators: PriorityDispatcher, include_plugins: list[str] + ) -> None: + """Populate all current autogenerate comparison functions into + a given PriorityDispatcher.""" + + exclude: set[Pattern[str]] = set() + include: dict[str, Pattern[str]] = {} + + matched_expressions: set[str] = set() + + for name in include_plugins: + if name.startswith("~"): + exclude.add(_make_re(name[1:])) + else: + include[name] = _make_re(name) + + for plugin in _all_plugins.values(): + if any(excl.match(plugin.name) for excl in exclude): + continue + + include_matches = [ + incl for incl in include if include[incl].match(plugin.name) + ] + if not include_matches: + continue + else: + matched_expressions.update(include_matches) + + log.info("setting up autogenerate plugin %s", plugin.name) + comparators.populate_with(plugin.autogenerate_comparators) + + never_matched = set(include).difference(matched_expressions) + if never_matched: + raise util.CommandError( + f"Did not locate plugins: {', '.join(never_matched)}" + ) + + @classmethod + def setup_plugin_from_module(cls, module: ModuleType, name: str) -> None: + """Call the ``setup()`` function of a plugin module, identified by + passing the module object itself. + + E.g.:: + + from alembic.runtime.plugins import Plugin + import myproject.alembic_plugin + + # Register the plugin manually + Plugin.setup_plugin_from_module( + myproject.alembic_plugin, + "myproject.custom_operations" + ) + + This will generate a new :class:`.Plugin` object with the given + name, which will register itself in the global list of plugins. + Then the module's ``setup()`` function is invoked, passing that + :class:`.Plugin` object. + + This exact process is invoked automatically at import time for any + plugin module that is published via the ``alembic.plugins`` entrypoint. + + """ + module.setup(Plugin(name)) + + +def _make_re(name: str) -> Pattern[str]: + tokens = name.split(".") + + reg = r"" + for token in tokens: + if token == "*": + reg += r"\..+?" + elif token.isidentifier(): + reg += r"\." + token + else: + raise ValueError(f"Invalid plugin expression {name!r}") + + # omit leading r'\.' + return re.compile(f"^{reg[2:]}$") + + +def _setup() -> None: + # setup third party plugins + for entrypoint in metadata.entry_points(group="alembic.plugins"): + for mod in entrypoint.load(): + Plugin.setup_plugin_from_module(mod, entrypoint.name) + + +_setup() diff --git a/alembic/testing/suite/_autogen_fixtures.py b/alembic/testing/suite/_autogen_fixtures.py index ed4acb26..8329a1ac 100644 --- a/alembic/testing/suite/_autogen_fixtures.py +++ b/alembic/testing/suite/_autogen_fixtures.py @@ -2,6 +2,8 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Literal +from typing import overload from typing import Set from sqlalchemy import CHAR @@ -381,17 +383,46 @@ class AutogenTest(_ComparesFKs): class AutogenFixtureTest(_ComparesFKs): + + @overload + def _fixture( + self, + m1: MetaData, + m2: MetaData, + include_schemas=..., + opts=..., + object_filters=..., + name_filters=..., + *, + return_ops: Literal[True], + max_identifier_length=..., + ) -> ops.UpgradeOps: ... + + @overload def _fixture( self, - m1, - m2, + m1: MetaData, + m2: MetaData, + include_schemas=..., + opts=..., + object_filters=..., + name_filters=..., + *, + return_ops: Literal[False] = ..., + max_identifier_length=..., + ) -> list[Any]: ... + + def _fixture( + self, + m1: MetaData, + m2: MetaData, include_schemas=False, opts=None, object_filters=_default_object_filters, name_filters=_default_name_filters, - return_ops=False, + return_ops: bool = False, max_identifier_length=None, - ): + ) -> ops.UpgradeOps | list[Any]: if max_identifier_length: dialect = self.bind.dialect existing_length = dialect.max_identifier_length diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index c1411157..8f3f685b 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -6,11 +6,14 @@ from .langhelpers import _with_legacy_names as _with_legacy_names from .langhelpers import asbool as asbool from .langhelpers import dedupe_tuple as dedupe_tuple from .langhelpers import Dispatcher as Dispatcher +from .langhelpers import DispatchPriority as DispatchPriority from .langhelpers import EMPTY_DICT as EMPTY_DICT from .langhelpers import immutabledict as immutabledict from .langhelpers import memoized_property as memoized_property from .langhelpers import ModuleClsProxy as ModuleClsProxy from .langhelpers import not_none as not_none +from .langhelpers import PriorityDispatcher as PriorityDispatcher +from .langhelpers import PriorityDispatchResult as PriorityDispatchResult from .langhelpers import rev_id as rev_id from .langhelpers import to_list as to_list from .langhelpers import to_tuple as to_tuple diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index baba898f..cf0df239 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -2,6 +2,7 @@ from __future__ import annotations import collections from collections.abc import Iterable +import enum import textwrap from typing import Any from typing import Callable @@ -17,9 +18,7 @@ from typing import Sequence from typing import Set from typing import Tuple from typing import Type -from typing import TYPE_CHECKING from typing import TypeVar -from typing import Union import uuid import warnings @@ -264,25 +263,63 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: return tuple(unique_list(tup)) +class PriorityDispatchResult(enum.Enum): + """indicate an action after running a function within a + :class:`.PriorityDispatcher` + + .. versionadded:: 1.18.0 + + """ + + CONTINUE = enum.auto() + """Continue running more functions. + + Any return value that is not PriorityDispatchResult.STOP is equivalent + to this. + + """ + + STOP = enum.auto() + """Stop running any additional functions within the subgroup""" + + +class DispatchPriority(enum.IntEnum): + """Indicate which of three sub-collections a function inside a + :class:`.PriorityDispatcher` should be placed. + + .. versionadded:: 1.18.0 + + """ + + FIRST = 50 + """Run the funciton in the first batch of functions (highest priority)""" + + MEDIUM = 25 + """Run the function at normal priority (this is the default)""" + + LAST = 10 + """Run the function in the last batch of functions""" + + class Dispatcher: - def __init__(self, uselist: bool = False) -> None: + def __init__(self) -> None: self._registry: Dict[Tuple[Any, ...], Any] = {} - self.uselist = uselist def dispatch_for( - self, target: Any, qualifier: str = "default", replace: bool = False + self, + target: Any, + *, + qualifier: str = "default", + replace: bool = False, ) -> Callable[[_C], _C]: def decorate(fn: _C) -> _C: - if self.uselist: - self._registry.setdefault((target, qualifier), []).append(fn) - else: - if (target, qualifier) in self._registry and not replace: - raise ValueError( - "Can not set dispatch function for object " - f"{target!r}: key already exists. To replace " - "existing function, use replace=True." - ) - self._registry[(target, qualifier)] = fn + if (target, qualifier) in self._registry and not replace: + raise ValueError( + "Can not set dispatch function for object " + f"{target!r}: key already exists. To replace " + "existing function, use replace=True." + ) + self._registry[(target, qualifier)] = fn return fn return decorate @@ -295,42 +332,113 @@ class Dispatcher: else: targets = type(obj).__mro__ - for spcls in targets: - if qualifier != "default" and (spcls, qualifier) in self._registry: - return self._fn_or_list(self._registry[(spcls, qualifier)]) - elif (spcls, "default") in self._registry: - return self._fn_or_list(self._registry[(spcls, "default")]) + if qualifier != "default": + qualifiers = [qualifier, "default"] else: - raise ValueError("no dispatch function for object: %s" % obj) + qualifiers = ["default"] - def _fn_or_list( - self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]] - ) -> Callable[..., Any]: - if self.uselist: - - def go(*arg: Any, **kw: Any) -> None: - if TYPE_CHECKING: - assert isinstance(fn_or_list, Sequence) - for fn in fn_or_list: - fn(*arg, **kw) - - return go + for spcls in targets: + for qualifier in qualifiers: + if (spcls, qualifier) in self._registry: + return self._registry[(spcls, qualifier)] else: - return fn_or_list # type: ignore + raise ValueError("no dispatch function for object: %s" % obj) def branch(self) -> Dispatcher: """Return a copy of this dispatcher that is independently writable.""" d = Dispatcher() - if self.uselist: - d._registry.update( - (k, [fn for fn in self._registry[k]]) for k in self._registry + d._registry.update(self._registry) + return d + + +class PriorityDispatcher: + """registers lists of functions at multiple levels of priorty and provides + a target to invoke them in priority order. + + .. versionadded:: 1.18.0 - PriorityDispatcher replaces the job + of Dispatcher(uselist=True) + + """ + + def __init__(self) -> None: + self._registry: dict[tuple[Any, ...], Any] = collections.defaultdict( + list + ) + + def dispatch_for( + self, + target: str, + *, + priority: DispatchPriority = DispatchPriority.MEDIUM, + qualifier: str = "default", + subgroup: str | None = None, + ) -> Callable[[_C], _C]: + """return a decorator callable that registers a function at a + given priority, with a given qualifier, to fire off for a given + subgroup. + + It's important this remains as a decorator to support third party + plugins who are populating the dispatcher using that style. + + """ + + def decorate(fn: _C) -> _C: + self._registry[(target, qualifier, priority)].append( + (fn, subgroup) ) + return fn + + return decorate + + def dispatch( + self, target: str, *, qualifier: str = "default" + ) -> Callable[..., None]: + """Provide a callable for the given target and qualifier.""" + + if qualifier != "default": + qualifiers = [qualifier, "default"] else: - d._registry.update(self._registry) + qualifiers = ["default"] + + def go(*arg: Any, **kw: Any) -> Any: + results_by_subgroup: dict[str, PriorityDispatchResult] = {} + for priority in DispatchPriority: + for qualifier in qualifiers: + for fn, subgroup in self._registry.get( + (target, qualifier, priority), () + ): + if ( + results_by_subgroup.get( + subgroup, PriorityDispatchResult.CONTINUE + ) + is PriorityDispatchResult.STOP + ): + continue + + result = fn(*arg, **kw) + results_by_subgroup[subgroup] = result + + return go + + def branch(self) -> PriorityDispatcher: + """Return a copy of this dispatcher that is independently + writable.""" + + d = PriorityDispatcher() + d.populate_with(self) return d + def populate_with(self, other: PriorityDispatcher) -> None: + """Populate this PriorityDispatcher with the contents of another one. + + Additive, does not remove existing contents. + """ + for k in other._registry: + new_list = other._registry[k] + self._registry[k].extend(new_list) + def not_none(value: Optional[_T]) -> _T: assert value is not None diff --git a/docs/build/api/autogenerate.rst b/docs/build/api/autogenerate.rst index 7d8043e8..dcee6fa4 100644 --- a/docs/build/api/autogenerate.rst +++ b/docs/build/api/autogenerate.rst @@ -459,20 +459,33 @@ routines to be able to locate, which can include any object such as custom DDL objects representing views, triggers, special constraints, or anything else we want to support. +.. _autogenerate_global_comparison_function: -Registering a Comparison Function ---------------------------------- +Registering a Comparison Function Globally +------------------------------------------ We now need to register a comparison hook, which will be used to compare the database to our model and produce ``CreateSequenceOp`` and ``DropSequenceOp`` directives to be included in our migration -script. Note that we are assuming a -Postgresql backend:: +script. The example below illustrates registering a comparison function +using the **global** dispatch:: from alembic.autogenerate import comparators + from alembic.autogenerate.api import AutogenContext + from alembic.operations.ops import UpgradeOps + # new in Alembic 1.18.0 - for older versions, no return value is needed + from alembic.util import PriorityDispatchResult + + # the global dispatch includes a decorator function. + # for plugin level dispatch, use Plugin.add_autogenerate_comparator() + # instead. @comparators.dispatch_for("schema") - def compare_sequences(autogen_context, upgrade_ops, schemas): + def compare_sequences( + autogen_context: AutogenContext, + upgrade_ops: UpgradeOps, + schemas: set[str | None] + ) -> PriorityDispatchResult: all_conn_sequences = set() for sch in schemas: @@ -510,12 +523,21 @@ Postgresql backend:: DropSequenceOp(name, schema=sch) ) + return PriorityDispatchResult.CONTINUE + Above, we've built a new function ``compare_sequences()`` and registered it as a "schema" level comparison function with autogenerate. The job that it performs is that it compares the list of sequence names present in each database schema with that of a list of sequence names that we are maintaining in our :class:`~sqlalchemy.schema.MetaData` object. +The registration of our function at the scope of "schema" means our +autogenerate comparison function is called outside of the context of any +specific table or column. The four available scopes are "autogenerate" (new in +1.18.0), "schema", "table", and "column"; these scopes are described fully in +the section :ref:`plugins_registering_autogenerate`, which details their use in +terms of a custom plugin, however the interfaces are the same. + When autogenerate completes, it will have a series of ``CreateSequenceOp`` and ``DropSequenceOp`` directives in the list of "upgrade" operations; the list of "downgrade" operations is generated @@ -523,54 +545,21 @@ directly from these using the ``CreateSequenceOp.reverse()`` and ``DropSequenceOp.reverse()`` methods that we've implemented on these objects. -The registration of our function at the scope of "schema" means our -autogenerate comparison function is called outside of the context -of any specific table or column. The three available scopes -are "schema", "table", and "column", summarized as follows: - -* **Schema level** - these hooks are passed a :class:`.AutogenContext`, - an :class:`.UpgradeOps` collection, and a collection of string schema - names to be operated upon. If the - :class:`.UpgradeOps` collection contains changes after all - hooks are run, it is included in the migration script: - - :: - - @comparators.dispatch_for("schema") - def compare_schema_level(autogen_context, upgrade_ops, schemas): - pass - -* **Table level** - these hooks are passed a :class:`.AutogenContext`, - a :class:`.ModifyTableOps` collection, a schema name, table name, - a :class:`~sqlalchemy.schema.Table` reflected from the database if any - or ``None``, and a :class:`~sqlalchemy.schema.Table` present in the - local :class:`~sqlalchemy.schema.MetaData`. If the - :class:`.ModifyTableOps` collection contains changes after all - hooks are run, it is included in the migration script: - - :: - - @comparators.dispatch_for("table") - def compare_table_level(autogen_context, modify_ops, - schemaname, tablename, conn_table, metadata_table): - pass - -* **Column level** - these hooks are passed a :class:`.AutogenContext`, - an :class:`.AlterColumnOp` object, a schema name, table name, - column name, a :class:`~sqlalchemy.schema.Column` reflected from the - database and a :class:`~sqlalchemy.schema.Column` present in the - local table. If the :class:`.AlterColumnOp` contains changes after - all hooks are run, it is included in the migration script; - a "change" is considered to be present if any of the ``modify_`` attributes - are set to a non-default value, or there are any keys - in the ``.kw`` collection with the prefix ``"modify_"``: - - :: - - @comparators.dispatch_for("column") - def compare_column_level(autogen_context, alter_column_op, - schemaname, tname, cname, conn_col, metadata_col): - pass +The example above illustrates registration with the so-called **global** +autogenerate dispatch, at ``alembic.autogenerate.comparators``. Alembic as of +version 1.18 also includes a **plugin level** dispatch, where comparison +functions are instead registered using +:meth:`.Plugin.add_autogenerate_comparator`. Comparison functions registered at +the plugin level operate in the same way as those registered globally, with the +exception that custom autogenerate compare functions must also be enabled at +the environment level within the +:attr:`.EnvironmentContext.configure.autogenerate_plugins` parameter, and also +have the ability to be omitted from an autogenerate run. + +.. seealso:: + + :ref:`plugins_registering_autogenerate` - newer plugin-level means of + registering autogenerate compare functions. The :class:`.AutogenContext` passed to these hooks is documented below. diff --git a/docs/build/api/index.rst b/docs/build/api/index.rst index 50a543b5..3d080339 100644 --- a/docs/build/api/index.rst +++ b/docs/build/api/index.rst @@ -29,5 +29,6 @@ to run commands programmatically, as discussed in the section :doc:`/api/command autogenerate script ddl + plugins exceptions diff --git a/docs/build/api/plugins.rst b/docs/build/api/plugins.rst new file mode 100644 index 00000000..d6d10fc9 --- /dev/null +++ b/docs/build/api/plugins.rst @@ -0,0 +1,466 @@ +.. _alembic.plugins.toplevel: + +======= +Plugins +======= + +.. versionadded:: 1.18.0 + +Alembic provides a plugin system that allows third-party extensions to +integrate with Alembic's functionality. Plugins can register custom operations, +operation implementations, autogenerate comparison functions, and other +extension points to add new capabilities to Alembic. + +The plugin system provides a structured way to organize and distribute these +extensions, allowing them to be discovered automatically using Python +entry points. + +Overview +======== + +The :class:`.Plugin` class provides the foundation for creating plugins. +A plugin's ``setup()`` function can perform various types of registrations: + +* **Custom operations** - Register new operation directives using + :meth:`.Operations.register_operation` (e.g., ``op.create_view()``) +* **Operation implementations** - Provide database-specific implementations + using :meth:`.Operations.implementation_for` +* **Autogenerate comparators** - Add comparison functions for detecting + schema differences during autogeneration +* **Other extensions** - Register any other global handlers or customizations + +A single plugin can register handlers across all of these categories. For +example, a plugin for custom database objects might register both the +operations to create/drop those objects and the autogenerate logic to +detect changes to them. + +.. seealso:: + + :ref:`replaceable_objects` - Cookbook recipe demonstrating custom + operations and implementations that would be suitable for packaging + as a plugin + +Installing and Using Plugins +============================ + +Third-party plugins are typically distributed as Python packages that can be +installed via pip or other package managers:: + + pip install mycompany-alembic-plugin + +Once installed, plugins that use Python's entry point system are automatically +discovered and loaded by Alembic at startup, which calls the plugin's +``setup()`` function to perform any registrations. + +Enable Autogenerate Plugins +--------------------------- + +For plugins that provide autogenerate comparison functions via the +:meth:`.Plugin.add_autogenerate_comparator` hook, the specific autogenerate +functionality registered by the plugin must be enabled with +:paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter, which +by default indicates that only Alembic's built-in plugins should be used. +Note that this step does not apply to older plugins that may be registering +autogenerate comparison functions globally. + +See the section :ref:`plugins_autogenerate` for background on enabling +autogenerate comparison plugins per environment. + +Using Plugins without entry points (such as local plugin code) +-------------------------------------------------------------- + +Plugins do not need to be published with entry points to be used. A plugin +can be manually registered by calling :meth:`.Plugin.setup_plugin_from_module` +in the ``env.py`` file:: + + from alembic.runtime.plugins import Plugin + import myproject.alembic_plugin + + # Register the plugin manually + Plugin.setup_plugin_from_module( + myproject.alembic_plugin, + "myproject.custom_operations" + ) + +This approach is useful for project-specific plugins that are not intended +for distribution, or for testing plugins during development. + +.. _plugins_autogenerate: + +Enabling Autogenerate Plugins in env.py +======================================= + +If a plugin provides autogenerate functionality that's registered via the +:meth:`.Plugin.add_autogenerate_comparator` hook, it can be selectively enabled +or disabled using the +:paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter in the +:meth:`.EnvironmentContext.configure` call, typically as used within the +``env.py`` file. This parameter is passed as a list of strings each naming a +specific plugin or a matching wildcard. The default value is +``["alembic.autogenerate.*"]`` which indicates that the full set of Alembic's +internal plugins should be used. + +The :paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter +accepts a list of string patterns: + +* Simple names match plugin names exactly: ``"alembic.autogenerate.tables"`` +* Wildcards match multiple plugins: ``"alembic.autogenerate.*"`` matches all + built-in plugins +* Negation patterns exclude plugins: ``"~alembic.autogenerate.comments"`` + excludes the comments plugin + +For example, to use all built-in plugins except comments, plus a custom +plugin:: + + context.configure( + # ... + autogenerate_plugins=[ + "alembic.autogenerate.*", + "~alembic.autogenerate.comments", + "mycompany.custom_types", + ] + ) + +The wildcard syntax using ``*`` indicates that tokens in that segment +of the name (separated by period characters) will match any name. For +Alembic's ``alembic.autogenerate.*`` namespace, the built in names being +invoked are: + +* ``alembic.autogenerate.schemas`` - Schema creation and dropping +* ``alembic.autogenerate.tables`` - Table creation, dropping, and modification. + This plugin depends on the ``schemas`` plugin in order to iterate through + tables. +* ``alembic.autogenerate.types`` - Column type changes. This plugin depends on + the ``tables`` plugin in order to iterate through columns. +* ``alembic.autogenerate.constraints`` - Constraint creation and dropping. This + plugin depends on the ``tables`` plugin in order to iterate through columns. +* ``alembic.autogenerate.defaults`` - Server default changes. This plugin + depends on the ``tables`` plugin in order to iterate through columns. +* ``alembic.autogenerate.comments`` - Table and column comment changes. This + plugin depends on the ``tables`` plugin in order to iterate through columns. + +While these names can be specified individually, they are subject to change +as Alembic evolves. Using the wildcard pattern is recommended. + +Omitting the built-in plugins entirely would prevent autogeneration from +proceeding, unless other plugins were provided that replaced its functionality +(which is possible!). Additionally, as noted above, the column-oriented plugins +rely on the table- and schema- oriented plugins in order to receive iterated +columns. + +The :paramref:`.EnvironmentContext.configure.autogenerate_plugins` +parameter only controls which plugins participate in autogenerate +operations. Other plugin functionality, such as custom operations +registered with :meth:`.Operations.register_operation`, is available +regardless of this setting. + + + + +Writing a Plugin +================ + +Creating a Plugin Module +------------------------- + +A plugin module must define a ``setup()`` function that accepts a +:class:`.Plugin` instance. This function is called when the plugin is +loaded, either automatically via entry points or manually via +:meth:`.Plugin.setup_plugin_from_module`:: + + from alembic import op + from alembic.operations import Operations + from alembic.runtime.plugins import Plugin + from alembic.util import DispatchPriority + + def setup(plugin: Plugin) -> None: + """Setup function called by Alembic when loading the plugin.""" + + # Register custom operations + Operations.register_operation("create_view")(CreateViewOp) + Operations.implementation_for(CreateViewOp)(create_view_impl) + + # Register autogenerate comparison functions + plugin.add_autogenerate_comparator( + _compare_views, + "view", + qualifier="default", + priority=DispatchPriority.MEDIUM, + ) + +The ``setup()`` function serves as the entry point for all plugin +registrations. It can call various Alembic APIs to extend functionality. + +Publishing a Plugin +------------------- + +To make a plugin available for installation via pip, create a package with +an entry point in ``pyproject.toml``:: + + [project.entry-points."alembic.plugins"] + mycompany.plugin_name = "mycompany.alembic_plugin" + +Where ``mycompany.alembic_plugin`` is the module containing the ``setup()`` +function. + +When the package is installed, Alembic automatically discovers and loads the +plugin through the entry point system. If the plugin provides autogenerate +functionality, users can then enable it by adding its name +``mycompany.plugin_name`` to the ``autogenerate_plugins`` list in their +``env.py``. + +Registering Custom Operations +------------------------------ + +Plugins can register new operation directives that become available as +``op.custom_operation()`` in migration scripts. This is done using +:meth:`.Operations.register_operation` and +:meth:`.Operations.implementation_for`. + +Example from the :ref:`replaceable_objects` recipe:: + + from alembic.operations import Operations, MigrateOperation + + class CreateViewOp(MigrateOperation): + def __init__(self, view_name, select_stmt): + self.view_name = view_name + self.select_stmt = select_stmt + + @Operations.register_operation("create_view") + class CreateViewOp(CreateViewOp): + pass + + @Operations.implementation_for(CreateViewOp) + def create_view(operations, operation): + operations.execute( + f"CREATE VIEW {operation.view_name} AS {operation.select_stmt}" + ) + +These registrations can be performed in the plugin's ``setup()`` function, +making the custom operations available globally. + +.. seealso:: + + :ref:`replaceable_objects` - Complete example of registering custom + operations + + :ref:`operation_plugins` - Documentation on the operations plugin system + +.. _plugins_registering_autogenerate: + +Registering Autogenerate Comparators at the Plugin Level +-------------------------------------------------------- + +Plugins can register comparison functions that participate in the autogenerate +process, detecting differences between database schema and SQLAlchemy metadata. +These functions may be registered globally, where they take place +unconditionally as documented at +:ref:`autogenerate_global_comparison_function`; for older versions of Alembic +prior to 1.18.0 this is the only registration system available. However when +targeting Alembic 1.18.0 or higher, the :class:`.Plugin` approach provides a +more configurable version of these registration hooks. + +Plugin level comparison functions are registered using +:meth:`.Plugin.add_autogenerate_comparator`. Each comparison function +establishes itself as part of a named "target", which is invoked by a parent +handler. For example, if a handler establishes itself as part of the +``"column"`` target, it will be invoked when the +``alembic.autogenerate.tables`` plugin proceeds through SQLAlchemy ``Table`` +objects and invokes comparison operations for pairs of same-named columns. + +For an example of a complete comparison function, see the example at +:ref:`autogenerate_global_comparison_function`. + +The current levels of comparison are the same between global and plugin-level +comparison functions, and include: + +* ``"autogenerate"`` - this target is invoked at the top of the autogenerate + chain. These hooks are passed a :class:`.AutogenContext` and an + :class:`.UpgradeOps` collection. Functions that subscribe to the + ``autogenerate`` target should look like:: + + from alembic.autogenerate.api import AutogenContext + from alembic.operations.ops import UpgradeOps + from alembic.runtime.plugins import Plugin + from alembic.util import PriorityDispatchResult + + def autogen_toplevel( + autogen_context: AutogenContext, upgrade_ops: UpgradeOps + ) -> PriorityDispatchResult: + # ... + + + def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator(autogen_toplevel, "autogenerate") + + The function should return either :attr:`.PriorityDispatchResult.CONTINUE` or + :attr:`.PriorityDispatchResult.STOP` to halt any further comparisons from + proceeding, and should respond to detected changes by mutating the given + :class:`.UpgradeOps` collection in place (the :class:`.DowngradeOps` version + is produced later by reversing the :class:`.UpgradeOps`). + + An autogenerate compare function that seeks to run entirely independently of + Alembic's built-in autogenerate plugins, or to replace them completely, would + register at the ``"autogenerate"`` level. The remaining levels indicated + below are all invoked from within Alembic's own autogenerate plugins and will + not take place if ``alembic.autogenerate.*`` is not enabled. + + .. versionadded:: 1.18.0 The ``"autogenerate"`` comparison scope was + introduced, replacing ``"schema"`` as the topmost comparison scope. + +* ``"schema"`` - this target is invoked for each individual "schema" being + compared, and hooks are passed a :class:`.AutogenContext`, an + :class:`.UpgradeOps` collection, and a set of schema names, featuring the + value ``None`` for the "default" schema. Functions that subscribe to the + ``"schema"`` target should look like:: + + from alembic.autogenerate.api import AutogenContext + from alembic.operations.ops import UpgradeOps + from alembic.runtime.plugins import Plugin + from alembic.util import PriorityDispatchResult + + def autogen_for_tables( + autogen_context: AutogenContext, + upgrade_ops: UpgradeOps, + schemas: set[str | None], + ) -> PriorityDispatchResult: + # ... + + def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator( + autogen_for_tables, + "schema", + "tables", + ) + + The function should normally return :attr:`.PriorityDispatchResult.CONTINUE` + and should respond to detected changes by mutating the given + :class:`.UpgradeOps` collection in place (the :class:`.DowngradeOps` version + is produced later by reversing the :class:`.UpgradeOps`). + + The registration example above includes the ``"tables"`` "compare element", + which is optional. This indicates that the comparison function is part of a + chain called "tables", which is what Alembic's own + ``alembic.autogenerate.tables`` plugin uses. If our custom comparison + function were to return the value :attr:`.PriorityDispatchResult.STOP`, + further comparison functions in the ``"tables"`` chain would not be called. + Similarly, if another plugin in the ``"tables"`` chain returned + :attr:`.PriorityDispatchResult.STOP`, then our plugin would not be called. + Making use of :attr:`.PriorityDispatchResult.STOP` in terms of other plugins + in the same "compare element" may be assisted by placing our function in the + comparator chain using :attr:`.DispatchPriority.FIRST` or + :attr:`.DispatchPriority.LAST` when registering. + +* ``"table"`` - this target is invoked per ``Table`` being compared between a + database autoloaded version and the local metadata version. These hooks are + passed an :class:`.AutogenContext`, a :class:`.ModifyTableOps` collection, a + schema name, table name, a :class:`~sqlalchemy.schema.Table` reflected from + the database if any or ``None``, and a :class:`~sqlalchemy.schema.Table` + present in the local :class:`~sqlalchemy.schema.MetaData`. If the + :class:`.ModifyTableOps` collection contains changes after all hooks are run, + it is included in the migration script:: + + from sqlalchemy import quoted_name + from sqlalchemy import Table + + from alembic.autogenerate.api import AutogenContext + from alembic.operations.ops import ModifyTableOps + from alembic.runtime.plugins import Plugin + from alembic.util import PriorityDispatchResult + + def compare_tables( + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, + schema: str | None, + tname: quoted_name | str, + conn_table: Table | None, + metadata_table: Table | None, + ) -> PriorityDispatchResult: + # ... + + + def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator(compare_tables, "table") + + This hook may be used to compare elements of tables, such as comments + or database-specific storage configurations. It should mutate the given + :class:`.ModifyTableOps` object in place to add new change operations. + +* ``"column"`` - this target is invoked per ``Column`` being compared between a + database autoloaded version and the local metadata version. + These hooks are passed an :class:`.AutogenContext`, + an :class:`.AlterColumnOp` object, a schema name, table name, + column name, a :class:`~sqlalchemy.schema.Column` reflected from the + database and a :class:`~sqlalchemy.schema.Column` present in the + local table. If the :class:`.AlterColumnOp` contains changes after + all hooks are run, it is included in the migration script; + a "change" is considered to be present if any of the ``modify_`` attributes + are set to a non-default value, or there are any keys + in the ``.kw`` collection with the prefix ``"modify_"``:: + + from typing import Any + from sqlalchemy import quoted_name + from sqlalchemy import Table + + from alembic.autogenerate.api import AutogenContext + from alembic.operations.ops import AlterColumnOp + from alembic.runtime.plugins import Plugin + from alembic.util import PriorityDispatchResult + + def compare_columns( + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, + schema: str | None, + tname: quoted_name | str, + cname: quoted_name | str, + conn_col: Column[Any], + metadata_col: Column[Any], + ) -> PriorityDispatchResult: + # ... + + + def setup(plugin: Plugin) -> None: + plugin.add_autogenerate_comparator(compare_columns, "column") + + Pre-existing compare chains within the ``"column"`` target include + ``"comment"``, ``"server_default"``, and ``"types"``. Comparison functions + here should mutate the given :class:`.AlterColumnOp` object in place to add + new change operations. + +.. seealso:: + + :ref:`alembic.autogenerate.toplevel` - Detailed documentation on the + autogenerate system + + :ref:`autogenerate_global_comparison_function` - a companion section + to this one which explains autogenerate comparison functions in terms of + the older "global" dispatch, but also includes a complete example of a + comparison function. + + :ref:`customizing_revision` - Customizing autogenerate behavior + + +Plugin API Reference +==================== + +.. autoclass:: alembic.runtime.plugins.Plugin + :members: + +.. autoclass:: alembic.util.langhelpers.PriorityDispatchResult + :members: + +.. autoclass:: alembic.util.langhelpers.DispatchPriority + :members: + +.. seealso:: + + :paramref:`.EnvironmentContext.configure.autogenerate_plugins` - + Configuration parameter for enabling autogenerate plugins + + :ref:`operation_plugins` - Documentation on custom operations + + :ref:`replaceable_objects` - Example of custom operations suitable + for a plugin + + :ref:`customizing_revision` - General information on customizing + autogenerate behavior diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index 79cfdf96..580e0d15 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== .. changelog:: - :version: 1.17.3 + :version: 1.18.0 :include_notes_from: unreleased .. changelog:: diff --git a/docs/build/cookbook.rst b/docs/build/cookbook.rst index 7f0bf3aa..ba8929ed 100644 --- a/docs/build/cookbook.rst +++ b/docs/build/cookbook.rst @@ -450,12 +450,21 @@ that run straight into :meth:`.Operations.execute`:: def drop_sp(operations, operation): operations.execute("DROP FUNCTION %s" % operation.target.name) +Publish the Extensions +---------------------- + All of the above code can be present anywhere within an application's source tree; the only requirement is that when the ``env.py`` script is invoked, it includes imports that ultimately call upon these classes as well as the :meth:`.Operations.register_operation` and :meth:`.Operations.implementation_for` sequences. +Alternatively, custom operations and autogenerate support can be organized +into reusable plugins using Alembic's plugin system. This allows extensions +to be packaged and distributed independently, and automatically discovered +via Python entry points. See :ref:`alembic.plugins.toplevel` for information +on writing and publishing plugins. + Create Initial Migrations ------------------------- diff --git a/docs/build/unreleased/plugins.rst b/docs/build/unreleased/plugins.rst new file mode 100644 index 00000000..efd4bfea --- /dev/null +++ b/docs/build/unreleased/plugins.rst @@ -0,0 +1,33 @@ +.. change:: + :tags: feature, autogenerate + + Release 1.18.0 introduces a plugin system that allows for automatic + loading of third-party extensions as well as configurable autogenerate + compare functionality on a per-environment basis. + + The :class:`.Plugin` class provides a common interface for extensions that + register handlers among Alembic's existing extension points such as + :meth:`.Operations.register_operation` and + :meth:`.Operations.implementation_for`. A new interface for registering + autogenerate comparison handlers, + :meth:`.Plugin.add_autogenerate_comparator`, provides for autogenerate + compare functionality that may be custom-configured on a per-environment + basis using the new + :paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter. + + The change does not impact well known Alembic add-ons such as + ``alembic-utils``, which continue to work as before; however, such add-ons + have the option to provide plugin entrypoints going forward. + + As part of this change, Alembic's autogenerate compare functionality is + reorganized into a series of internal plugins under the + ``alembic.autogenerate`` namespace, which may be individually or + collectively identified for inclusion and/or exclusion within the + :meth:`.EnvironmentContext.configure` call using a new parameter + :paramref:`.EnvironmentContext.configure.autogenerate_plugins`. This + parameter is also where third party comparison plugins may also be + indicated. + + See :ref:`alembic.plugins.toplevel` for complete documentation on + the new :class:`.Plugin` class as well as autogenerate-specific usage + instructions. diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index cd9d54ae..ee99fc31 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -39,6 +39,7 @@ from sqlalchemy.types import VARBINARY from alembic import autogenerate from alembic import testing from alembic.autogenerate import api +from alembic.autogenerate.compare.tables import _compare_tables from alembic.migration import MigrationContext from alembic.operations import ops from alembic.testing import assert_raises_message @@ -58,6 +59,11 @@ from alembic.testing.suite._autogen_fixtures import AutogenTest from alembic.testing.suite._autogen_fixtures import ModelOne from alembic.util import CommandError +if True: + from alembic.autogenerate.compare.types import ( + _dialect_impl_compare_type as _compare_type, + ) + # TODO: we should make an adaptation of CompareMetadataToInspectorTest that is # more well suited towards generic backends (2021-06-10) @@ -500,7 +506,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): def test_skip_null_type_comparison_reflected(self): ac = ops.AlterColumnOp("sometable", "somecol") - autogenerate.compare._compare_type( + _compare_type( self.autogen_context, ac, None, @@ -514,7 +520,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): def test_skip_null_type_comparison_local(self): ac = ops.AlterColumnOp("sometable", "somecol") - autogenerate.compare._compare_type( + _compare_type( self.autogen_context, ac, None, @@ -531,7 +537,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): impl = Integer ac = ops.AlterColumnOp("sometable", "somecol") - autogenerate.compare._compare_type( + _compare_type( self.autogen_context, ac, None, @@ -544,7 +550,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): assert not ac.has_changes() ac = ops.AlterColumnOp("sometable", "somecol") - autogenerate.compare._compare_type( + _compare_type( self.autogen_context, ac, None, @@ -567,7 +573,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): return dialect.type_descriptor(CHAR(32)) uo = ops.AlterColumnOp("sometable", "somecol") - autogenerate.compare._compare_type( + _compare_type( self.autogen_context, uo, None, @@ -583,7 +589,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): inspector = inspect(self.bind) uo = ops.UpgradeOps(ops=[]) - autogenerate.compare._compare_tables( + _compare_tables( OrderedSet([(None, "extra"), (None, "user")]), OrderedSet(), inspector, @@ -1204,12 +1210,12 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase): my_compare_type.return_value = False self.context._user_compare_type = my_compare_type - diffs = [] ctx = self.autogen_context - diffs = [] - autogenerate._produce_net_changes(ctx, diffs) - eq_(diffs, []) + uo = ops.UpgradeOps(ops=[]) + autogenerate._produce_net_changes(ctx, uo) + + eq_(uo.as_diffs(), []) def test_column_type_modified_custom_compare_type_returns_True(self): my_compare_type = mock.Mock() diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py new file mode 100644 index 00000000..eff3676e --- /dev/null +++ b/tests/test_dispatch.py @@ -0,0 +1,404 @@ +"""Test the Dispatcher and PriorityDispatcher utilities.""" + +from alembic import testing +from alembic.testing import eq_ +from alembic.testing.fixtures import TestBase +from alembic.util import Dispatcher +from alembic.util import DispatchPriority +from alembic.util import PriorityDispatcher +from alembic.util import PriorityDispatchResult + + +class DispatcherTest(TestBase): + """Tests for the Dispatcher class.""" + + def test_dispatch_for_decorator(self): + """Test basic decorator registration.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler1(): + return "handler1" + + fn = dispatcher.dispatch("target1") + eq_(fn(), "handler1") + + def test_dispatch_with_args_kwargs(self): + """Test that arguments are passed through to handlers.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler(arg1, kwarg1=None): + return (arg1, kwarg1) + + fn = dispatcher.dispatch("target1") + result = fn("value1", kwarg1="value2") + eq_(result, ("value1", "value2")) + + def test_dispatch_for_qualifier(self): + """Test registration with qualifier.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1", qualifier="postgresql") + def handler_pg(): + return "postgresql" + + @dispatcher.dispatch_for("target1", qualifier="default") + def handler_default(): + return "default" + + fn_pg = dispatcher.dispatch("target1", qualifier="postgresql") + eq_(fn_pg(), "postgresql") + + fn_default = dispatcher.dispatch("target1", qualifier="default") + eq_(fn_default(), "default") + + def test_dispatch_qualifier_fallback(self): + """Test that non-default qualifier falls back to default.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler_default(): + return "default" + + # Request with specific qualifier should fallback to default + fn = dispatcher.dispatch("target1", qualifier="mysql") + eq_(fn(), "default") + + def test_dispatch_type_target(self): + """Test dispatching with type targets using MRO.""" + dispatcher = Dispatcher() + + class Base: + pass + + class Child(Base): + pass + + @dispatcher.dispatch_for(Base) + def handler_base(): + return "base" + + # Dispatching with Child should find Base handler via MRO + fn = dispatcher.dispatch(Child()) + eq_(fn(), "base") + + def test_dispatch_type_class_vs_instance(self): + """Test dispatching with type vs instance.""" + dispatcher = Dispatcher() + + class MyClass: + pass + + @dispatcher.dispatch_for(MyClass) + def handler(): + return "handler" + + # Both class and instance should work + fn_class = dispatcher.dispatch(MyClass) + eq_(fn_class(), "handler") + + fn_instance = dispatcher.dispatch(MyClass()) + eq_(fn_instance(), "handler") + + def test_dispatch_no_match_raises(self): + """Test that dispatching with no match raises ValueError.""" + dispatcher = Dispatcher() + + with testing.expect_raises_message(ValueError, "no dispatch function"): + dispatcher.dispatch("nonexistent") + + def test_dispatch_replace_false_raises(self): + """Test that duplicate registration raises ValueError.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler1(): + return "handler1" + + with testing.expect_raises_message(ValueError, "key already exists"): + + @dispatcher.dispatch_for("target1") + def handler2(): + return "handler2" + + def test_dispatch_replace_true_works(self): + """Test that replace=True allows overwriting.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler1(): + return "handler1" + + @dispatcher.dispatch_for("target1", replace=True) + def handler2(): + return "handler2" + + fn = dispatcher.dispatch("target1") + eq_(fn(), "handler2") + + def test_branch(self): + """Test that branch creates independent copy.""" + dispatcher = Dispatcher() + + @dispatcher.dispatch_for("target1") + def handler1(): + return "handler1" + + dispatcher2 = dispatcher.branch() + + # Add to branch should not affect original + @dispatcher2.dispatch_for("target2") + def handler2(): + return "handler2" + + # Original should not have target2 + with testing.expect_raises(ValueError): + dispatcher.dispatch("target2") + + # Branch should have both + fn1 = dispatcher2.dispatch("target1") + eq_(fn1(), "handler1") + fn2 = dispatcher2.dispatch("target2") + eq_(fn2(), "handler2") + + +class PriorityDispatcherTest(TestBase): + """Tests for the PriorityDispatcher class.""" + + def test_dispatch_for_decorator(self): + """Test basic decorator registration.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1") + def handler1(): + results.append("handler1") + + fn = dispatcher.dispatch("target1") + fn() + eq_(results, ["handler1"]) + + def test_dispatch_target_not_registered(self): + """Test that dispatching unregistered target returns noop.""" + dispatcher = PriorityDispatcher() + + # Unlike regular Dispatcher, PriorityDispatcher returns a noop + # function for unregistered targets + fn = dispatcher.dispatch("nonexistent") + # Should not raise, just return a callable that does nothing + fn() + + def test_dispatch_with_priority(self): + """Test that handlers execute in priority order.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", priority=DispatchPriority.LAST) + def handler_last(): + results.append("last") + + @dispatcher.dispatch_for("target1", priority=DispatchPriority.FIRST) + def handler_first(): + results.append("first") + + @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM) + def handler_medium(): + results.append("medium") + + fn = dispatcher.dispatch("target1") + fn() + eq_(results, ["first", "medium", "last"]) + + def test_dispatch_with_subgroup(self): + """Test that subgroups track results independently.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", subgroup="group1") + def handler1(): + results.append("group1") + return PriorityDispatchResult.CONTINUE + + @dispatcher.dispatch_for("target1", subgroup="group2") + def handler2(): + results.append("group2") + return PriorityDispatchResult.CONTINUE + + fn = dispatcher.dispatch("target1") + fn() + eq_(results, ["group1", "group2"]) + + def test_dispatch_stop_result(self): + """Test that STOP prevents further execution in subgroup.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for( + "target1", priority=DispatchPriority.FIRST, subgroup="group1" + ) + def handler1(): + results.append("handler1") + return PriorityDispatchResult.STOP + + @dispatcher.dispatch_for( + "target1", priority=DispatchPriority.MEDIUM, subgroup="group1" + ) + def handler2(): + results.append("handler2") # Should not execute + return PriorityDispatchResult.CONTINUE + + @dispatcher.dispatch_for( + "target1", priority=DispatchPriority.FIRST, subgroup="group2" + ) + def handler3(): + results.append("handler3") # Should execute + return PriorityDispatchResult.CONTINUE + + fn = dispatcher.dispatch("target1") + fn() + # handler2 should not run because handler1 returned STOP for group1 + # handler3 should run because it's in a different subgroup + eq_(results, ["handler1", "handler3"]) + + def test_dispatch_with_qualifier(self): + """Test dispatching with qualifiers includes both specific and + default.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", qualifier="postgresql") + def handler_pg(): + results.append("postgresql") + + @dispatcher.dispatch_for("target1", qualifier="default") + def handler_default(): + results.append("default") + + fn_pg = dispatcher.dispatch("target1", qualifier="postgresql") + fn_pg() + # Should run both postgresql and default handlers + eq_(results, ["postgresql", "default"]) + + def test_dispatch_qualifier_fallback(self): + """Test that non-default qualifier also executes default handlers.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", qualifier="default") + def handler_default(): + results.append("default") + + # Request with specific qualifier should also run default + fn = dispatcher.dispatch("target1", qualifier="mysql") + fn() + eq_(results, ["default"]) + + def test_dispatch_with_args_kwargs(self): + """Test that arguments are passed through to handlers.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1") + def handler(arg1, kwarg1=None): + results.append((arg1, kwarg1)) + + fn = dispatcher.dispatch("target1") + fn("value1", kwarg1="value2") + eq_(results, [("value1", "value2")]) + + def test_multiple_handlers_same_priority(self): + """Test multiple handlers at same priority execute in order.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM) + def handler1(): + results.append("handler1") + + @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM) + def handler2(): + results.append("handler2") + + fn = dispatcher.dispatch("target1") + fn() + # Both should execute + eq_(results, ["handler1", "handler2"]) + + def test_branch(self): + """Test that branch creates independent copy.""" + dispatcher = PriorityDispatcher() + results1 = [] + + @dispatcher.dispatch_for("target1") + def handler1(): + results1.append("handler1") + + dispatcher2 = dispatcher.branch() + results2 = [] + + @dispatcher2.dispatch_for("target2") + def handler2(): + results2.append("handler2") + + # Original should have target1 + fn1 = dispatcher.dispatch("target1") + fn1() + eq_(results1, ["handler1"]) + + # Branch should have both + fn1_branch = dispatcher2.dispatch("target1") + fn2_branch = dispatcher2.dispatch("target2") + fn1_branch() + fn2_branch() + eq_(results1, ["handler1", "handler1"]) + eq_(results2, ["handler2"]) + + def test_populate_with(self): + """Test populate_with method.""" + dispatcher1 = PriorityDispatcher() + results = [] + + @dispatcher1.dispatch_for("target1") + def handler1(): + results.append("handler1") + + dispatcher2 = PriorityDispatcher() + + @dispatcher2.dispatch_for("target2") + def handler2(): + results.append("handler2") + + # Populate dispatcher2 with dispatcher1's handlers + dispatcher2.populate_with(dispatcher1) + + # dispatcher2 should now have both handlers + fn1 = dispatcher2.dispatch("target1") + fn2 = dispatcher2.dispatch("target2") + fn1() + fn2() + eq_(results, ["handler1", "handler2"]) + + def test_none_subgroup(self): + """Test that None subgroup is tracked separately.""" + dispatcher = PriorityDispatcher() + results = [] + + @dispatcher.dispatch_for("target1", subgroup=None) + def handler1(): + results.append("none") + return PriorityDispatchResult.STOP + + @dispatcher.dispatch_for("target1", subgroup=None) + def handler2(): + results.append("none2") # Should not execute + return PriorityDispatchResult.CONTINUE + + @dispatcher.dispatch_for("target1", subgroup="other") + def handler3(): + results.append("other") # Should execute + return PriorityDispatchResult.CONTINUE + + fn = dispatcher.dispatch("target1") + fn() + eq_(results, ["none", "other"]) diff --git a/tests/test_mysql.py b/tests/test_mysql.py index c15b70e3..399cd34d 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -20,7 +20,7 @@ from alembic import autogenerate from alembic import op from alembic import util from alembic.autogenerate import api -from alembic.autogenerate import compare +from alembic.autogenerate.compare.constraints import _compare_nullable from alembic.migration import MigrationContext from alembic.operations import ops from alembic.testing import assert_raises_message @@ -33,6 +33,14 @@ from alembic.testing.fixtures import AlterColRoundTripFixture from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase +if True: + from alembic.autogenerate.compare.server_defaults import ( + _user_compare_server_default, + ) + from alembic.autogenerate.compare.types import ( + _dialect_impl_compare_type as _compare_type, + ) + class MySQLOpTest(TestBase): def test_create_table_with_comment(self): @@ -229,9 +237,13 @@ class MySQLOpTest(TestBase): operation = ops.AlterColumnOp("t", "c") for fn in ( - compare._compare_nullable, - compare._compare_type, - compare._compare_server_default, + _compare_nullable, + _compare_type, + # note that _user_compare_server_default does not actually + # do a server default compare here, compare_server_default + # is False so this just assigns the existing default to the + # AlterColumnOp + _user_compare_server_default, ): fn( autogen_context, @@ -240,8 +252,9 @@ class MySQLOpTest(TestBase): "t", "c", Column("c", Float(), nullable=False, server_default=text("0")), - Column("c", Float(), nullable=True, default=0), + Column("c", Float(), nullable=True, server_default=text("0")), ) + op.invoke(operation) context.assert_("ALTER TABLE t MODIFY c FLOAT NULL DEFAULT 0") diff --git a/tests/test_plugin.py b/tests/test_plugin.py new file mode 100644 index 00000000..b7a8d1b5 --- /dev/null +++ b/tests/test_plugin.py @@ -0,0 +1,313 @@ +"""Test the Plugin class and plugin system.""" + +from types import ModuleType +from unittest import mock + +from alembic import testing +from alembic import util +from alembic.runtime.plugins import _all_plugins +from alembic.runtime.plugins import _make_re +from alembic.runtime.plugins import Plugin +from alembic.testing import eq_ +from alembic.testing.fixtures import TestBase +from alembic.util import DispatchPriority +from alembic.util import PriorityDispatcher +from alembic.util import PriorityDispatchResult + + +class PluginTest(TestBase): + """Tests for the Plugin class.""" + + @testing.fixture(scope="function", autouse=True) + def _clear_plugin_registry(self): + """Clear plugin registry before each test and restore after.""" + # Save original plugins + original_plugins = _all_plugins.copy() + _all_plugins.clear() + + yield + + # Restore plugin registry after test + _all_plugins.clear() + _all_plugins.update(original_plugins) + + def test_plugin_creation(self): + """Test basic plugin creation.""" + plugin = Plugin("test.plugin") + eq_(plugin.name, "test.plugin") + assert "test.plugin" in _all_plugins + eq_(_all_plugins["test.plugin"], plugin) + + def test_plugin_creation_duplicate_raises(self): + """Test that duplicate plugin names raise ValueError.""" + Plugin("test.plugin") + with testing.expect_raises_message( + ValueError, "A plugin named test.plugin is already registered" + ): + Plugin("test.plugin") + + def test_plugin_remove(self): + """Test plugin removal.""" + plugin = Plugin("test.plugin") + assert "test.plugin" in _all_plugins + plugin.remove() + assert "test.plugin" not in _all_plugins + + def test_add_autogenerate_comparator(self): + """Test adding autogenerate comparison functions.""" + plugin = Plugin("test.plugin") + + def my_comparator(): + return PriorityDispatchResult.CONTINUE + + plugin.add_autogenerate_comparator( + my_comparator, + "table", + "column", + qualifier="postgresql", + priority=DispatchPriority.FIRST, + ) + + # Verify it was registered in the dispatcher + fn = plugin.autogenerate_comparators.dispatch( + "table", qualifier="postgresql" + ) + # The dispatcher returns a callable, call it to verify + fn() + + def test_populate_autogenerate_priority_dispatch_simple(self): + """Test populating dispatcher with simple include pattern.""" + plugin1 = Plugin("test.plugin1") + plugin2 = Plugin("test.plugin2") + + mock1 = mock.Mock() + mock2 = mock.Mock() + + plugin1.add_autogenerate_comparator(mock1, "test") + plugin2.add_autogenerate_comparator(mock2, "test") + + dispatcher = PriorityDispatcher() + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.plugin1"] + ) + + # Should have plugin1's handler, but not plugin2's + fn = dispatcher.dispatch("test") + fn() + eq_(mock1.mock_calls, [mock.call()]) + eq_(mock2.mock_calls, []) + + def test_populate_autogenerate_priority_dispatch_wildcard(self): + """Test populating dispatcher with wildcard pattern.""" + plugin1_alpha = Plugin("test.plugin1.alpha") + plugin1_beta = Plugin("test.plugin1.beta") + plugin2_gamma = Plugin("test.plugin2.gamma") + + mock_alpha = mock.Mock() + mock_beta = mock.Mock() + mock_gamma = mock.Mock() + + plugin1_alpha.add_autogenerate_comparator(mock_alpha, "test") + plugin1_beta.add_autogenerate_comparator(mock_beta, "test") + plugin2_gamma.add_autogenerate_comparator(mock_gamma, "test") + + dispatcher = PriorityDispatcher() + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.plugin1.*"] + ) + + # Both test.plugin1.* should be included + # test.plugin2.* should not be included + fn = dispatcher.dispatch("test") + fn() + eq_(mock_alpha.mock_calls, [mock.call()]) + eq_(mock_beta.mock_calls, [mock.call()]) + eq_(mock_gamma.mock_calls, []) + + def test_populate_autogenerate_priority_dispatch_exclude(self): + """Test populating dispatcher with exclude pattern.""" + plugin1 = Plugin("test.plugin1") + plugin2 = Plugin("test.plugin2") + + mock1 = mock.Mock() + mock2 = mock.Mock() + + plugin1.add_autogenerate_comparator(mock1, "test") + plugin2.add_autogenerate_comparator(mock2, "test") + + dispatcher = PriorityDispatcher() + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.*", "~test.plugin2"] + ) + + # Should have plugin1's handler, but not plugin2's (excluded) + fn = dispatcher.dispatch("test") + fn() + eq_(mock1.mock_calls, [mock.call()]) + eq_(mock2.mock_calls, []) + + def test_populate_autogenerate_priority_dispatch_not_found(self): + """Test that non-matching pattern raises error.""" + Plugin("test.plugin1") + + dispatcher = PriorityDispatcher() + with testing.expect_raises_message( + util.CommandError, + "Did not locate plugins.*test.nonexistent", + ): + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.nonexistent"] + ) + + def test_populate_autogenerate_priority_dispatch_wildcard_not_found( + self, + ): + """Test that non-matching wildcard pattern raises error.""" + Plugin("test.plugin1") + + dispatcher = PriorityDispatcher() + with testing.expect_raises_message( + util.CommandError, + "Did not locate plugins", + ): + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["other.*"] + ) + + def test_populate_autogenerate_priority_dispatch_multiple_includes(self): + """Test populating with multiple include patterns.""" + Plugin("test.plugin1") + Plugin("other.plugin2") + + dispatcher = PriorityDispatcher() + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.plugin1", "other.plugin2"] + ) + # Should not raise error + + def test_setup_plugin_from_module(self): + """Test setting up plugin from a module.""" + # Create a mock module with a setup function + mock_module = ModuleType("mock_plugin") + + def setup(plugin): + eq_(plugin.name, "mock.plugin") + # Register a comparator to verify setup was called + plugin.add_autogenerate_comparator( + lambda: PriorityDispatchResult.CONTINUE, + "test_target", + ) + + mock_module.setup = setup + + Plugin.setup_plugin_from_module(mock_module, "mock.plugin") + + # Verify plugin was created + assert "mock.plugin" in _all_plugins + + def test_autogenerate_comparators_dispatcher(self): + """Test that autogenerate_comparators is a PriorityDispatcher.""" + plugin = Plugin("test.plugin") + assert isinstance(plugin.autogenerate_comparators, PriorityDispatcher) + + def test_populate_with_real_handlers(self): + """Test populating dispatcher with actual comparison handlers.""" + plugin = Plugin("test.plugin") + results = [] + + def compare_tables( + autogen_context, upgrade_ops, schemas + ): # pragma: no cover + results.append(("compare_tables", autogen_context)) + return PriorityDispatchResult.CONTINUE + + def compare_types( + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, + ): # pragma: no cover + results.append(("compare_types", tname)) + return PriorityDispatchResult.CONTINUE + + plugin.add_autogenerate_comparator(compare_tables, "table") + plugin.add_autogenerate_comparator(compare_types, "type") + + dispatcher = PriorityDispatcher() + Plugin.populate_autogenerate_priority_dispatch( + dispatcher, ["test.plugin"] + ) + + # Verify handlers are in dispatcher + fn_table = dispatcher.dispatch("table") + fn_type = dispatcher.dispatch("type") + + # Call them to verify they work + fn_table("autogen_ctx", "upgrade_ops", "schemas") + fn_type( + "autogen_ctx", + "alter_op", + "schema", + "tablename", + "colname", + "conn_col", + "meta_col", + ) + + eq_(results[0][0], "compare_tables") + eq_(results[1][0], "compare_types") + eq_(results[1][1], "tablename") + + +class MakeReTest(TestBase): + """Tests for the _make_re helper function.""" + + def test_simple_name(self): + """Test regex generation for simple dotted names.""" + pattern = _make_re("test.plugin") + assert pattern.match("test.plugin") + + # Partial matches dont work; use a * for this + assert not pattern.match("test.plugin.extra") + + # other tokens don't match either + assert not pattern.match("test.pluginfoo") + + assert not pattern.match("other.plugin") + assert not pattern.match("test") + + def test_wildcard(self): + """Test regex generation with wildcard.""" + pattern = _make_re("test.*") + assert pattern.match("test.plugin") + assert pattern.match("test.plugin.extra") + assert not pattern.match("test") + assert not pattern.match("other.plugin") + + def test_multiple_wildcards(self): + """Test regex generation with multiple wildcards.""" + pattern = _make_re("test.*.sub.*") + assert pattern.match("test.plugin.sub.item") + assert pattern.match("test.a.sub.b") + assert not pattern.match("test.plugin") + + def test_invalid_pattern_raises(self): + """Test that invalid patterns raise ValueError.""" + with testing.expect_raises_message( + ValueError, "Invalid plugin expression" + ): + _make_re("test.plugin-name") + + def test_valid_underscore(self): + """Test that underscores are valid in names.""" + pattern = _make_re("test.my_plugin") + assert pattern.match("test.my_plugin") + + def test_valid_mixed_case(self): + """Test that mixed case is valid in names.""" + pattern = _make_re("test.MyPlugin") + assert pattern.match("test.MyPlugin") + assert not pattern.match("test.myplugin") diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0f43cf3b..6001d5d1 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -40,9 +40,7 @@ from alembic import op from alembic import testing from alembic import util from alembic.autogenerate import api -from alembic.autogenerate.compare import _compare_server_default -from alembic.autogenerate.compare import _compare_tables -from alembic.autogenerate.compare import _render_server_default_for_compare +from alembic.autogenerate.compare.tables import _compare_tables from alembic.migration import MigrationContext from alembic.operations import ops from alembic.script import ScriptDirectory @@ -67,6 +65,15 @@ from alembic.testing.fixtures import TestBase from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest +if True: + from alembic.autogenerate.compare.server_defaults import ( + _render_server_default_for_compare, + ) # noqa: E501 + from alembic.autogenerate.compare.server_defaults import ( + _dialect_impl_compare_server_default as _compare_server_default, + ) + + class PostgresqlOpTest(TestBase): def test_rename_table_postgresql(self): context = op_fixture("postgresql") diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index e8ee21df..45778ae7 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -15,7 +15,6 @@ from sqlalchemy.sql import column from alembic import autogenerate from alembic import op from alembic.autogenerate import api -from alembic.autogenerate.compare import _compare_server_default from alembic.migration import MigrationContext from alembic.operations import ops from alembic.testing import assert_raises_message @@ -29,6 +28,11 @@ from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest +if True: + from alembic.autogenerate.compare.server_defaults import ( + _dialect_impl_compare_server_default as _compare_server_default, + ) + class SQLiteTest(TestBase): def test_add_column(self):