- id: black
- repo: https://github.com/sqlalchemyorg/zimports
- rev: v0.6.2
+ rev: v0.7.0
hooks:
- id: zimports
args:
from . import context
from . import op
+from .runtime import plugins
-__version__ = "1.17.3"
+
+__version__ = "1.18.0"
from __future__ import annotations
import contextlib
+import logging
from typing import Any
from typing import Dict
from typing import Iterator
from . import render
from .. import util
from ..operations import ops
+from ..runtime.plugins import Plugin
from ..util import sqla_compat
-"""Provide the 'autogenerate' feature which can produce migration operations
-automatically."""
-
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from ..script.base import Script
from ..script.base import ScriptDirectory
from ..script.revision import _GetRevArg
+ from ..util import PriorityDispatcher
+
+
+log = logging.getLogger(__name__)
def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
"""
- dialect: Optional[Dialect] = None
+ dialect: Dialect
"""The :class:`~sqlalchemy.engine.Dialect` object currently in use.
This is normally obtained from the
"""
- migration_context: MigrationContext = None # type: ignore[assignment]
+ migration_context: MigrationContext
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
+ comparators: PriorityDispatcher
+
def __init__(
self,
migration_context: MigrationContext,
"the database for schema information"
)
+ # branch off from the "global" comparators. This collection
+ # is empty in Alembic except that it is populated by third party
+ # extensions that don't use the plugin system. so we will build
+ # off of whatever is in there.
+ if autogenerate:
+ self.comparators = compare.comparators.branch()
+ Plugin.populate_autogenerate_priority_dispatch(
+ self.comparators,
+ include_plugins=migration_context.opts.get(
+ "autogenerate_plugins", ["alembic.autogenerate.*"]
+ ),
+ )
+
if opts is None:
opts = migration_context.opts
self._name_filters = name_filters
self.migration_context = migration_context
- if self.migration_context is not None:
- self.connection = self.migration_context.bind
- self.dialect = self.migration_context.dialect
+ self.connection = self.migration_context.bind
+ self.dialect = self.migration_context.dialect
self.imports = set()
self.opts: Dict[str, Any] = opts
--- /dev/null
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING
+
+from . import comments
+from . import constraints
+from . import schema
+from . import server_defaults
+from . import tables
+from . import types
+from ... import util
+from ...runtime.plugins import Plugin
+
+if TYPE_CHECKING:
+ from ..api import AutogenContext
+ from ...operations.ops import MigrationScript
+ from ...operations.ops import UpgradeOps
+
+
+log = logging.getLogger(__name__)
+
+comparators = util.PriorityDispatcher()
+"""global registry which alembic keeps empty, but copies when creating
+a new AutogenContext.
+
+This is to support a variety of third party plugins that hook their autogen
+functionality onto this collection.
+
+"""
+
+
+def _populate_migration_script(
+ autogen_context: AutogenContext, migration_script: MigrationScript
+) -> None:
+ upgrade_ops = migration_script.upgrade_ops_list[-1]
+ downgrade_ops = migration_script.downgrade_ops_list[-1]
+
+ _produce_net_changes(autogen_context, upgrade_ops)
+ upgrade_ops.reverse_into(downgrade_ops)
+
+
+def _produce_net_changes(
+ autogen_context: AutogenContext, upgrade_ops: UpgradeOps
+) -> None:
+ assert autogen_context.dialect is not None
+
+ autogen_context.comparators.dispatch(
+ "autogenerate", qualifier=autogen_context.dialect.name
+ )(autogen_context, upgrade_ops)
+
+
+Plugin.setup_plugin_from_module(schema, "alembic.autogenerate.schemas")
+Plugin.setup_plugin_from_module(tables, "alembic.autogenerate.tables")
+Plugin.setup_plugin_from_module(types, "alembic.autogenerate.types")
+Plugin.setup_plugin_from_module(
+ constraints, "alembic.autogenerate.constraints"
+)
+Plugin.setup_plugin_from_module(
+ server_defaults, "alembic.autogenerate.defaults"
+)
+Plugin.setup_plugin_from_module(comments, "alembic.autogenerate.comments")
--- /dev/null
+from __future__ import annotations
+
+import logging
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
+from ...operations import ops
+from ...util import PriorityDispatchResult
+
+if TYPE_CHECKING:
+
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Table
+
+ from ..api import AutogenContext
+ from ...operations.ops import AlterColumnOp
+ from ...operations.ops import ModifyTableOps
+ from ...runtime.plugins import Plugin
+
+log = logging.getLogger(__name__)
+
+
+def _compare_column_comment(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: quoted_name,
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+ assert autogen_context.dialect is not None
+ if not autogen_context.dialect.supports_comments:
+ return PriorityDispatchResult.CONTINUE
+
+ metadata_comment = metadata_col.comment
+ conn_col_comment = conn_col.comment
+ if conn_col_comment is None and metadata_comment is None:
+ return PriorityDispatchResult.CONTINUE
+
+ alter_column_op.existing_comment = conn_col_comment
+
+ if conn_col_comment != metadata_comment:
+ alter_column_op.modify_comment = metadata_comment
+ log.info("Detected column comment '%s.%s'", tname, cname)
+
+ return PriorityDispatchResult.STOP
+ else:
+ return PriorityDispatchResult.CONTINUE
+
+
+def _compare_table_comment(
+ autogen_context: AutogenContext,
+ modify_table_ops: ModifyTableOps,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ conn_table: Optional[Table],
+ metadata_table: Optional[Table],
+) -> PriorityDispatchResult:
+ assert autogen_context.dialect is not None
+ if not autogen_context.dialect.supports_comments:
+ return PriorityDispatchResult.CONTINUE
+
+ # if we're doing CREATE TABLE, comments will be created inline
+ # with the create_table op.
+ if conn_table is None or metadata_table is None:
+ return PriorityDispatchResult.CONTINUE
+
+ if conn_table.comment is None and metadata_table.comment is None:
+ return PriorityDispatchResult.CONTINUE
+
+ if metadata_table.comment is None and conn_table.comment is not None:
+ modify_table_ops.ops.append(
+ ops.DropTableCommentOp(
+ tname, existing_comment=conn_table.comment, schema=schema
+ )
+ )
+ return PriorityDispatchResult.STOP
+ elif metadata_table.comment != conn_table.comment:
+ modify_table_ops.ops.append(
+ ops.CreateTableCommentOp(
+ tname,
+ metadata_table.comment,
+ existing_comment=conn_table.comment,
+ schema=schema,
+ )
+ )
+ return PriorityDispatchResult.STOP
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _compare_column_comment,
+ "column",
+ "comments",
+ )
+ plugin.add_autogenerate_comparator(
+ _compare_table_comment,
+ "table",
+ "comments",
+ )
-# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
-# mypy: no-warn-return-any, allow-any-generics
+# mypy: allow-untyped-defs, allow-untyped-calls, allow-incomplete-defs
from __future__ import annotations
-import contextlib
import logging
-import re
from typing import Any
from typing import cast
from typing import Dict
-from typing import Iterator
from typing import Mapping
from typing import Optional
-from typing import Set
-from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy import event
-from sqlalchemy import inspect
from sqlalchemy import schema as sa_schema
from sqlalchemy import text
-from sqlalchemy import types as sqltypes
from sqlalchemy.sql import expression
-from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
-from sqlalchemy.util import OrderedSet
-from .. import util
-from ..ddl._autogen import is_index_sig
-from ..ddl._autogen import is_uq_sig
-from ..operations import ops
-from ..util import sqla_compat
+from .util import _InspectorConv
+from ... import util
+from ...ddl._autogen import is_index_sig
+from ...ddl._autogen import is_uq_sig
+from ...operations import ops
+from ...util import PriorityDispatchResult
+from ...util import sqla_compat
if TYPE_CHECKING:
- from typing import Literal
-
- from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Table
- from alembic.autogenerate.api import AutogenContext
- from alembic.ddl.impl import DefaultImpl
- from alembic.operations.ops import AlterColumnOp
- from alembic.operations.ops import MigrationScript
- from alembic.operations.ops import ModifyTableOps
- from alembic.operations.ops import UpgradeOps
- from ..ddl._autogen import _constraint_sig
-
-
-log = logging.getLogger(__name__)
-
-
-def _populate_migration_script(
- autogen_context: AutogenContext, migration_script: MigrationScript
-) -> None:
- upgrade_ops = migration_script.upgrade_ops_list[-1]
- downgrade_ops = migration_script.downgrade_ops_list[-1]
-
- _produce_net_changes(autogen_context, upgrade_ops)
- upgrade_ops.reverse_into(downgrade_ops)
-
-
-comparators = util.Dispatcher(uselist=True)
-
-
-def _produce_net_changes(
- autogen_context: AutogenContext, upgrade_ops: UpgradeOps
-) -> None:
- connection = autogen_context.connection
- assert connection is not None
- include_schemas = autogen_context.opts.get("include_schemas", False)
-
- inspector: Inspector = inspect(connection)
-
- default_schema = connection.dialect.default_schema_name
- schemas: Set[Optional[str]]
- if include_schemas:
- schemas = set(inspector.get_schema_names())
- # replace default schema name with None
- schemas.discard("information_schema")
- # replace the "default" schema with None
- schemas.discard(default_schema)
- schemas.add(None)
- else:
- schemas = {None}
-
- schemas = {
- s for s in schemas if autogen_context.run_name_filters(s, "schema", {})
- }
-
- assert autogen_context.dialect is not None
- comparators.dispatch("schema", autogen_context.dialect.name)(
- autogen_context, upgrade_ops, schemas
- )
-
-
-@comparators.dispatch_for("schema")
-def _autogen_for_tables(
- autogen_context: AutogenContext,
- upgrade_ops: UpgradeOps,
- schemas: Union[Set[None], Set[Optional[str]]],
-) -> None:
- inspector = autogen_context.inspector
-
- conn_table_names: Set[Tuple[Optional[str], str]] = set()
-
- version_table_schema = (
- autogen_context.migration_context.version_table_schema
- )
- version_table = autogen_context.migration_context.version_table
-
- for schema_name in schemas:
- tables = set(inspector.get_table_names(schema=schema_name))
- if schema_name == version_table_schema:
- tables = tables.difference(
- [autogen_context.migration_context.version_table]
- )
-
- conn_table_names.update(
- (schema_name, tname)
- for tname in tables
- if autogen_context.run_name_filters(
- tname, "table", {"schema_name": schema_name}
- )
- )
-
- metadata_table_names = OrderedSet(
- [(table.schema, table.name) for table in autogen_context.sorted_tables]
- ).difference([(version_table_schema, version_table)])
-
- _compare_tables(
- conn_table_names,
- metadata_table_names,
- inspector,
- upgrade_ops,
- autogen_context,
- )
-
-
-def _compare_tables(
- conn_table_names: set,
- metadata_table_names: set,
- inspector: Inspector,
- upgrade_ops: UpgradeOps,
- autogen_context: AutogenContext,
-) -> None:
- default_schema = inspector.bind.dialect.default_schema_name
-
- # tables coming from the connection will not have "schema"
- # set if it matches default_schema_name; so we need a list
- # of table names from local metadata that also have "None" if schema
- # == default_schema_name. Most setups will be like this anyway but
- # some are not (see #170)
- metadata_table_names_no_dflt_schema = OrderedSet(
- [
- (schema if schema != default_schema else None, tname)
- for schema, tname in metadata_table_names
- ]
- )
-
- # to adjust for the MetaData collection storing the tables either
- # as "schemaname.tablename" or just "tablename", create a new lookup
- # which will match the "non-default-schema" keys to the Table object.
- tname_to_table = {
- no_dflt_schema: autogen_context.table_key_to_table[
- sa_schema._get_table_key(tname, schema)
- ]
- for no_dflt_schema, (schema, tname) in zip(
- metadata_table_names_no_dflt_schema, metadata_table_names
- )
- }
- metadata_table_names = metadata_table_names_no_dflt_schema
-
- for s, tname in metadata_table_names.difference(conn_table_names):
- name = "%s.%s" % (s, tname) if s else tname
- metadata_table = tname_to_table[(s, tname)]
- if autogen_context.run_object_filters(
- metadata_table, tname, "table", False, None
- ):
- upgrade_ops.ops.append(
- ops.CreateTableOp.from_table(metadata_table)
- )
- log.info("Detected added table %r", name)
- modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
-
- comparators.dispatch("table")(
- autogen_context,
- modify_table_ops,
- s,
- tname,
- None,
- metadata_table,
- )
- if not modify_table_ops.is_empty():
- upgrade_ops.ops.append(modify_table_ops)
-
- removal_metadata = sa_schema.MetaData()
- for s, tname in conn_table_names.difference(metadata_table_names):
- name = sa_schema._get_table_key(tname, s)
- exists = name in removal_metadata.tables
- t = sa_schema.Table(tname, removal_metadata, schema=s)
-
- if not exists:
- event.listen(
- t,
- "column_reflect",
- # fmt: off
- autogen_context.migration_context.impl.
- _compat_autogen_column_reflect
- (inspector),
- # fmt: on
- )
- _InspectorConv(inspector).reflect_table(t, include_columns=None)
- if autogen_context.run_object_filters(t, tname, "table", True, None):
- modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
-
- comparators.dispatch("table")(
- autogen_context, modify_table_ops, s, tname, t, None
- )
- if not modify_table_ops.is_empty():
- upgrade_ops.ops.append(modify_table_ops)
-
- upgrade_ops.ops.append(ops.DropTableOp.from_table(t))
- log.info("Detected removed table %r", name)
-
- existing_tables = conn_table_names.intersection(metadata_table_names)
-
- existing_metadata = sa_schema.MetaData()
- conn_column_info = {}
- for s, tname in existing_tables:
- name = sa_schema._get_table_key(tname, s)
- exists = name in existing_metadata.tables
- t = sa_schema.Table(tname, existing_metadata, schema=s)
- if not exists:
- event.listen(
- t,
- "column_reflect",
- # fmt: off
- autogen_context.migration_context.impl.
- _compat_autogen_column_reflect(inspector),
- # fmt: on
- )
- _InspectorConv(inspector).reflect_table(t, include_columns=None)
-
- conn_column_info[(s, tname)] = t
-
- for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
- s = s or None
- name = "%s.%s" % (s, tname) if s else tname
- metadata_table = tname_to_table[(s, tname)]
- conn_table = existing_metadata.tables[name]
-
- if autogen_context.run_object_filters(
- metadata_table, tname, "table", False, conn_table
- ):
- modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
- with _compare_columns(
- s,
- tname,
- conn_table,
- metadata_table,
- modify_table_ops,
- autogen_context,
- inspector,
- ):
- comparators.dispatch("table")(
- autogen_context,
- modify_table_ops,
- s,
- tname,
- conn_table,
- metadata_table,
- )
-
- if not modify_table_ops.is_empty():
- upgrade_ops.ops.append(modify_table_ops)
-
-
-_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
- {
- "asc": expression.asc,
- "desc": expression.desc,
- "nulls_first": expression.nullsfirst,
- "nulls_last": expression.nullslast,
- "nullsfirst": expression.nullsfirst, # 1_3 name
- "nullslast": expression.nullslast, # 1_3 name
- }
-)
-
-
-def _make_index(
- impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
-) -> Optional[Index]:
- exprs: list[Union[Column[Any], TextClause]] = []
- sorting = params.get("column_sorting")
-
- for num, col_name in enumerate(params["column_names"]):
- item: Union[Column[Any], TextClause]
- if col_name is None:
- assert "expressions" in params
- name = params["expressions"][num]
- item = text(name)
- else:
- name = col_name
- item = conn_table.c[col_name]
- if sorting and name in sorting:
- for operator in sorting[name]:
- if operator in _IndexColumnSortingOps:
- item = _IndexColumnSortingOps[operator](item)
- exprs.append(item)
- ix = sa_schema.Index(
- params["name"],
- *exprs,
- unique=params["unique"],
- _table=conn_table,
- **impl.adjust_reflected_dialect_options(params, "index"),
- )
- if "duplicates_constraint" in params:
- ix.info["duplicates_constraint"] = params["duplicates_constraint"]
- return ix
-
-
-def _make_unique_constraint(
- impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
-) -> UniqueConstraint:
- uq = sa_schema.UniqueConstraint(
- *[conn_table.c[cname] for cname in params["column_names"]],
- name=params["name"],
- **impl.adjust_reflected_dialect_options(params, "unique_constraint"),
- )
- if "duplicates_index" in params:
- uq.info["duplicates_index"] = params["duplicates_index"]
-
- return uq
-
-
-def _make_foreign_key(
- params: Dict[str, Any], conn_table: Table
-) -> ForeignKeyConstraint:
- tname = params["referred_table"]
- if params["referred_schema"]:
- tname = "%s.%s" % (params["referred_schema"], tname)
-
- options = params.get("options", {})
-
- const = sa_schema.ForeignKeyConstraint(
- [conn_table.c[cname] for cname in params["constrained_columns"]],
- ["%s.%s" % (tname, n) for n in params["referred_columns"]],
- onupdate=options.get("onupdate"),
- ondelete=options.get("ondelete"),
- deferrable=options.get("deferrable"),
- initially=options.get("initially"),
- name=params["name"],
- )
- # needed by 0.7
- conn_table.append_constraint(const)
- return const
-
-
-@contextlib.contextmanager
-def _compare_columns(
- schema: Optional[str],
- tname: Union[quoted_name, str],
- conn_table: Table,
- metadata_table: Table,
- modify_table_ops: ModifyTableOps,
- autogen_context: AutogenContext,
- inspector: Inspector,
-) -> Iterator[None]:
- name = "%s.%s" % (schema, tname) if schema else tname
- metadata_col_names = OrderedSet(
- c.name for c in metadata_table.c if not c.system
- )
- metadata_cols_by_name = {
- c.name: c for c in metadata_table.c if not c.system
- }
-
- conn_col_names = {
- c.name: c
- for c in conn_table.c
- if autogen_context.run_name_filters(
- c.name, "column", {"table_name": tname, "schema_name": schema}
- )
- }
-
- for cname in metadata_col_names.difference(conn_col_names):
- if autogen_context.run_object_filters(
- metadata_cols_by_name[cname], cname, "column", False, None
- ):
- modify_table_ops.ops.append(
- ops.AddColumnOp.from_column_and_tablename(
- schema, tname, metadata_cols_by_name[cname]
- )
- )
- log.info("Detected added column '%s.%s'", name, cname)
-
- for colname in metadata_col_names.intersection(conn_col_names):
- metadata_col = metadata_cols_by_name[colname]
- conn_col = conn_table.c[colname]
- if not autogen_context.run_object_filters(
- metadata_col, colname, "column", False, conn_col
- ):
- continue
- alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema)
-
- comparators.dispatch("column")(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- colname,
- conn_col,
- metadata_col,
- )
-
- if alter_column_op.has_changes():
- modify_table_ops.ops.append(alter_column_op)
-
- yield
-
- for cname in set(conn_col_names).difference(metadata_col_names):
- if autogen_context.run_object_filters(
- conn_table.c[cname], cname, "column", True, None
- ):
- modify_table_ops.ops.append(
- ops.DropColumnOp.from_column_and_tablename(
- schema, tname, conn_table.c[cname]
- )
- )
- log.info("Detected removed column '%s.%s'", name, cname)
+ from ...autogenerate.api import AutogenContext
+ from ...ddl._autogen import _constraint_sig
+ from ...ddl.impl import DefaultImpl
+ from ...operations.ops import AlterColumnOp
+ from ...operations.ops import ModifyTableOps
+ from ...runtime.plugins import Plugin
_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
-class _InspectorConv:
- __slots__ = ("inspector",)
-
- def __init__(self, inspector):
- self.inspector = inspector
-
- def _apply_reflectinfo_conv(self, consts):
- if not consts:
- return consts
- for const in consts:
- if const["name"] is not None and not isinstance(
- const["name"], conv
- ):
- const["name"] = conv(const["name"])
- return consts
-
- def _apply_constraint_conv(self, consts):
- if not consts:
- return consts
- for const in consts:
- if const.name is not None and not isinstance(const.name, conv):
- const.name = conv(const.name)
- return consts
-
- def get_indexes(self, *args, **kw):
- return self._apply_reflectinfo_conv(
- self.inspector.get_indexes(*args, **kw)
- )
-
- def get_unique_constraints(self, *args, **kw):
- return self._apply_reflectinfo_conv(
- self.inspector.get_unique_constraints(*args, **kw)
- )
-
- def get_foreign_keys(self, *args, **kw):
- return self._apply_reflectinfo_conv(
- self.inspector.get_foreign_keys(*args, **kw)
- )
-
- def reflect_table(self, table, *, include_columns):
- self.inspector.reflect_table(table, include_columns=include_columns)
-
- # I had a cool version of this using _ReflectInfo, however that doesn't
- # work in 1.4 and it's not public API in 2.x. Then this is just a two
- # liner. So there's no competition...
- self._apply_constraint_conv(table.constraints)
- self._apply_constraint_conv(table.indexes)
+log = logging.getLogger(__name__)
-@comparators.dispatch_for("table")
def _compare_indexes_and_uniques(
autogen_context: AutogenContext,
modify_ops: ModifyTableOps,
tname: Union[quoted_name, str],
conn_table: Optional[Table],
metadata_table: Optional[Table],
-) -> None:
+) -> PriorityDispatchResult:
inspector = autogen_context.inspector
is_create_table = conn_table is None
is_drop_table = metadata_table is None
if c.is_named
}
- conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
- conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
+ conn_uniques_by_name: Dict[
+ sqla_compat._ConstraintName,
+ _constraint_sig[sa_schema.UniqueConstraint],
+ ]
+ conn_indexes_by_name: Dict[
+ sqla_compat._ConstraintName, _constraint_sig[sa_schema.Index]
+ ]
conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
# 4. The backend may double up indexes as unique constraints and
# vice versa (e.g. MySQL, Postgresql)
- def obj_added(obj: _constraint_sig):
+ def obj_added(
+ obj: (
+ _constraint_sig[sa_schema.UniqueConstraint]
+ | _constraint_sig[sa_schema.Index]
+ ),
+ ):
if is_index_sig(obj):
if autogen_context.run_object_filters(
obj.const, obj.name, "index", False, None
else:
assert False
- def obj_removed(obj: _constraint_sig):
+ def obj_removed(
+ obj: (
+ _constraint_sig[sa_schema.UniqueConstraint]
+ | _constraint_sig[sa_schema.Index]
+ ),
+ ):
if is_index_sig(obj):
if obj.is_unique and not supports_unique_constraints:
# many databases double up unique constraints
assert False
def obj_changed(
- old: _constraint_sig,
- new: _constraint_sig,
+ old: (
+ _constraint_sig[sa_schema.UniqueConstraint]
+ | _constraint_sig[sa_schema.Index]
+ ),
+ new: (
+ _constraint_sig[sa_schema.UniqueConstraint]
+ | _constraint_sig[sa_schema.Index]
+ ),
msg: str,
):
if is_index_sig(old):
obj_removed(conn_obj)
obj_added(metadata_obj)
else:
+ # TODO: for plugins, let's do is_index_sig / is_uq_sig
+ # here so we know index or unique, then
+ # do a sub-dispatch,
+ # autogen_context.comparators.dispatch("index")
+ # or
+ # autogen_context.comparators.dispatch("unique_constraint")
+ #
comparison = metadata_obj.compare_to_reflected(conn_obj)
if comparison.is_different:
if uq_sig not in conn_uniques_by_sig:
obj_added(unnamed_metadata_uniques[uq_sig])
+ return PriorityDispatchResult.CONTINUE
+
def _correct_for_uq_duplicates_uix(
conn_unique_constraints,
conn_indexes.discard(conn_ix_names[overlap])
-@comparators.dispatch_for("column")
-def _compare_nullable(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: Union[quoted_name, str],
- cname: Union[quoted_name, str],
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> None:
- metadata_col_nullable = metadata_col.nullable
- conn_col_nullable = conn_col.nullable
- alter_column_op.existing_nullable = conn_col_nullable
-
- if conn_col_nullable is not metadata_col_nullable:
- if (
- sqla_compat._server_default_is_computed(
- metadata_col.server_default, conn_col.server_default
- )
- and sqla_compat._nullability_might_be_unset(metadata_col)
- or (
- sqla_compat._server_default_is_identity(
- metadata_col.server_default, conn_col.server_default
- )
- )
- ):
- log.info(
- "Ignoring nullable change on identity column '%s.%s'",
- tname,
- cname,
- )
- else:
- alter_column_op.modify_nullable = metadata_col_nullable
- log.info(
- "Detected %s on column '%s.%s'",
- "NULL" if metadata_col_nullable else "NOT NULL",
- tname,
- cname,
- )
-
-
-@comparators.dispatch_for("column")
-def _setup_autoincrement(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: Union[quoted_name, str],
- cname: quoted_name,
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> None:
- if metadata_col.table._autoincrement_column is metadata_col:
- alter_column_op.kw["autoincrement"] = True
- elif metadata_col.autoincrement is True:
- alter_column_op.kw["autoincrement"] = True
- elif metadata_col.autoincrement is False:
- alter_column_op.kw["autoincrement"] = False
-
-
-@comparators.dispatch_for("column")
-def _compare_type(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: Union[quoted_name, str],
- cname: Union[quoted_name, str],
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> None:
- conn_type = conn_col.type
- alter_column_op.existing_type = conn_type
- metadata_type = metadata_col.type
- if conn_type._type_affinity is sqltypes.NullType:
- log.info(
- "Couldn't determine database type " "for column '%s.%s'",
- tname,
- cname,
- )
- return
- if metadata_type._type_affinity is sqltypes.NullType:
- log.info(
- "Column '%s.%s' has no type within " "the model; can't compare",
- tname,
- cname,
- )
- return
-
- isdiff = autogen_context.migration_context._compare_type(
- conn_col, metadata_col
- )
+_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
+ {
+ "asc": expression.asc,
+ "desc": expression.desc,
+ "nulls_first": expression.nullsfirst,
+ "nulls_last": expression.nullslast,
+ "nullsfirst": expression.nullsfirst, # 1_3 name
+ "nullslast": expression.nullslast, # 1_3 name
+ }
+)
- if isdiff:
- alter_column_op.modify_type = metadata_type
- log.info(
- "Detected type change from %r to %r on '%s.%s'",
- conn_type,
- metadata_type,
- tname,
- cname,
- )
+def _make_index(
+ impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+) -> Optional[Index]:
+ exprs: list[Union[Column[Any], TextClause]] = []
+ sorting = params.get("column_sorting")
-def _render_server_default_for_compare(
- metadata_default: Optional[Any], autogen_context: AutogenContext
-) -> Optional[str]:
- if isinstance(metadata_default, sa_schema.DefaultClause):
- if isinstance(metadata_default.arg, str):
- metadata_default = metadata_default.arg
+ for num, col_name in enumerate(params["column_names"]):
+ item: Union[Column[Any], TextClause]
+ if col_name is None:
+ assert "expressions" in params
+ name = params["expressions"][num]
+ item = text(name)
else:
- metadata_default = str(
- metadata_default.arg.compile(
- dialect=autogen_context.dialect,
- compile_kwargs={"literal_binds": True},
- )
- )
- if isinstance(metadata_default, str):
- return metadata_default
- else:
- return None
-
-
-def _normalize_computed_default(sqltext: str) -> str:
- """we want to warn if a computed sql expression has changed. however
- we don't want false positives and the warning is not that critical.
- so filter out most forms of variability from the SQL text.
-
- """
-
- return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower()
-
-
-def _compare_computed_default(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: str,
- cname: str,
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> None:
- rendered_metadata_default = str(
- cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
- dialect=autogen_context.dialect,
- compile_kwargs={"literal_binds": True},
- )
+ name = col_name
+ item = conn_table.c[col_name]
+ if sorting and name in sorting:
+ for operator in sorting[name]:
+ if operator in _IndexColumnSortingOps:
+ item = _IndexColumnSortingOps[operator](item)
+ exprs.append(item)
+ ix = sa_schema.Index(
+ params["name"],
+ *exprs,
+ unique=params["unique"],
+ _table=conn_table,
+ **impl.adjust_reflected_dialect_options(params, "index"),
)
+ if "duplicates_constraint" in params:
+ ix.info["duplicates_constraint"] = params["duplicates_constraint"]
+ return ix
- # since we cannot change computed columns, we do only a crude comparison
- # here where we try to eliminate syntactical differences in order to
- # get a minimal comparison just to emit a warning.
- rendered_metadata_default = _normalize_computed_default(
- rendered_metadata_default
+def _make_unique_constraint(
+ impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+) -> UniqueConstraint:
+ uq = sa_schema.UniqueConstraint(
+ *[conn_table.c[cname] for cname in params["column_names"]],
+ name=params["name"],
+ **impl.adjust_reflected_dialect_options(params, "unique_constraint"),
)
+ if "duplicates_index" in params:
+ uq.info["duplicates_index"] = params["duplicates_index"]
- if isinstance(conn_col.server_default, sa_schema.Computed):
- rendered_conn_default = str(
- conn_col.server_default.sqltext.compile(
- dialect=autogen_context.dialect,
- compile_kwargs={"literal_binds": True},
- )
- )
- if rendered_conn_default is None:
- rendered_conn_default = ""
- else:
- rendered_conn_default = _normalize_computed_default(
- rendered_conn_default
- )
- else:
- rendered_conn_default = ""
-
- if rendered_metadata_default != rendered_conn_default:
- _warn_computed_not_supported(tname, cname)
+ return uq
-def _warn_computed_not_supported(tname: str, cname: str) -> None:
- util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
+def _make_foreign_key(
+ params: Dict[str, Any], conn_table: Table
+) -> ForeignKeyConstraint:
+ tname = params["referred_table"]
+ if params["referred_schema"]:
+ tname = "%s.%s" % (params["referred_schema"], tname)
+ options = params.get("options", {})
-def _compare_identity_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
- impl = autogen_context.migration_context.impl
- diff, ignored_attr, is_alter = impl._compare_identity_default(
- metadata_col.server_default, conn_col.server_default
+ const = sa_schema.ForeignKeyConstraint(
+ [conn_table.c[cname] for cname in params["constrained_columns"]],
+ ["%s.%s" % (tname, n) for n in params["referred_columns"]],
+ onupdate=options.get("onupdate"),
+ ondelete=options.get("ondelete"),
+ deferrable=options.get("deferrable"),
+ initially=options.get("initially"),
+ name=params["name"],
)
-
- return diff, is_alter
-
-
-@comparators.dispatch_for("column")
-def _compare_server_default(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: Union[quoted_name, str],
- cname: Union[quoted_name, str],
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> Optional[bool]:
- metadata_default = metadata_col.server_default
- conn_col_default = conn_col.server_default
- if conn_col_default is None and metadata_default is None:
- return False
-
- if sqla_compat._server_default_is_computed(metadata_default):
- return _compare_computed_default( # type:ignore[func-returns-value]
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
- )
- if sqla_compat._server_default_is_computed(conn_col_default):
- _warn_computed_not_supported(tname, cname)
- return False
-
- if sqla_compat._server_default_is_identity(
- metadata_default, conn_col_default
- ):
- alter_column_op.existing_server_default = conn_col_default
- diff, is_alter = _compare_identity_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
- )
- if is_alter:
- alter_column_op.modify_server_default = metadata_default
- if diff:
- log.info(
- "Detected server default on column '%s.%s': "
- "identity options attributes %s",
- tname,
- cname,
- sorted(diff),
- )
- else:
- rendered_metadata_default = _render_server_default_for_compare(
- metadata_default, autogen_context
- )
-
- rendered_conn_default = (
- cast(Any, conn_col_default).arg.text if conn_col_default else None
- )
-
- alter_column_op.existing_server_default = conn_col_default
-
- is_diff = autogen_context.migration_context._compare_server_default(
- conn_col,
- metadata_col,
- rendered_metadata_default,
- rendered_conn_default,
- )
- if is_diff:
- alter_column_op.modify_server_default = metadata_default
- log.info("Detected server default on column '%s.%s'", tname, cname)
-
- return None
-
-
-@comparators.dispatch_for("column")
-def _compare_column_comment(
- autogen_context: AutogenContext,
- alter_column_op: AlterColumnOp,
- schema: Optional[str],
- tname: Union[quoted_name, str],
- cname: quoted_name,
- conn_col: Column[Any],
- metadata_col: Column[Any],
-) -> Optional[Literal[False]]:
- assert autogen_context.dialect is not None
- if not autogen_context.dialect.supports_comments:
- return None
-
- metadata_comment = metadata_col.comment
- conn_col_comment = conn_col.comment
- if conn_col_comment is None and metadata_comment is None:
- return False
-
- alter_column_op.existing_comment = conn_col_comment
-
- if conn_col_comment != metadata_comment:
- alter_column_op.modify_comment = metadata_comment
- log.info("Detected column comment '%s.%s'", tname, cname)
-
- return None
+ # needed by 0.7
+ conn_table.append_constraint(const)
+ return const
-@comparators.dispatch_for("table")
def _compare_foreign_keys(
autogen_context: AutogenContext,
modify_table_ops: ModifyTableOps,
tname: Union[quoted_name, str],
conn_table: Table,
metadata_table: Table,
-) -> None:
+) -> PriorityDispatchResult:
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
if conn_table is None or metadata_table is None:
- return
+ return PriorityDispatchResult.CONTINUE
inspector = autogen_context.inspector
metadata_fks = {
if removed_sig not in metadata_fks_by_sig:
compare_to = (
metadata_fks_by_name[const.name].const
- if const.name in metadata_fks_by_name
+ if const.name and const.name in metadata_fks_by_name
else None
)
_remove_fk(const, compare_to)
if added_sig not in conn_fks_by_sig:
compare_to = (
conn_fks_by_name[const.name].const
- if const.name in conn_fks_by_name
+ if const.name and const.name in conn_fks_by_name
else None
)
_add_fk(const, compare_to)
+ return PriorityDispatchResult.CONTINUE
-@comparators.dispatch_for("table")
-def _compare_table_comment(
+
+def _compare_nullable(
autogen_context: AutogenContext,
- modify_table_ops: ModifyTableOps,
+ alter_column_op: AlterColumnOp,
schema: Optional[str],
tname: Union[quoted_name, str],
- conn_table: Optional[Table],
- metadata_table: Optional[Table],
-) -> None:
- assert autogen_context.dialect is not None
- if not autogen_context.dialect.supports_comments:
- return
-
- # if we're doing CREATE TABLE, comments will be created inline
- # with the create_table op.
- if conn_table is None or metadata_table is None:
- return
-
- if conn_table.comment is None and metadata_table.comment is None:
- return
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+ metadata_col_nullable = metadata_col.nullable
+ conn_col_nullable = conn_col.nullable
+ alter_column_op.existing_nullable = conn_col_nullable
- if metadata_table.comment is None and conn_table.comment is not None:
- modify_table_ops.ops.append(
- ops.DropTableCommentOp(
- tname, existing_comment=conn_table.comment, schema=schema
+ if conn_col_nullable is not metadata_col_nullable:
+ if (
+ sqla_compat._server_default_is_computed(
+ metadata_col.server_default, conn_col.server_default
)
- )
- elif metadata_table.comment != conn_table.comment:
- modify_table_ops.ops.append(
- ops.CreateTableCommentOp(
+ and sqla_compat._nullability_might_be_unset(metadata_col)
+ or (
+ sqla_compat._server_default_is_identity(
+ metadata_col.server_default, conn_col.server_default
+ )
+ )
+ ):
+ log.info(
+ "Ignoring nullable change on identity column '%s.%s'",
tname,
- metadata_table.comment,
- existing_comment=conn_table.comment,
- schema=schema,
+ cname,
)
- )
+ else:
+ alter_column_op.modify_nullable = metadata_col_nullable
+ log.info(
+ "Detected %s on column '%s.%s'",
+ "NULL" if metadata_col_nullable else "NOT NULL",
+ tname,
+ cname,
+ )
+ # column nullablity changed, no further nullable checks needed
+ return PriorityDispatchResult.STOP
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _compare_indexes_and_uniques,
+ "table",
+ "indexes",
+ )
+ plugin.add_autogenerate_comparator(
+ _compare_foreign_keys,
+ "table",
+ "foreignkeys",
+ )
+ plugin.add_autogenerate_comparator(
+ _compare_nullable,
+ "column",
+ "nullable",
+ )
--- /dev/null
+# mypy: allow-untyped-calls
+
+from __future__ import annotations
+
+import logging
+from typing import Optional
+from typing import Set
+from typing import TYPE_CHECKING
+
+from sqlalchemy import inspect
+
+from ...util import PriorityDispatchResult
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine.reflection import Inspector
+
+ from ...autogenerate.api import AutogenContext
+ from ...operations.ops import UpgradeOps
+ from ...runtime.plugins import Plugin
+
+
+log = logging.getLogger(__name__)
+
+
+def _produce_net_changes(
+ autogen_context: AutogenContext, upgrade_ops: UpgradeOps
+) -> PriorityDispatchResult:
+ connection = autogen_context.connection
+ assert connection is not None
+ include_schemas = autogen_context.opts.get("include_schemas", False)
+
+ inspector: Inspector = inspect(connection)
+
+ default_schema = connection.dialect.default_schema_name
+ schemas: Set[Optional[str]]
+ if include_schemas:
+ schemas = set(inspector.get_schema_names())
+ # replace default schema name with None
+ schemas.discard("information_schema")
+ # replace the "default" schema with None
+ schemas.discard(default_schema)
+ schemas.add(None)
+ else:
+ schemas = {None}
+
+ schemas = {
+ s for s in schemas if autogen_context.run_name_filters(s, "schema", {})
+ }
+
+ assert autogen_context.dialect is not None
+ autogen_context.comparators.dispatch(
+ "schema", qualifier=autogen_context.dialect.name
+ )(autogen_context, upgrade_ops, schemas)
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _produce_net_changes,
+ "autogenerate",
+ )
--- /dev/null
+from __future__ import annotations
+
+import logging
+import re
+from types import NoneType
+from typing import Any
+from typing import cast
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy.sql.schema import DefaultClause
+
+from ... import util
+from ...util import DispatchPriority
+from ...util import PriorityDispatchResult
+from ...util import sqla_compat
+
+if TYPE_CHECKING:
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Column
+
+ from ...autogenerate.api import AutogenContext
+ from ...operations.ops import AlterColumnOp
+ from ...runtime.plugins import Plugin
+
+log = logging.getLogger(__name__)
+
+
+def _render_server_default_for_compare(
+ metadata_default: Optional[Any], autogen_context: AutogenContext
+) -> Optional[str]:
+ if isinstance(metadata_default, sa_schema.DefaultClause):
+ if isinstance(metadata_default.arg, str):
+ metadata_default = metadata_default.arg
+ else:
+ metadata_default = str(
+ metadata_default.arg.compile(
+ dialect=autogen_context.dialect,
+ compile_kwargs={"literal_binds": True},
+ )
+ )
+ if isinstance(metadata_default, str):
+ return metadata_default
+ else:
+ return None
+
+
+def _normalize_computed_default(sqltext: str) -> str:
+ """we want to warn if a computed sql expression has changed. however
+ we don't want false positives and the warning is not that critical.
+ so filter out most forms of variability from the SQL text.
+
+ """
+
+ return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower()
+
+
+def _compare_computed_default(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: str,
+ cname: str,
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+
+ metadata_default = metadata_col.server_default
+ conn_col_default = conn_col.server_default
+ if conn_col_default is None and metadata_default is None:
+ return PriorityDispatchResult.CONTINUE
+
+ if sqla_compat._server_default_is_computed(
+ conn_col_default
+ ) and not sqla_compat._server_default_is_computed(metadata_default):
+ _warn_computed_not_supported(tname, cname)
+ return PriorityDispatchResult.STOP
+
+ if not sqla_compat._server_default_is_computed(metadata_default):
+ return PriorityDispatchResult.CONTINUE
+
+ rendered_metadata_default = str(
+ cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
+ dialect=autogen_context.dialect,
+ compile_kwargs={"literal_binds": True},
+ )
+ )
+
+ # since we cannot change computed columns, we do only a crude comparison
+ # here where we try to eliminate syntactical differences in order to
+ # get a minimal comparison just to emit a warning.
+
+ rendered_metadata_default = _normalize_computed_default(
+ rendered_metadata_default
+ )
+
+ if isinstance(conn_col.server_default, sa_schema.Computed):
+ rendered_conn_default = str(
+ conn_col.server_default.sqltext.compile(
+ dialect=autogen_context.dialect,
+ compile_kwargs={"literal_binds": True},
+ )
+ )
+ rendered_conn_default = _normalize_computed_default(
+ rendered_conn_default
+ )
+ else:
+ rendered_conn_default = ""
+
+ if rendered_metadata_default != rendered_conn_default:
+ _warn_computed_not_supported(tname, cname)
+
+ return PriorityDispatchResult.STOP
+
+
+def _warn_computed_not_supported(tname: str, cname: str) -> None:
+ util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
+
+
+def _compare_identity_default(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+ skip: Sequence[str] = (
+ "order",
+ "on_null",
+ "oracle_order",
+ "oracle_on_null",
+ ),
+) -> PriorityDispatchResult:
+
+ metadata_default = metadata_col.server_default
+ conn_col_default = conn_col.server_default
+ if (
+ conn_col_default is None
+ and metadata_default is None
+ or not sqla_compat._server_default_is_identity(
+ metadata_default, conn_col_default
+ )
+ ):
+ return PriorityDispatchResult.CONTINUE
+
+ assert isinstance(
+ metadata_col.server_default,
+ (sa_schema.Identity, sa_schema.Sequence, NoneType),
+ )
+ assert isinstance(
+ conn_col.server_default,
+ (sa_schema.Identity, sa_schema.Sequence, NoneType),
+ )
+
+ impl = autogen_context.migration_context.impl
+ diff, _, is_alter = impl._compare_identity_default( # type: ignore[no-untyped-call] # noqa: E501
+ metadata_col.server_default, conn_col.server_default
+ )
+
+ if is_alter:
+ alter_column_op.modify_server_default = metadata_default
+ if diff:
+ log.info(
+ "Detected server default on column '%s.%s': "
+ "identity options attributes %s",
+ tname,
+ cname,
+ sorted(diff),
+ )
+
+ return PriorityDispatchResult.STOP
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def _user_compare_server_default(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+
+ metadata_default = metadata_col.server_default
+ conn_col_default = conn_col.server_default
+ if conn_col_default is None and metadata_default is None:
+ return PriorityDispatchResult.CONTINUE
+
+ alter_column_op.existing_server_default = conn_col_default
+
+ migration_context = autogen_context.migration_context
+
+ if migration_context._user_compare_server_default is False:
+ return PriorityDispatchResult.STOP
+
+ if not callable(migration_context._user_compare_server_default):
+ return PriorityDispatchResult.CONTINUE
+
+ rendered_metadata_default = _render_server_default_for_compare(
+ metadata_default, autogen_context
+ )
+ rendered_conn_default = (
+ cast(Any, conn_col_default).arg.text if conn_col_default else None
+ )
+
+ is_diff = migration_context._user_compare_server_default(
+ migration_context,
+ conn_col,
+ metadata_col,
+ rendered_conn_default,
+ metadata_col.server_default,
+ rendered_metadata_default,
+ )
+ if is_diff:
+ alter_column_op.modify_server_default = metadata_default
+ log.info(
+ "User defined function %s detected "
+ "server default on column '%s.%s'",
+ migration_context._user_compare_server_default,
+ tname,
+ cname,
+ )
+ return PriorityDispatchResult.STOP
+ return PriorityDispatchResult.CONTINUE
+
+
+def _dialect_impl_compare_server_default(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+ """use dialect.impl.compare_server_default.
+
+ This would in theory not be needed. however we dont know if any
+ third party libraries haven't made their own alembic dialect and
+ implemented this method.
+
+ """
+ metadata_default = metadata_col.server_default
+ conn_col_default = conn_col.server_default
+ if conn_col_default is None and metadata_default is None:
+ return PriorityDispatchResult.CONTINUE
+
+ # this is already done by _user_compare_server_default,
+ # but doing it here also for unit tests that want to call
+ # _dialect_impl_compare_server_default directly
+ alter_column_op.existing_server_default = conn_col_default
+
+ if not isinstance(
+ metadata_default, (DefaultClause, NoneType)
+ ) or not isinstance(conn_col_default, (DefaultClause, NoneType)):
+ return PriorityDispatchResult.CONTINUE
+
+ migration_context = autogen_context.migration_context
+
+ rendered_metadata_default = _render_server_default_for_compare(
+ metadata_default, autogen_context
+ )
+ rendered_conn_default = (
+ cast(Any, conn_col_default).arg.text if conn_col_default else None
+ )
+
+ is_diff = migration_context.impl.compare_server_default( # type: ignore[no-untyped-call] # noqa: E501
+ conn_col,
+ metadata_col,
+ rendered_metadata_default,
+ rendered_conn_default,
+ )
+ if is_diff:
+ alter_column_op.modify_server_default = metadata_default
+ log.info(
+ "Dialect impl %s detected server default on column '%s.%s'",
+ migration_context.impl,
+ tname,
+ cname,
+ )
+ return PriorityDispatchResult.STOP
+ return PriorityDispatchResult.CONTINUE
+
+
+def _setup_autoincrement(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: quoted_name,
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+ if metadata_col.table._autoincrement_column is metadata_col:
+ alter_column_op.kw["autoincrement"] = True
+ elif metadata_col.autoincrement is True:
+ alter_column_op.kw["autoincrement"] = True
+ elif metadata_col.autoincrement is False:
+ alter_column_op.kw["autoincrement"] = False
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _user_compare_server_default,
+ "column",
+ "server_default",
+ priority=DispatchPriority.FIRST,
+ )
+ plugin.add_autogenerate_comparator(
+ _compare_computed_default,
+ "column",
+ "server_default",
+ )
+
+ plugin.add_autogenerate_comparator(
+ _compare_identity_default,
+ "column",
+ "server_default",
+ )
+
+ plugin.add_autogenerate_comparator(
+ _setup_autoincrement,
+ "column",
+ "server_default",
+ )
+ plugin.add_autogenerate_comparator(
+ _dialect_impl_compare_server_default,
+ "column",
+ "server_default",
+ priority=DispatchPriority.LAST,
+ )
--- /dev/null
+# mypy: allow-untyped-calls
+
+from __future__ import annotations
+
+import contextlib
+import logging
+from typing import Iterator
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy import event
+from sqlalchemy import schema as sa_schema
+from sqlalchemy.util import OrderedSet
+
+from .util import _InspectorConv
+from ...operations import ops
+from ...util import PriorityDispatchResult
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Table
+
+ from ...autogenerate.api import AutogenContext
+ from ...operations.ops import ModifyTableOps
+ from ...operations.ops import UpgradeOps
+ from ...runtime.plugins import Plugin
+
+
+log = logging.getLogger(__name__)
+
+
+def _autogen_for_tables(
+ autogen_context: AutogenContext,
+ upgrade_ops: UpgradeOps,
+ schemas: Set[Optional[str]],
+) -> PriorityDispatchResult:
+ inspector = autogen_context.inspector
+
+ conn_table_names: Set[Tuple[Optional[str], str]] = set()
+
+ version_table_schema = (
+ autogen_context.migration_context.version_table_schema
+ )
+ version_table = autogen_context.migration_context.version_table
+
+ for schema_name in schemas:
+ tables = set(inspector.get_table_names(schema=schema_name))
+ if schema_name == version_table_schema:
+ tables = tables.difference(
+ [autogen_context.migration_context.version_table]
+ )
+
+ conn_table_names.update(
+ (schema_name, tname)
+ for tname in tables
+ if autogen_context.run_name_filters(
+ tname, "table", {"schema_name": schema_name}
+ )
+ )
+
+ metadata_table_names = OrderedSet(
+ [(table.schema, table.name) for table in autogen_context.sorted_tables]
+ ).difference([(version_table_schema, version_table)])
+
+ _compare_tables(
+ conn_table_names,
+ metadata_table_names,
+ inspector,
+ upgrade_ops,
+ autogen_context,
+ )
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def _compare_tables(
+ conn_table_names: set[tuple[str | None, str]],
+ metadata_table_names: set[tuple[str | None, str]],
+ inspector: Inspector,
+ upgrade_ops: UpgradeOps,
+ autogen_context: AutogenContext,
+) -> None:
+ default_schema = inspector.bind.dialect.default_schema_name
+
+ # tables coming from the connection will not have "schema"
+ # set if it matches default_schema_name; so we need a list
+ # of table names from local metadata that also have "None" if schema
+ # == default_schema_name. Most setups will be like this anyway but
+ # some are not (see #170)
+ metadata_table_names_no_dflt_schema = OrderedSet(
+ [
+ (schema if schema != default_schema else None, tname)
+ for schema, tname in metadata_table_names
+ ]
+ )
+
+ # to adjust for the MetaData collection storing the tables either
+ # as "schemaname.tablename" or just "tablename", create a new lookup
+ # which will match the "non-default-schema" keys to the Table object.
+ tname_to_table = {
+ no_dflt_schema: autogen_context.table_key_to_table[
+ sa_schema._get_table_key(tname, schema)
+ ]
+ for no_dflt_schema, (schema, tname) in zip(
+ metadata_table_names_no_dflt_schema, metadata_table_names
+ )
+ }
+ metadata_table_names = metadata_table_names_no_dflt_schema
+
+ for s, tname in metadata_table_names.difference(conn_table_names):
+ name = "%s.%s" % (s, tname) if s else tname
+ metadata_table = tname_to_table[(s, tname)]
+ if autogen_context.run_object_filters(
+ metadata_table, tname, "table", False, None
+ ):
+ upgrade_ops.ops.append(
+ ops.CreateTableOp.from_table(metadata_table)
+ )
+ log.info("Detected added table %r", name)
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+
+ autogen_context.comparators.dispatch(
+ "table", qualifier=autogen_context.dialect.name
+ )(
+ autogen_context,
+ modify_table_ops,
+ s,
+ tname,
+ None,
+ metadata_table,
+ )
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
+
+ removal_metadata = sa_schema.MetaData()
+ for s, tname in conn_table_names.difference(metadata_table_names):
+ name = sa_schema._get_table_key(tname, s)
+ exists = name in removal_metadata.tables
+ t = sa_schema.Table(tname, removal_metadata, schema=s)
+
+ if not exists:
+ event.listen(
+ t,
+ "column_reflect",
+ # fmt: off
+ autogen_context.migration_context.impl.
+ _compat_autogen_column_reflect
+ (inspector),
+ # fmt: on
+ )
+ _InspectorConv(inspector).reflect_table(t, include_columns=None)
+ if autogen_context.run_object_filters(t, tname, "table", True, None):
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+
+ autogen_context.comparators.dispatch(
+ "table", qualifier=autogen_context.dialect.name
+ )(autogen_context, modify_table_ops, s, tname, t, None)
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
+
+ upgrade_ops.ops.append(ops.DropTableOp.from_table(t))
+ log.info("Detected removed table %r", name)
+
+ existing_tables = conn_table_names.intersection(metadata_table_names)
+
+ existing_metadata = sa_schema.MetaData()
+ conn_column_info = {}
+ for s, tname in existing_tables:
+ name = sa_schema._get_table_key(tname, s)
+ exists = name in existing_metadata.tables
+ t = sa_schema.Table(tname, existing_metadata, schema=s)
+ if not exists:
+ event.listen(
+ t,
+ "column_reflect",
+ # fmt: off
+ autogen_context.migration_context.impl.
+ _compat_autogen_column_reflect(inspector),
+ # fmt: on
+ )
+ _InspectorConv(inspector).reflect_table(t, include_columns=None)
+
+ conn_column_info[(s, tname)] = t
+
+ for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
+ s = s or None
+ name = "%s.%s" % (s, tname) if s else tname
+ metadata_table = tname_to_table[(s, tname)]
+ conn_table = existing_metadata.tables[name]
+
+ if autogen_context.run_object_filters(
+ metadata_table, tname, "table", False, conn_table
+ ):
+ modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
+ with _compare_columns(
+ s,
+ tname,
+ conn_table,
+ metadata_table,
+ modify_table_ops,
+ autogen_context,
+ inspector,
+ ):
+ autogen_context.comparators.dispatch(
+ "table", qualifier=autogen_context.dialect.name
+ )(
+ autogen_context,
+ modify_table_ops,
+ s,
+ tname,
+ conn_table,
+ metadata_table,
+ )
+
+ if not modify_table_ops.is_empty():
+ upgrade_ops.ops.append(modify_table_ops)
+
+
+@contextlib.contextmanager
+def _compare_columns(
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ conn_table: Table,
+ metadata_table: Table,
+ modify_table_ops: ModifyTableOps,
+ autogen_context: AutogenContext,
+ inspector: Inspector,
+) -> Iterator[None]:
+ name = "%s.%s" % (schema, tname) if schema else tname
+ metadata_col_names = OrderedSet(
+ c.name for c in metadata_table.c if not c.system
+ )
+ metadata_cols_by_name = {
+ c.name: c for c in metadata_table.c if not c.system
+ }
+
+ conn_col_names = {
+ c.name: c
+ for c in conn_table.c
+ if autogen_context.run_name_filters(
+ c.name, "column", {"table_name": tname, "schema_name": schema}
+ )
+ }
+
+ for cname in metadata_col_names.difference(conn_col_names):
+ if autogen_context.run_object_filters(
+ metadata_cols_by_name[cname], cname, "column", False, None
+ ):
+ modify_table_ops.ops.append(
+ ops.AddColumnOp.from_column_and_tablename(
+ schema, tname, metadata_cols_by_name[cname]
+ )
+ )
+ log.info("Detected added column '%s.%s'", name, cname)
+
+ for colname in metadata_col_names.intersection(conn_col_names):
+ metadata_col = metadata_cols_by_name[colname]
+ conn_col = conn_table.c[colname]
+ if not autogen_context.run_object_filters(
+ metadata_col, colname, "column", False, conn_col
+ ):
+ continue
+ alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema)
+
+ autogen_context.comparators.dispatch(
+ "column", qualifier=autogen_context.dialect.name
+ )(
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ colname,
+ conn_col,
+ metadata_col,
+ )
+
+ if alter_column_op.has_changes():
+ modify_table_ops.ops.append(alter_column_op)
+
+ yield
+
+ for cname in set(conn_col_names).difference(metadata_col_names):
+ if autogen_context.run_object_filters(
+ conn_table.c[cname], cname, "column", True, None
+ ):
+ modify_table_ops.ops.append(
+ ops.DropColumnOp.from_column_and_tablename(
+ schema, tname, conn_table.c[cname]
+ )
+ )
+ log.info("Detected removed column '%s.%s'", name, cname)
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _autogen_for_tables,
+ "schema",
+ "tables",
+ )
--- /dev/null
+from __future__ import annotations
+
+import logging
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy import types as sqltypes
+
+from ...util import DispatchPriority
+from ...util import PriorityDispatchResult
+
+if TYPE_CHECKING:
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Column
+
+ from ...autogenerate.api import AutogenContext
+ from ...operations.ops import AlterColumnOp
+ from ...runtime.plugins import Plugin
+
+
+log = logging.getLogger(__name__)
+
+
+def _compare_type_setup(
+ alter_column_op: AlterColumnOp,
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> bool:
+
+ conn_type = conn_col.type
+ alter_column_op.existing_type = conn_type
+ metadata_type = metadata_col.type
+ if conn_type._type_affinity is sqltypes.NullType:
+ log.info(
+ "Couldn't determine database type for column '%s.%s'",
+ tname,
+ cname,
+ )
+ return False
+ if metadata_type._type_affinity is sqltypes.NullType:
+ log.info(
+ "Column '%s.%s' has no type within the model; can't compare",
+ tname,
+ cname,
+ )
+ return False
+
+ return True
+
+
+def _user_compare_type(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+
+ migration_context = autogen_context.migration_context
+
+ if migration_context._user_compare_type is False:
+ return PriorityDispatchResult.STOP
+
+ if not _compare_type_setup(
+ alter_column_op, tname, cname, conn_col, metadata_col
+ ):
+ return PriorityDispatchResult.CONTINUE
+
+ if not callable(migration_context._user_compare_type):
+ return PriorityDispatchResult.CONTINUE
+
+ is_diff = migration_context._user_compare_type(
+ migration_context,
+ conn_col,
+ metadata_col,
+ conn_col.type,
+ metadata_col.type,
+ )
+ if is_diff:
+ alter_column_op.modify_type = metadata_col.type
+ log.info(
+ "Detected type change from %r to %r on '%s.%s'",
+ conn_col.type,
+ metadata_col.type,
+ tname,
+ cname,
+ )
+ return PriorityDispatchResult.STOP
+ elif is_diff is False:
+ # if user compare type returns False and not None,
+ # it means "dont do any more type comparison"
+ return PriorityDispatchResult.STOP
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def _dialect_impl_compare_type(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: Optional[str],
+ tname: Union[quoted_name, str],
+ cname: Union[quoted_name, str],
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+) -> PriorityDispatchResult:
+
+ if not _compare_type_setup(
+ alter_column_op, tname, cname, conn_col, metadata_col
+ ):
+ return PriorityDispatchResult.CONTINUE
+
+ migration_context = autogen_context.migration_context
+ is_diff = migration_context.impl.compare_type(conn_col, metadata_col)
+
+ if is_diff:
+ alter_column_op.modify_type = metadata_col.type
+ log.info(
+ "Detected type change from %r to %r on '%s.%s'",
+ conn_col.type,
+ metadata_col.type,
+ tname,
+ cname,
+ )
+ return PriorityDispatchResult.STOP
+
+ return PriorityDispatchResult.CONTINUE
+
+
+def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ _user_compare_type,
+ "column",
+ "types",
+ priority=DispatchPriority.FIRST,
+ )
+ plugin.add_autogenerate_comparator(
+ _dialect_impl_compare_type,
+ "column",
+ "types",
+ priority=DispatchPriority.LAST,
+ )
--- /dev/null
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
+from sqlalchemy.sql.elements import conv
+
+
+class _InspectorConv:
+ __slots__ = ("inspector",)
+
+ def __init__(self, inspector):
+ self.inspector = inspector
+
+ def _apply_reflectinfo_conv(self, consts):
+ if not consts:
+ return consts
+ for const in consts:
+ if const["name"] is not None and not isinstance(
+ const["name"], conv
+ ):
+ const["name"] = conv(const["name"])
+ return consts
+
+ def _apply_constraint_conv(self, consts):
+ if not consts:
+ return consts
+ for const in consts:
+ if const.name is not None and not isinstance(const.name, conv):
+ const.name = conv(const.name)
+ return consts
+
+ def get_indexes(self, *args, **kw):
+ return self._apply_reflectinfo_conv(
+ self.inspector.get_indexes(*args, **kw)
+ )
+
+ def get_unique_constraints(self, *args, **kw):
+ return self._apply_reflectinfo_conv(
+ self.inspector.get_unique_constraints(*args, **kw)
+ )
+
+ def get_foreign_keys(self, *args, **kw):
+ return self._apply_reflectinfo_conv(
+ self.inspector.get_foreign_keys(*args, **kw)
+ )
+
+ def reflect_table(self, table, *, include_columns):
+ self.inspector.reflect_table(table, include_columns=include_columns)
+
+ # I had a cool version of this using _ReflectInfo, however that doesn't
+ # work in 1.4 and it's not public API in 2.x. Then this is just a two
+ # liner. So there's no competition...
+ self._apply_constraint_conv(table.constraints)
+ self._apply_constraint_conv(table.indexes)
None,
]
] = None,
+ autogenerate_plugins: Optional[Sequence[str]] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
:paramref:`.command.revision.process_revision_directives`
+ :param autogenerate_plugins: A list of string names of "plugins" that
+ should participate in this autogenerate run. Defaults to the list
+ ``["alembic.autogenerate.*"]``, which indicates that Alembic's default
+ autogeneration plugins will be used.
+
+ See the section :ref:`plugins_autogenerate` for complete background
+ on how to use this parameter.
+
+ .. versionadded:: 1.18.0 Added a new plugin system for autogenerate
+ compare directives.
+
+ .. seealso::
+
+ :ref:`plugins_autogenerate` - background on enabling/disabling
+ autogenerate plugins
+
+ :ref:`alembic.plugins.toplevel` - Introduction and documentation
+ to the plugin system
+
Parameters specific to individual backends:
:param mssql_batch_separator: The "batch separator" which will
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[OnVersionApplyFn] = None,
+ autogenerate_plugins: Sequence[str] | None = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
:paramref:`.command.revision.process_revision_directives`
+ :param autogenerate_plugins: A list of string names of "plugins" that
+ should participate in this autogenerate run. Defaults to the list
+ ``["alembic.autogenerate.*"]``, which indicates that Alembic's default
+ autogeneration plugins will be used.
+
+ See the section :ref:`plugins_autogenerate` for complete background
+ on how to use this parameter.
+
+ .. versionadded:: 1.18.0 Added a new plugin system for autogenerate
+ compare directives.
+
+ .. seealso::
+
+ :ref:`plugins_autogenerate` - background on enabling/disabling
+ autogenerate plugins
+
+ :ref:`alembic.plugins.toplevel` - Introduction and documentation
+ to the plugin system
+
Parameters specific to individual backends:
:param mssql_batch_separator: The "batch separator" which will
opts["process_revision_directives"] = process_revision_directives
opts["on_version_apply"] = util.to_tuple(on_version_apply, default=())
+ if autogenerate_plugins is not None:
+ opts["autogenerate_plugins"] = autogenerate_plugins
+
if render_item is not None:
opts["render_item"] = render_item
opts["compare_type"] = compare_type
from typing import TYPE_CHECKING
from typing import Union
-from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy.engine import Engine
else:
return None
- def _compare_type(
- self, inspector_column: Column[Any], metadata_column: Column
- ) -> bool:
- if self._user_compare_type is False:
- return False
-
- if callable(self._user_compare_type):
- user_value = self._user_compare_type(
- self,
- inspector_column,
- metadata_column,
- inspector_column.type,
- metadata_column.type,
- )
- if user_value is not None:
- return user_value
-
- return self.impl.compare_type(inspector_column, metadata_column)
-
- def _compare_server_default(
- self,
- inspector_column: Column[Any],
- metadata_column: Column[Any],
- rendered_metadata_default: Optional[str],
- rendered_column_default: Optional[str],
- ) -> bool:
- if self._user_compare_server_default is False:
- return False
-
- if callable(self._user_compare_server_default):
- user_value = self._user_compare_server_default(
- self,
- inspector_column,
- metadata_column,
- rendered_column_default,
- metadata_column.server_default,
- rendered_metadata_default,
- )
- if user_value is not None:
- return user_value
-
- return self.impl.compare_server_default(
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default,
- )
-
class HeadMaintainer:
def __init__(self, context: MigrationContext, heads: Any) -> None:
--- /dev/null
+from __future__ import annotations
+
+from importlib import metadata
+import logging
+import re
+from types import ModuleType
+from typing import Callable
+from typing import Pattern
+from typing import TYPE_CHECKING
+
+from .. import util
+from ..util import DispatchPriority
+from ..util import PriorityDispatcher
+
+if TYPE_CHECKING:
+ from ..util import PriorityDispatchResult
+
+_all_plugins = {}
+
+
+log = logging.getLogger("__name__")
+
+
+class Plugin:
+ """Describe a series of functions that are pulled in as a plugin.
+
+ This is initially to provide for portable lists of autogenerate
+ comparison functions, however the setup for a plugin can run any
+ other kinds of global registration as well.
+
+ .. versionadded:: 1.18.0
+
+ """
+
+ def __init__(self, name: str):
+ self.name = name
+ log.info("setup plugin %s", name)
+ if name in _all_plugins:
+ raise ValueError(f"A plugin named {name} is already registered")
+ _all_plugins[name] = self
+ self.autogenerate_comparators = PriorityDispatcher()
+
+ def remove(self) -> None:
+ """remove this plugin"""
+
+ del _all_plugins[self.name]
+
+ def add_autogenerate_comparator(
+ self,
+ fn: Callable[..., PriorityDispatchResult],
+ compare_target: str,
+ compare_element: str | None = None,
+ *,
+ qualifier: str = "default",
+ priority: DispatchPriority = DispatchPriority.MEDIUM,
+ ) -> None:
+ """Register an autogenerate comparison function.
+
+ See the section :ref:`plugins_registering_autogenerate` for detailed
+ examples on how to use this method.
+
+ :param fn: The comparison function to register. The function receives
+ arguments specific to the type of comparison being performed and
+ should return a :class:`.PriorityDispatchResult` value.
+
+ :param compare_target: The type of comparison being performed
+ (e.g., ``"table"``, ``"column"``, ``"type"``).
+
+ :param compare_element: Optional sub-element being compared within
+ the target type.
+
+ :param qualifier: Database dialect qualifier. Use ``"default"`` for
+ all dialects, or specify a dialect name like ``"postgresql"`` to
+ register a dialect-specific handler. Defaults to ``"default"``.
+
+ :param priority: Execution priority for this comparison function.
+ Functions are executed in priority order from
+ :attr:`.DispatchPriority.FIRST` to :attr:`.DispatchPriority.LAST`.
+ Defaults to :attr:`.DispatchPriority.MEDIUM`.
+
+ """
+ self.autogenerate_comparators.dispatch_for(
+ compare_target,
+ subgroup=compare_element,
+ priority=priority,
+ qualifier=qualifier,
+ )(fn)
+
+ @classmethod
+ def populate_autogenerate_priority_dispatch(
+ cls, comparators: PriorityDispatcher, include_plugins: list[str]
+ ) -> None:
+ """Populate all current autogenerate comparison functions into
+ a given PriorityDispatcher."""
+
+ exclude: set[Pattern[str]] = set()
+ include: dict[str, Pattern[str]] = {}
+
+ matched_expressions: set[str] = set()
+
+ for name in include_plugins:
+ if name.startswith("~"):
+ exclude.add(_make_re(name[1:]))
+ else:
+ include[name] = _make_re(name)
+
+ for plugin in _all_plugins.values():
+ if any(excl.match(plugin.name) for excl in exclude):
+ continue
+
+ include_matches = [
+ incl for incl in include if include[incl].match(plugin.name)
+ ]
+ if not include_matches:
+ continue
+ else:
+ matched_expressions.update(include_matches)
+
+ log.info("setting up autogenerate plugin %s", plugin.name)
+ comparators.populate_with(plugin.autogenerate_comparators)
+
+ never_matched = set(include).difference(matched_expressions)
+ if never_matched:
+ raise util.CommandError(
+ f"Did not locate plugins: {', '.join(never_matched)}"
+ )
+
+ @classmethod
+ def setup_plugin_from_module(cls, module: ModuleType, name: str) -> None:
+ """Call the ``setup()`` function of a plugin module, identified by
+ passing the module object itself.
+
+ E.g.::
+
+ from alembic.runtime.plugins import Plugin
+ import myproject.alembic_plugin
+
+ # Register the plugin manually
+ Plugin.setup_plugin_from_module(
+ myproject.alembic_plugin,
+ "myproject.custom_operations"
+ )
+
+ This will generate a new :class:`.Plugin` object with the given
+ name, which will register itself in the global list of plugins.
+ Then the module's ``setup()`` function is invoked, passing that
+ :class:`.Plugin` object.
+
+ This exact process is invoked automatically at import time for any
+ plugin module that is published via the ``alembic.plugins`` entrypoint.
+
+ """
+ module.setup(Plugin(name))
+
+
+def _make_re(name: str) -> Pattern[str]:
+ tokens = name.split(".")
+
+ reg = r""
+ for token in tokens:
+ if token == "*":
+ reg += r"\..+?"
+ elif token.isidentifier():
+ reg += r"\." + token
+ else:
+ raise ValueError(f"Invalid plugin expression {name!r}")
+
+ # omit leading r'\.'
+ return re.compile(f"^{reg[2:]}$")
+
+
+def _setup() -> None:
+ # setup third party plugins
+ for entrypoint in metadata.entry_points(group="alembic.plugins"):
+ for mod in entrypoint.load():
+ Plugin.setup_plugin_from_module(mod, entrypoint.name)
+
+
+_setup()
from typing import Any
from typing import Dict
+from typing import Literal
+from typing import overload
from typing import Set
from sqlalchemy import CHAR
class AutogenFixtureTest(_ComparesFKs):
+
+ @overload
+ def _fixture(
+ self,
+ m1: MetaData,
+ m2: MetaData,
+ include_schemas=...,
+ opts=...,
+ object_filters=...,
+ name_filters=...,
+ *,
+ return_ops: Literal[True],
+ max_identifier_length=...,
+ ) -> ops.UpgradeOps: ...
+
+ @overload
def _fixture(
self,
- m1,
- m2,
+ m1: MetaData,
+ m2: MetaData,
+ include_schemas=...,
+ opts=...,
+ object_filters=...,
+ name_filters=...,
+ *,
+ return_ops: Literal[False] = ...,
+ max_identifier_length=...,
+ ) -> list[Any]: ...
+
+ def _fixture(
+ self,
+ m1: MetaData,
+ m2: MetaData,
include_schemas=False,
opts=None,
object_filters=_default_object_filters,
name_filters=_default_name_filters,
- return_ops=False,
+ return_ops: bool = False,
max_identifier_length=None,
- ):
+ ) -> ops.UpgradeOps | list[Any]:
if max_identifier_length:
dialect = self.bind.dialect
existing_length = dialect.max_identifier_length
from .langhelpers import asbool as asbool
from .langhelpers import dedupe_tuple as dedupe_tuple
from .langhelpers import Dispatcher as Dispatcher
+from .langhelpers import DispatchPriority as DispatchPriority
from .langhelpers import EMPTY_DICT as EMPTY_DICT
from .langhelpers import immutabledict as immutabledict
from .langhelpers import memoized_property as memoized_property
from .langhelpers import ModuleClsProxy as ModuleClsProxy
from .langhelpers import not_none as not_none
+from .langhelpers import PriorityDispatcher as PriorityDispatcher
+from .langhelpers import PriorityDispatchResult as PriorityDispatchResult
from .langhelpers import rev_id as rev_id
from .langhelpers import to_list as to_list
from .langhelpers import to_tuple as to_tuple
import collections
from collections.abc import Iterable
+import enum
import textwrap
from typing import Any
from typing import Callable
from typing import Set
from typing import Tuple
from typing import Type
-from typing import TYPE_CHECKING
from typing import TypeVar
-from typing import Union
import uuid
import warnings
return tuple(unique_list(tup))
+class PriorityDispatchResult(enum.Enum):
+ """indicate an action after running a function within a
+ :class:`.PriorityDispatcher`
+
+ .. versionadded:: 1.18.0
+
+ """
+
+ CONTINUE = enum.auto()
+ """Continue running more functions.
+
+ Any return value that is not PriorityDispatchResult.STOP is equivalent
+ to this.
+
+ """
+
+ STOP = enum.auto()
+ """Stop running any additional functions within the subgroup"""
+
+
+class DispatchPriority(enum.IntEnum):
+ """Indicate which of three sub-collections a function inside a
+ :class:`.PriorityDispatcher` should be placed.
+
+ .. versionadded:: 1.18.0
+
+ """
+
+ FIRST = 50
+ """Run the funciton in the first batch of functions (highest priority)"""
+
+ MEDIUM = 25
+ """Run the function at normal priority (this is the default)"""
+
+ LAST = 10
+ """Run the function in the last batch of functions"""
+
+
class Dispatcher:
- def __init__(self, uselist: bool = False) -> None:
+ def __init__(self) -> None:
self._registry: Dict[Tuple[Any, ...], Any] = {}
- self.uselist = uselist
def dispatch_for(
- self, target: Any, qualifier: str = "default", replace: bool = False
+ self,
+ target: Any,
+ *,
+ qualifier: str = "default",
+ replace: bool = False,
) -> Callable[[_C], _C]:
def decorate(fn: _C) -> _C:
- if self.uselist:
- self._registry.setdefault((target, qualifier), []).append(fn)
- else:
- if (target, qualifier) in self._registry and not replace:
- raise ValueError(
- "Can not set dispatch function for object "
- f"{target!r}: key already exists. To replace "
- "existing function, use replace=True."
- )
- self._registry[(target, qualifier)] = fn
+ if (target, qualifier) in self._registry and not replace:
+ raise ValueError(
+ "Can not set dispatch function for object "
+ f"{target!r}: key already exists. To replace "
+ "existing function, use replace=True."
+ )
+ self._registry[(target, qualifier)] = fn
return fn
return decorate
else:
targets = type(obj).__mro__
- for spcls in targets:
- if qualifier != "default" and (spcls, qualifier) in self._registry:
- return self._fn_or_list(self._registry[(spcls, qualifier)])
- elif (spcls, "default") in self._registry:
- return self._fn_or_list(self._registry[(spcls, "default")])
+ if qualifier != "default":
+ qualifiers = [qualifier, "default"]
else:
- raise ValueError("no dispatch function for object: %s" % obj)
+ qualifiers = ["default"]
- def _fn_or_list(
- self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
- ) -> Callable[..., Any]:
- if self.uselist:
-
- def go(*arg: Any, **kw: Any) -> None:
- if TYPE_CHECKING:
- assert isinstance(fn_or_list, Sequence)
- for fn in fn_or_list:
- fn(*arg, **kw)
-
- return go
+ for spcls in targets:
+ for qualifier in qualifiers:
+ if (spcls, qualifier) in self._registry:
+ return self._registry[(spcls, qualifier)]
else:
- return fn_or_list # type: ignore
+ raise ValueError("no dispatch function for object: %s" % obj)
def branch(self) -> Dispatcher:
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
- if self.uselist:
- d._registry.update(
- (k, [fn for fn in self._registry[k]]) for k in self._registry
+ d._registry.update(self._registry)
+ return d
+
+
+class PriorityDispatcher:
+ """registers lists of functions at multiple levels of priorty and provides
+ a target to invoke them in priority order.
+
+ .. versionadded:: 1.18.0 - PriorityDispatcher replaces the job
+ of Dispatcher(uselist=True)
+
+ """
+
+ def __init__(self) -> None:
+ self._registry: dict[tuple[Any, ...], Any] = collections.defaultdict(
+ list
+ )
+
+ def dispatch_for(
+ self,
+ target: str,
+ *,
+ priority: DispatchPriority = DispatchPriority.MEDIUM,
+ qualifier: str = "default",
+ subgroup: str | None = None,
+ ) -> Callable[[_C], _C]:
+ """return a decorator callable that registers a function at a
+ given priority, with a given qualifier, to fire off for a given
+ subgroup.
+
+ It's important this remains as a decorator to support third party
+ plugins who are populating the dispatcher using that style.
+
+ """
+
+ def decorate(fn: _C) -> _C:
+ self._registry[(target, qualifier, priority)].append(
+ (fn, subgroup)
)
+ return fn
+
+ return decorate
+
+ def dispatch(
+ self, target: str, *, qualifier: str = "default"
+ ) -> Callable[..., None]:
+ """Provide a callable for the given target and qualifier."""
+
+ if qualifier != "default":
+ qualifiers = [qualifier, "default"]
else:
- d._registry.update(self._registry)
+ qualifiers = ["default"]
+
+ def go(*arg: Any, **kw: Any) -> Any:
+ results_by_subgroup: dict[str, PriorityDispatchResult] = {}
+ for priority in DispatchPriority:
+ for qualifier in qualifiers:
+ for fn, subgroup in self._registry.get(
+ (target, qualifier, priority), ()
+ ):
+ if (
+ results_by_subgroup.get(
+ subgroup, PriorityDispatchResult.CONTINUE
+ )
+ is PriorityDispatchResult.STOP
+ ):
+ continue
+
+ result = fn(*arg, **kw)
+ results_by_subgroup[subgroup] = result
+
+ return go
+
+ def branch(self) -> PriorityDispatcher:
+ """Return a copy of this dispatcher that is independently
+ writable."""
+
+ d = PriorityDispatcher()
+ d.populate_with(self)
return d
+ def populate_with(self, other: PriorityDispatcher) -> None:
+ """Populate this PriorityDispatcher with the contents of another one.
+
+ Additive, does not remove existing contents.
+ """
+ for k in other._registry:
+ new_list = other._registry[k]
+ self._registry[k].extend(new_list)
+
def not_none(value: Optional[_T]) -> _T:
assert value is not None
custom DDL objects representing views, triggers, special constraints,
or anything else we want to support.
+.. _autogenerate_global_comparison_function:
-Registering a Comparison Function
----------------------------------
+Registering a Comparison Function Globally
+------------------------------------------
We now need to register a comparison hook, which will be used
to compare the database to our model and produce ``CreateSequenceOp``
and ``DropSequenceOp`` directives to be included in our migration
-script. Note that we are assuming a
-Postgresql backend::
+script. The example below illustrates registering a comparison function
+using the **global** dispatch::
from alembic.autogenerate import comparators
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import UpgradeOps
+ # new in Alembic 1.18.0 - for older versions, no return value is needed
+ from alembic.util import PriorityDispatchResult
+
+ # the global dispatch includes a decorator function.
+ # for plugin level dispatch, use Plugin.add_autogenerate_comparator()
+ # instead.
@comparators.dispatch_for("schema")
- def compare_sequences(autogen_context, upgrade_ops, schemas):
+ def compare_sequences(
+ autogen_context: AutogenContext,
+ upgrade_ops: UpgradeOps,
+ schemas: set[str | None]
+ ) -> PriorityDispatchResult:
all_conn_sequences = set()
for sch in schemas:
DropSequenceOp(name, schema=sch)
)
+ return PriorityDispatchResult.CONTINUE
+
Above, we've built a new function ``compare_sequences()`` and registered
it as a "schema" level comparison function with autogenerate. The
job that it performs is that it compares the list of sequence names
present in each database schema with that of a list of sequence names
that we are maintaining in our :class:`~sqlalchemy.schema.MetaData` object.
+The registration of our function at the scope of "schema" means our
+autogenerate comparison function is called outside of the context of any
+specific table or column. The four available scopes are "autogenerate" (new in
+1.18.0), "schema", "table", and "column"; these scopes are described fully in
+the section :ref:`plugins_registering_autogenerate`, which details their use in
+terms of a custom plugin, however the interfaces are the same.
+
When autogenerate completes, it will have a series of
``CreateSequenceOp`` and ``DropSequenceOp`` directives in the list of
"upgrade" operations; the list of "downgrade" operations is generated
``CreateSequenceOp.reverse()`` and ``DropSequenceOp.reverse()`` methods
that we've implemented on these objects.
-The registration of our function at the scope of "schema" means our
-autogenerate comparison function is called outside of the context
-of any specific table or column. The three available scopes
-are "schema", "table", and "column", summarized as follows:
-
-* **Schema level** - these hooks are passed a :class:`.AutogenContext`,
- an :class:`.UpgradeOps` collection, and a collection of string schema
- names to be operated upon. If the
- :class:`.UpgradeOps` collection contains changes after all
- hooks are run, it is included in the migration script:
-
- ::
-
- @comparators.dispatch_for("schema")
- def compare_schema_level(autogen_context, upgrade_ops, schemas):
- pass
-
-* **Table level** - these hooks are passed a :class:`.AutogenContext`,
- a :class:`.ModifyTableOps` collection, a schema name, table name,
- a :class:`~sqlalchemy.schema.Table` reflected from the database if any
- or ``None``, and a :class:`~sqlalchemy.schema.Table` present in the
- local :class:`~sqlalchemy.schema.MetaData`. If the
- :class:`.ModifyTableOps` collection contains changes after all
- hooks are run, it is included in the migration script:
-
- ::
-
- @comparators.dispatch_for("table")
- def compare_table_level(autogen_context, modify_ops,
- schemaname, tablename, conn_table, metadata_table):
- pass
-
-* **Column level** - these hooks are passed a :class:`.AutogenContext`,
- an :class:`.AlterColumnOp` object, a schema name, table name,
- column name, a :class:`~sqlalchemy.schema.Column` reflected from the
- database and a :class:`~sqlalchemy.schema.Column` present in the
- local table. If the :class:`.AlterColumnOp` contains changes after
- all hooks are run, it is included in the migration script;
- a "change" is considered to be present if any of the ``modify_`` attributes
- are set to a non-default value, or there are any keys
- in the ``.kw`` collection with the prefix ``"modify_"``:
-
- ::
-
- @comparators.dispatch_for("column")
- def compare_column_level(autogen_context, alter_column_op,
- schemaname, tname, cname, conn_col, metadata_col):
- pass
+The example above illustrates registration with the so-called **global**
+autogenerate dispatch, at ``alembic.autogenerate.comparators``. Alembic as of
+version 1.18 also includes a **plugin level** dispatch, where comparison
+functions are instead registered using
+:meth:`.Plugin.add_autogenerate_comparator`. Comparison functions registered at
+the plugin level operate in the same way as those registered globally, with the
+exception that custom autogenerate compare functions must also be enabled at
+the environment level within the
+:attr:`.EnvironmentContext.configure.autogenerate_plugins` parameter, and also
+have the ability to be omitted from an autogenerate run.
+
+.. seealso::
+
+ :ref:`plugins_registering_autogenerate` - newer plugin-level means of
+ registering autogenerate compare functions.
The :class:`.AutogenContext` passed to these hooks is documented below.
autogenerate
script
ddl
+ plugins
exceptions
--- /dev/null
+.. _alembic.plugins.toplevel:
+
+=======
+Plugins
+=======
+
+.. versionadded:: 1.18.0
+
+Alembic provides a plugin system that allows third-party extensions to
+integrate with Alembic's functionality. Plugins can register custom operations,
+operation implementations, autogenerate comparison functions, and other
+extension points to add new capabilities to Alembic.
+
+The plugin system provides a structured way to organize and distribute these
+extensions, allowing them to be discovered automatically using Python
+entry points.
+
+Overview
+========
+
+The :class:`.Plugin` class provides the foundation for creating plugins.
+A plugin's ``setup()`` function can perform various types of registrations:
+
+* **Custom operations** - Register new operation directives using
+ :meth:`.Operations.register_operation` (e.g., ``op.create_view()``)
+* **Operation implementations** - Provide database-specific implementations
+ using :meth:`.Operations.implementation_for`
+* **Autogenerate comparators** - Add comparison functions for detecting
+ schema differences during autogeneration
+* **Other extensions** - Register any other global handlers or customizations
+
+A single plugin can register handlers across all of these categories. For
+example, a plugin for custom database objects might register both the
+operations to create/drop those objects and the autogenerate logic to
+detect changes to them.
+
+.. seealso::
+
+ :ref:`replaceable_objects` - Cookbook recipe demonstrating custom
+ operations and implementations that would be suitable for packaging
+ as a plugin
+
+Installing and Using Plugins
+============================
+
+Third-party plugins are typically distributed as Python packages that can be
+installed via pip or other package managers::
+
+ pip install mycompany-alembic-plugin
+
+Once installed, plugins that use Python's entry point system are automatically
+discovered and loaded by Alembic at startup, which calls the plugin's
+``setup()`` function to perform any registrations.
+
+Enable Autogenerate Plugins
+---------------------------
+
+For plugins that provide autogenerate comparison functions via the
+:meth:`.Plugin.add_autogenerate_comparator` hook, the specific autogenerate
+functionality registered by the plugin must be enabled with
+:paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter, which
+by default indicates that only Alembic's built-in plugins should be used.
+Note that this step does not apply to older plugins that may be registering
+autogenerate comparison functions globally.
+
+See the section :ref:`plugins_autogenerate` for background on enabling
+autogenerate comparison plugins per environment.
+
+Using Plugins without entry points (such as local plugin code)
+--------------------------------------------------------------
+
+Plugins do not need to be published with entry points to be used. A plugin
+can be manually registered by calling :meth:`.Plugin.setup_plugin_from_module`
+in the ``env.py`` file::
+
+ from alembic.runtime.plugins import Plugin
+ import myproject.alembic_plugin
+
+ # Register the plugin manually
+ Plugin.setup_plugin_from_module(
+ myproject.alembic_plugin,
+ "myproject.custom_operations"
+ )
+
+This approach is useful for project-specific plugins that are not intended
+for distribution, or for testing plugins during development.
+
+.. _plugins_autogenerate:
+
+Enabling Autogenerate Plugins in env.py
+=======================================
+
+If a plugin provides autogenerate functionality that's registered via the
+:meth:`.Plugin.add_autogenerate_comparator` hook, it can be selectively enabled
+or disabled using the
+:paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter in the
+:meth:`.EnvironmentContext.configure` call, typically as used within the
+``env.py`` file. This parameter is passed as a list of strings each naming a
+specific plugin or a matching wildcard. The default value is
+``["alembic.autogenerate.*"]`` which indicates that the full set of Alembic's
+internal plugins should be used.
+
+The :paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter
+accepts a list of string patterns:
+
+* Simple names match plugin names exactly: ``"alembic.autogenerate.tables"``
+* Wildcards match multiple plugins: ``"alembic.autogenerate.*"`` matches all
+ built-in plugins
+* Negation patterns exclude plugins: ``"~alembic.autogenerate.comments"``
+ excludes the comments plugin
+
+For example, to use all built-in plugins except comments, plus a custom
+plugin::
+
+ context.configure(
+ # ...
+ autogenerate_plugins=[
+ "alembic.autogenerate.*",
+ "~alembic.autogenerate.comments",
+ "mycompany.custom_types",
+ ]
+ )
+
+The wildcard syntax using ``*`` indicates that tokens in that segment
+of the name (separated by period characters) will match any name. For
+Alembic's ``alembic.autogenerate.*`` namespace, the built in names being
+invoked are:
+
+* ``alembic.autogenerate.schemas`` - Schema creation and dropping
+* ``alembic.autogenerate.tables`` - Table creation, dropping, and modification.
+ This plugin depends on the ``schemas`` plugin in order to iterate through
+ tables.
+* ``alembic.autogenerate.types`` - Column type changes. This plugin depends on
+ the ``tables`` plugin in order to iterate through columns.
+* ``alembic.autogenerate.constraints`` - Constraint creation and dropping. This
+ plugin depends on the ``tables`` plugin in order to iterate through columns.
+* ``alembic.autogenerate.defaults`` - Server default changes. This plugin
+ depends on the ``tables`` plugin in order to iterate through columns.
+* ``alembic.autogenerate.comments`` - Table and column comment changes. This
+ plugin depends on the ``tables`` plugin in order to iterate through columns.
+
+While these names can be specified individually, they are subject to change
+as Alembic evolves. Using the wildcard pattern is recommended.
+
+Omitting the built-in plugins entirely would prevent autogeneration from
+proceeding, unless other plugins were provided that replaced its functionality
+(which is possible!). Additionally, as noted above, the column-oriented plugins
+rely on the table- and schema- oriented plugins in order to receive iterated
+columns.
+
+The :paramref:`.EnvironmentContext.configure.autogenerate_plugins`
+parameter only controls which plugins participate in autogenerate
+operations. Other plugin functionality, such as custom operations
+registered with :meth:`.Operations.register_operation`, is available
+regardless of this setting.
+
+
+
+
+Writing a Plugin
+================
+
+Creating a Plugin Module
+-------------------------
+
+A plugin module must define a ``setup()`` function that accepts a
+:class:`.Plugin` instance. This function is called when the plugin is
+loaded, either automatically via entry points or manually via
+:meth:`.Plugin.setup_plugin_from_module`::
+
+ from alembic import op
+ from alembic.operations import Operations
+ from alembic.runtime.plugins import Plugin
+ from alembic.util import DispatchPriority
+
+ def setup(plugin: Plugin) -> None:
+ """Setup function called by Alembic when loading the plugin."""
+
+ # Register custom operations
+ Operations.register_operation("create_view")(CreateViewOp)
+ Operations.implementation_for(CreateViewOp)(create_view_impl)
+
+ # Register autogenerate comparison functions
+ plugin.add_autogenerate_comparator(
+ _compare_views,
+ "view",
+ qualifier="default",
+ priority=DispatchPriority.MEDIUM,
+ )
+
+The ``setup()`` function serves as the entry point for all plugin
+registrations. It can call various Alembic APIs to extend functionality.
+
+Publishing a Plugin
+-------------------
+
+To make a plugin available for installation via pip, create a package with
+an entry point in ``pyproject.toml``::
+
+ [project.entry-points."alembic.plugins"]
+ mycompany.plugin_name = "mycompany.alembic_plugin"
+
+Where ``mycompany.alembic_plugin`` is the module containing the ``setup()``
+function.
+
+When the package is installed, Alembic automatically discovers and loads the
+plugin through the entry point system. If the plugin provides autogenerate
+functionality, users can then enable it by adding its name
+``mycompany.plugin_name`` to the ``autogenerate_plugins`` list in their
+``env.py``.
+
+Registering Custom Operations
+------------------------------
+
+Plugins can register new operation directives that become available as
+``op.custom_operation()`` in migration scripts. This is done using
+:meth:`.Operations.register_operation` and
+:meth:`.Operations.implementation_for`.
+
+Example from the :ref:`replaceable_objects` recipe::
+
+ from alembic.operations import Operations, MigrateOperation
+
+ class CreateViewOp(MigrateOperation):
+ def __init__(self, view_name, select_stmt):
+ self.view_name = view_name
+ self.select_stmt = select_stmt
+
+ @Operations.register_operation("create_view")
+ class CreateViewOp(CreateViewOp):
+ pass
+
+ @Operations.implementation_for(CreateViewOp)
+ def create_view(operations, operation):
+ operations.execute(
+ f"CREATE VIEW {operation.view_name} AS {operation.select_stmt}"
+ )
+
+These registrations can be performed in the plugin's ``setup()`` function,
+making the custom operations available globally.
+
+.. seealso::
+
+ :ref:`replaceable_objects` - Complete example of registering custom
+ operations
+
+ :ref:`operation_plugins` - Documentation on the operations plugin system
+
+.. _plugins_registering_autogenerate:
+
+Registering Autogenerate Comparators at the Plugin Level
+--------------------------------------------------------
+
+Plugins can register comparison functions that participate in the autogenerate
+process, detecting differences between database schema and SQLAlchemy metadata.
+These functions may be registered globally, where they take place
+unconditionally as documented at
+:ref:`autogenerate_global_comparison_function`; for older versions of Alembic
+prior to 1.18.0 this is the only registration system available. However when
+targeting Alembic 1.18.0 or higher, the :class:`.Plugin` approach provides a
+more configurable version of these registration hooks.
+
+Plugin level comparison functions are registered using
+:meth:`.Plugin.add_autogenerate_comparator`. Each comparison function
+establishes itself as part of a named "target", which is invoked by a parent
+handler. For example, if a handler establishes itself as part of the
+``"column"`` target, it will be invoked when the
+``alembic.autogenerate.tables`` plugin proceeds through SQLAlchemy ``Table``
+objects and invokes comparison operations for pairs of same-named columns.
+
+For an example of a complete comparison function, see the example at
+:ref:`autogenerate_global_comparison_function`.
+
+The current levels of comparison are the same between global and plugin-level
+comparison functions, and include:
+
+* ``"autogenerate"`` - this target is invoked at the top of the autogenerate
+ chain. These hooks are passed a :class:`.AutogenContext` and an
+ :class:`.UpgradeOps` collection. Functions that subscribe to the
+ ``autogenerate`` target should look like::
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import UpgradeOps
+ from alembic.runtime.plugins import Plugin
+ from alembic.util import PriorityDispatchResult
+
+ def autogen_toplevel(
+ autogen_context: AutogenContext, upgrade_ops: UpgradeOps
+ ) -> PriorityDispatchResult:
+ # ...
+
+
+ def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(autogen_toplevel, "autogenerate")
+
+ The function should return either :attr:`.PriorityDispatchResult.CONTINUE` or
+ :attr:`.PriorityDispatchResult.STOP` to halt any further comparisons from
+ proceeding, and should respond to detected changes by mutating the given
+ :class:`.UpgradeOps` collection in place (the :class:`.DowngradeOps` version
+ is produced later by reversing the :class:`.UpgradeOps`).
+
+ An autogenerate compare function that seeks to run entirely independently of
+ Alembic's built-in autogenerate plugins, or to replace them completely, would
+ register at the ``"autogenerate"`` level. The remaining levels indicated
+ below are all invoked from within Alembic's own autogenerate plugins and will
+ not take place if ``alembic.autogenerate.*`` is not enabled.
+
+ .. versionadded:: 1.18.0 The ``"autogenerate"`` comparison scope was
+ introduced, replacing ``"schema"`` as the topmost comparison scope.
+
+* ``"schema"`` - this target is invoked for each individual "schema" being
+ compared, and hooks are passed a :class:`.AutogenContext`, an
+ :class:`.UpgradeOps` collection, and a set of schema names, featuring the
+ value ``None`` for the "default" schema. Functions that subscribe to the
+ ``"schema"`` target should look like::
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import UpgradeOps
+ from alembic.runtime.plugins import Plugin
+ from alembic.util import PriorityDispatchResult
+
+ def autogen_for_tables(
+ autogen_context: AutogenContext,
+ upgrade_ops: UpgradeOps,
+ schemas: set[str | None],
+ ) -> PriorityDispatchResult:
+ # ...
+
+ def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(
+ autogen_for_tables,
+ "schema",
+ "tables",
+ )
+
+ The function should normally return :attr:`.PriorityDispatchResult.CONTINUE`
+ and should respond to detected changes by mutating the given
+ :class:`.UpgradeOps` collection in place (the :class:`.DowngradeOps` version
+ is produced later by reversing the :class:`.UpgradeOps`).
+
+ The registration example above includes the ``"tables"`` "compare element",
+ which is optional. This indicates that the comparison function is part of a
+ chain called "tables", which is what Alembic's own
+ ``alembic.autogenerate.tables`` plugin uses. If our custom comparison
+ function were to return the value :attr:`.PriorityDispatchResult.STOP`,
+ further comparison functions in the ``"tables"`` chain would not be called.
+ Similarly, if another plugin in the ``"tables"`` chain returned
+ :attr:`.PriorityDispatchResult.STOP`, then our plugin would not be called.
+ Making use of :attr:`.PriorityDispatchResult.STOP` in terms of other plugins
+ in the same "compare element" may be assisted by placing our function in the
+ comparator chain using :attr:`.DispatchPriority.FIRST` or
+ :attr:`.DispatchPriority.LAST` when registering.
+
+* ``"table"`` - this target is invoked per ``Table`` being compared between a
+ database autoloaded version and the local metadata version. These hooks are
+ passed an :class:`.AutogenContext`, a :class:`.ModifyTableOps` collection, a
+ schema name, table name, a :class:`~sqlalchemy.schema.Table` reflected from
+ the database if any or ``None``, and a :class:`~sqlalchemy.schema.Table`
+ present in the local :class:`~sqlalchemy.schema.MetaData`. If the
+ :class:`.ModifyTableOps` collection contains changes after all hooks are run,
+ it is included in the migration script::
+
+ from sqlalchemy import quoted_name
+ from sqlalchemy import Table
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.runtime.plugins import Plugin
+ from alembic.util import PriorityDispatchResult
+
+ def compare_tables(
+ autogen_context: AutogenContext,
+ modify_table_ops: ModifyTableOps,
+ schema: str | None,
+ tname: quoted_name | str,
+ conn_table: Table | None,
+ metadata_table: Table | None,
+ ) -> PriorityDispatchResult:
+ # ...
+
+
+ def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(compare_tables, "table")
+
+ This hook may be used to compare elements of tables, such as comments
+ or database-specific storage configurations. It should mutate the given
+ :class:`.ModifyTableOps` object in place to add new change operations.
+
+* ``"column"`` - this target is invoked per ``Column`` being compared between a
+ database autoloaded version and the local metadata version.
+ These hooks are passed an :class:`.AutogenContext`,
+ an :class:`.AlterColumnOp` object, a schema name, table name,
+ column name, a :class:`~sqlalchemy.schema.Column` reflected from the
+ database and a :class:`~sqlalchemy.schema.Column` present in the
+ local table. If the :class:`.AlterColumnOp` contains changes after
+ all hooks are run, it is included in the migration script;
+ a "change" is considered to be present if any of the ``modify_`` attributes
+ are set to a non-default value, or there are any keys
+ in the ``.kw`` collection with the prefix ``"modify_"``::
+
+ from typing import Any
+ from sqlalchemy import quoted_name
+ from sqlalchemy import Table
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import AlterColumnOp
+ from alembic.runtime.plugins import Plugin
+ from alembic.util import PriorityDispatchResult
+
+ def compare_columns(
+ autogen_context: AutogenContext,
+ alter_column_op: AlterColumnOp,
+ schema: str | None,
+ tname: quoted_name | str,
+ cname: quoted_name | str,
+ conn_col: Column[Any],
+ metadata_col: Column[Any],
+ ) -> PriorityDispatchResult:
+ # ...
+
+
+ def setup(plugin: Plugin) -> None:
+ plugin.add_autogenerate_comparator(compare_columns, "column")
+
+ Pre-existing compare chains within the ``"column"`` target include
+ ``"comment"``, ``"server_default"``, and ``"types"``. Comparison functions
+ here should mutate the given :class:`.AlterColumnOp` object in place to add
+ new change operations.
+
+.. seealso::
+
+ :ref:`alembic.autogenerate.toplevel` - Detailed documentation on the
+ autogenerate system
+
+ :ref:`autogenerate_global_comparison_function` - a companion section
+ to this one which explains autogenerate comparison functions in terms of
+ the older "global" dispatch, but also includes a complete example of a
+ comparison function.
+
+ :ref:`customizing_revision` - Customizing autogenerate behavior
+
+
+Plugin API Reference
+====================
+
+.. autoclass:: alembic.runtime.plugins.Plugin
+ :members:
+
+.. autoclass:: alembic.util.langhelpers.PriorityDispatchResult
+ :members:
+
+.. autoclass:: alembic.util.langhelpers.DispatchPriority
+ :members:
+
+.. seealso::
+
+ :paramref:`.EnvironmentContext.configure.autogenerate_plugins` -
+ Configuration parameter for enabling autogenerate plugins
+
+ :ref:`operation_plugins` - Documentation on custom operations
+
+ :ref:`replaceable_objects` - Example of custom operations suitable
+ for a plugin
+
+ :ref:`customizing_revision` - General information on customizing
+ autogenerate behavior
==========
.. changelog::
- :version: 1.17.3
+ :version: 1.18.0
:include_notes_from: unreleased
.. changelog::
def drop_sp(operations, operation):
operations.execute("DROP FUNCTION %s" % operation.target.name)
+Publish the Extensions
+----------------------
+
All of the above code can be present anywhere within an application's
source tree; the only requirement is that when the ``env.py`` script is
invoked, it includes imports that ultimately call upon these classes
as well as the :meth:`.Operations.register_operation` and
:meth:`.Operations.implementation_for` sequences.
+Alternatively, custom operations and autogenerate support can be organized
+into reusable plugins using Alembic's plugin system. This allows extensions
+to be packaged and distributed independently, and automatically discovered
+via Python entry points. See :ref:`alembic.plugins.toplevel` for information
+on writing and publishing plugins.
+
Create Initial Migrations
-------------------------
--- /dev/null
+.. change::
+ :tags: feature, autogenerate
+
+ Release 1.18.0 introduces a plugin system that allows for automatic
+ loading of third-party extensions as well as configurable autogenerate
+ compare functionality on a per-environment basis.
+
+ The :class:`.Plugin` class provides a common interface for extensions that
+ register handlers among Alembic's existing extension points such as
+ :meth:`.Operations.register_operation` and
+ :meth:`.Operations.implementation_for`. A new interface for registering
+ autogenerate comparison handlers,
+ :meth:`.Plugin.add_autogenerate_comparator`, provides for autogenerate
+ compare functionality that may be custom-configured on a per-environment
+ basis using the new
+ :paramref:`.EnvironmentContext.configure.autogenerate_plugins` parameter.
+
+ The change does not impact well known Alembic add-ons such as
+ ``alembic-utils``, which continue to work as before; however, such add-ons
+ have the option to provide plugin entrypoints going forward.
+
+ As part of this change, Alembic's autogenerate compare functionality is
+ reorganized into a series of internal plugins under the
+ ``alembic.autogenerate`` namespace, which may be individually or
+ collectively identified for inclusion and/or exclusion within the
+ :meth:`.EnvironmentContext.configure` call using a new parameter
+ :paramref:`.EnvironmentContext.configure.autogenerate_plugins`. This
+ parameter is also where third party comparison plugins may also be
+ indicated.
+
+ See :ref:`alembic.plugins.toplevel` for complete documentation on
+ the new :class:`.Plugin` class as well as autogenerate-specific usage
+ instructions.
from alembic import autogenerate
from alembic import testing
from alembic.autogenerate import api
+from alembic.autogenerate.compare.tables import _compare_tables
from alembic.migration import MigrationContext
from alembic.operations import ops
from alembic.testing import assert_raises_message
from alembic.testing.suite._autogen_fixtures import ModelOne
from alembic.util import CommandError
+if True:
+ from alembic.autogenerate.compare.types import (
+ _dialect_impl_compare_type as _compare_type,
+ )
+
# TODO: we should make an adaptation of CompareMetadataToInspectorTest that is
# more well suited towards generic backends (2021-06-10)
def test_skip_null_type_comparison_reflected(self):
ac = ops.AlterColumnOp("sometable", "somecol")
- autogenerate.compare._compare_type(
+ _compare_type(
self.autogen_context,
ac,
None,
def test_skip_null_type_comparison_local(self):
ac = ops.AlterColumnOp("sometable", "somecol")
- autogenerate.compare._compare_type(
+ _compare_type(
self.autogen_context,
ac,
None,
impl = Integer
ac = ops.AlterColumnOp("sometable", "somecol")
- autogenerate.compare._compare_type(
+ _compare_type(
self.autogen_context,
ac,
None,
assert not ac.has_changes()
ac = ops.AlterColumnOp("sometable", "somecol")
- autogenerate.compare._compare_type(
+ _compare_type(
self.autogen_context,
ac,
None,
return dialect.type_descriptor(CHAR(32))
uo = ops.AlterColumnOp("sometable", "somecol")
- autogenerate.compare._compare_type(
+ _compare_type(
self.autogen_context,
uo,
None,
inspector = inspect(self.bind)
uo = ops.UpgradeOps(ops=[])
- autogenerate.compare._compare_tables(
+ _compare_tables(
OrderedSet([(None, "extra"), (None, "user")]),
OrderedSet(),
inspector,
my_compare_type.return_value = False
self.context._user_compare_type = my_compare_type
- diffs = []
ctx = self.autogen_context
- diffs = []
- autogenerate._produce_net_changes(ctx, diffs)
- eq_(diffs, [])
+ uo = ops.UpgradeOps(ops=[])
+ autogenerate._produce_net_changes(ctx, uo)
+
+ eq_(uo.as_diffs(), [])
def test_column_type_modified_custom_compare_type_returns_True(self):
my_compare_type = mock.Mock()
--- /dev/null
+"""Test the Dispatcher and PriorityDispatcher utilities."""
+
+from alembic import testing
+from alembic.testing import eq_
+from alembic.testing.fixtures import TestBase
+from alembic.util import Dispatcher
+from alembic.util import DispatchPriority
+from alembic.util import PriorityDispatcher
+from alembic.util import PriorityDispatchResult
+
+
+class DispatcherTest(TestBase):
+ """Tests for the Dispatcher class."""
+
+ def test_dispatch_for_decorator(self):
+ """Test basic decorator registration."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ return "handler1"
+
+ fn = dispatcher.dispatch("target1")
+ eq_(fn(), "handler1")
+
+ def test_dispatch_with_args_kwargs(self):
+ """Test that arguments are passed through to handlers."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler(arg1, kwarg1=None):
+ return (arg1, kwarg1)
+
+ fn = dispatcher.dispatch("target1")
+ result = fn("value1", kwarg1="value2")
+ eq_(result, ("value1", "value2"))
+
+ def test_dispatch_for_qualifier(self):
+ """Test registration with qualifier."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1", qualifier="postgresql")
+ def handler_pg():
+ return "postgresql"
+
+ @dispatcher.dispatch_for("target1", qualifier="default")
+ def handler_default():
+ return "default"
+
+ fn_pg = dispatcher.dispatch("target1", qualifier="postgresql")
+ eq_(fn_pg(), "postgresql")
+
+ fn_default = dispatcher.dispatch("target1", qualifier="default")
+ eq_(fn_default(), "default")
+
+ def test_dispatch_qualifier_fallback(self):
+ """Test that non-default qualifier falls back to default."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler_default():
+ return "default"
+
+ # Request with specific qualifier should fallback to default
+ fn = dispatcher.dispatch("target1", qualifier="mysql")
+ eq_(fn(), "default")
+
+ def test_dispatch_type_target(self):
+ """Test dispatching with type targets using MRO."""
+ dispatcher = Dispatcher()
+
+ class Base:
+ pass
+
+ class Child(Base):
+ pass
+
+ @dispatcher.dispatch_for(Base)
+ def handler_base():
+ return "base"
+
+ # Dispatching with Child should find Base handler via MRO
+ fn = dispatcher.dispatch(Child())
+ eq_(fn(), "base")
+
+ def test_dispatch_type_class_vs_instance(self):
+ """Test dispatching with type vs instance."""
+ dispatcher = Dispatcher()
+
+ class MyClass:
+ pass
+
+ @dispatcher.dispatch_for(MyClass)
+ def handler():
+ return "handler"
+
+ # Both class and instance should work
+ fn_class = dispatcher.dispatch(MyClass)
+ eq_(fn_class(), "handler")
+
+ fn_instance = dispatcher.dispatch(MyClass())
+ eq_(fn_instance(), "handler")
+
+ def test_dispatch_no_match_raises(self):
+ """Test that dispatching with no match raises ValueError."""
+ dispatcher = Dispatcher()
+
+ with testing.expect_raises_message(ValueError, "no dispatch function"):
+ dispatcher.dispatch("nonexistent")
+
+ def test_dispatch_replace_false_raises(self):
+ """Test that duplicate registration raises ValueError."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ return "handler1"
+
+ with testing.expect_raises_message(ValueError, "key already exists"):
+
+ @dispatcher.dispatch_for("target1")
+ def handler2():
+ return "handler2"
+
+ def test_dispatch_replace_true_works(self):
+ """Test that replace=True allows overwriting."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ return "handler1"
+
+ @dispatcher.dispatch_for("target1", replace=True)
+ def handler2():
+ return "handler2"
+
+ fn = dispatcher.dispatch("target1")
+ eq_(fn(), "handler2")
+
+ def test_branch(self):
+ """Test that branch creates independent copy."""
+ dispatcher = Dispatcher()
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ return "handler1"
+
+ dispatcher2 = dispatcher.branch()
+
+ # Add to branch should not affect original
+ @dispatcher2.dispatch_for("target2")
+ def handler2():
+ return "handler2"
+
+ # Original should not have target2
+ with testing.expect_raises(ValueError):
+ dispatcher.dispatch("target2")
+
+ # Branch should have both
+ fn1 = dispatcher2.dispatch("target1")
+ eq_(fn1(), "handler1")
+ fn2 = dispatcher2.dispatch("target2")
+ eq_(fn2(), "handler2")
+
+
+class PriorityDispatcherTest(TestBase):
+ """Tests for the PriorityDispatcher class."""
+
+ def test_dispatch_for_decorator(self):
+ """Test basic decorator registration."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ results.append("handler1")
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ eq_(results, ["handler1"])
+
+ def test_dispatch_target_not_registered(self):
+ """Test that dispatching unregistered target returns noop."""
+ dispatcher = PriorityDispatcher()
+
+ # Unlike regular Dispatcher, PriorityDispatcher returns a noop
+ # function for unregistered targets
+ fn = dispatcher.dispatch("nonexistent")
+ # Should not raise, just return a callable that does nothing
+ fn()
+
+ def test_dispatch_with_priority(self):
+ """Test that handlers execute in priority order."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", priority=DispatchPriority.LAST)
+ def handler_last():
+ results.append("last")
+
+ @dispatcher.dispatch_for("target1", priority=DispatchPriority.FIRST)
+ def handler_first():
+ results.append("first")
+
+ @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM)
+ def handler_medium():
+ results.append("medium")
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ eq_(results, ["first", "medium", "last"])
+
+ def test_dispatch_with_subgroup(self):
+ """Test that subgroups track results independently."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", subgroup="group1")
+ def handler1():
+ results.append("group1")
+ return PriorityDispatchResult.CONTINUE
+
+ @dispatcher.dispatch_for("target1", subgroup="group2")
+ def handler2():
+ results.append("group2")
+ return PriorityDispatchResult.CONTINUE
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ eq_(results, ["group1", "group2"])
+
+ def test_dispatch_stop_result(self):
+ """Test that STOP prevents further execution in subgroup."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for(
+ "target1", priority=DispatchPriority.FIRST, subgroup="group1"
+ )
+ def handler1():
+ results.append("handler1")
+ return PriorityDispatchResult.STOP
+
+ @dispatcher.dispatch_for(
+ "target1", priority=DispatchPriority.MEDIUM, subgroup="group1"
+ )
+ def handler2():
+ results.append("handler2") # Should not execute
+ return PriorityDispatchResult.CONTINUE
+
+ @dispatcher.dispatch_for(
+ "target1", priority=DispatchPriority.FIRST, subgroup="group2"
+ )
+ def handler3():
+ results.append("handler3") # Should execute
+ return PriorityDispatchResult.CONTINUE
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ # handler2 should not run because handler1 returned STOP for group1
+ # handler3 should run because it's in a different subgroup
+ eq_(results, ["handler1", "handler3"])
+
+ def test_dispatch_with_qualifier(self):
+ """Test dispatching with qualifiers includes both specific and
+ default."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", qualifier="postgresql")
+ def handler_pg():
+ results.append("postgresql")
+
+ @dispatcher.dispatch_for("target1", qualifier="default")
+ def handler_default():
+ results.append("default")
+
+ fn_pg = dispatcher.dispatch("target1", qualifier="postgresql")
+ fn_pg()
+ # Should run both postgresql and default handlers
+ eq_(results, ["postgresql", "default"])
+
+ def test_dispatch_qualifier_fallback(self):
+ """Test that non-default qualifier also executes default handlers."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", qualifier="default")
+ def handler_default():
+ results.append("default")
+
+ # Request with specific qualifier should also run default
+ fn = dispatcher.dispatch("target1", qualifier="mysql")
+ fn()
+ eq_(results, ["default"])
+
+ def test_dispatch_with_args_kwargs(self):
+ """Test that arguments are passed through to handlers."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1")
+ def handler(arg1, kwarg1=None):
+ results.append((arg1, kwarg1))
+
+ fn = dispatcher.dispatch("target1")
+ fn("value1", kwarg1="value2")
+ eq_(results, [("value1", "value2")])
+
+ def test_multiple_handlers_same_priority(self):
+ """Test multiple handlers at same priority execute in order."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM)
+ def handler1():
+ results.append("handler1")
+
+ @dispatcher.dispatch_for("target1", priority=DispatchPriority.MEDIUM)
+ def handler2():
+ results.append("handler2")
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ # Both should execute
+ eq_(results, ["handler1", "handler2"])
+
+ def test_branch(self):
+ """Test that branch creates independent copy."""
+ dispatcher = PriorityDispatcher()
+ results1 = []
+
+ @dispatcher.dispatch_for("target1")
+ def handler1():
+ results1.append("handler1")
+
+ dispatcher2 = dispatcher.branch()
+ results2 = []
+
+ @dispatcher2.dispatch_for("target2")
+ def handler2():
+ results2.append("handler2")
+
+ # Original should have target1
+ fn1 = dispatcher.dispatch("target1")
+ fn1()
+ eq_(results1, ["handler1"])
+
+ # Branch should have both
+ fn1_branch = dispatcher2.dispatch("target1")
+ fn2_branch = dispatcher2.dispatch("target2")
+ fn1_branch()
+ fn2_branch()
+ eq_(results1, ["handler1", "handler1"])
+ eq_(results2, ["handler2"])
+
+ def test_populate_with(self):
+ """Test populate_with method."""
+ dispatcher1 = PriorityDispatcher()
+ results = []
+
+ @dispatcher1.dispatch_for("target1")
+ def handler1():
+ results.append("handler1")
+
+ dispatcher2 = PriorityDispatcher()
+
+ @dispatcher2.dispatch_for("target2")
+ def handler2():
+ results.append("handler2")
+
+ # Populate dispatcher2 with dispatcher1's handlers
+ dispatcher2.populate_with(dispatcher1)
+
+ # dispatcher2 should now have both handlers
+ fn1 = dispatcher2.dispatch("target1")
+ fn2 = dispatcher2.dispatch("target2")
+ fn1()
+ fn2()
+ eq_(results, ["handler1", "handler2"])
+
+ def test_none_subgroup(self):
+ """Test that None subgroup is tracked separately."""
+ dispatcher = PriorityDispatcher()
+ results = []
+
+ @dispatcher.dispatch_for("target1", subgroup=None)
+ def handler1():
+ results.append("none")
+ return PriorityDispatchResult.STOP
+
+ @dispatcher.dispatch_for("target1", subgroup=None)
+ def handler2():
+ results.append("none2") # Should not execute
+ return PriorityDispatchResult.CONTINUE
+
+ @dispatcher.dispatch_for("target1", subgroup="other")
+ def handler3():
+ results.append("other") # Should execute
+ return PriorityDispatchResult.CONTINUE
+
+ fn = dispatcher.dispatch("target1")
+ fn()
+ eq_(results, ["none", "other"])
from alembic import op
from alembic import util
from alembic.autogenerate import api
-from alembic.autogenerate import compare
+from alembic.autogenerate.compare.constraints import _compare_nullable
from alembic.migration import MigrationContext
from alembic.operations import ops
from alembic.testing import assert_raises_message
from alembic.testing.fixtures import op_fixture
from alembic.testing.fixtures import TestBase
+if True:
+ from alembic.autogenerate.compare.server_defaults import (
+ _user_compare_server_default,
+ )
+ from alembic.autogenerate.compare.types import (
+ _dialect_impl_compare_type as _compare_type,
+ )
+
class MySQLOpTest(TestBase):
def test_create_table_with_comment(self):
operation = ops.AlterColumnOp("t", "c")
for fn in (
- compare._compare_nullable,
- compare._compare_type,
- compare._compare_server_default,
+ _compare_nullable,
+ _compare_type,
+ # note that _user_compare_server_default does not actually
+ # do a server default compare here, compare_server_default
+ # is False so this just assigns the existing default to the
+ # AlterColumnOp
+ _user_compare_server_default,
):
fn(
autogen_context,
"t",
"c",
Column("c", Float(), nullable=False, server_default=text("0")),
- Column("c", Float(), nullable=True, default=0),
+ Column("c", Float(), nullable=True, server_default=text("0")),
)
+
op.invoke(operation)
context.assert_("ALTER TABLE t MODIFY c FLOAT NULL DEFAULT 0")
--- /dev/null
+"""Test the Plugin class and plugin system."""
+
+from types import ModuleType
+from unittest import mock
+
+from alembic import testing
+from alembic import util
+from alembic.runtime.plugins import _all_plugins
+from alembic.runtime.plugins import _make_re
+from alembic.runtime.plugins import Plugin
+from alembic.testing import eq_
+from alembic.testing.fixtures import TestBase
+from alembic.util import DispatchPriority
+from alembic.util import PriorityDispatcher
+from alembic.util import PriorityDispatchResult
+
+
+class PluginTest(TestBase):
+ """Tests for the Plugin class."""
+
+ @testing.fixture(scope="function", autouse=True)
+ def _clear_plugin_registry(self):
+ """Clear plugin registry before each test and restore after."""
+ # Save original plugins
+ original_plugins = _all_plugins.copy()
+ _all_plugins.clear()
+
+ yield
+
+ # Restore plugin registry after test
+ _all_plugins.clear()
+ _all_plugins.update(original_plugins)
+
+ def test_plugin_creation(self):
+ """Test basic plugin creation."""
+ plugin = Plugin("test.plugin")
+ eq_(plugin.name, "test.plugin")
+ assert "test.plugin" in _all_plugins
+ eq_(_all_plugins["test.plugin"], plugin)
+
+ def test_plugin_creation_duplicate_raises(self):
+ """Test that duplicate plugin names raise ValueError."""
+ Plugin("test.plugin")
+ with testing.expect_raises_message(
+ ValueError, "A plugin named test.plugin is already registered"
+ ):
+ Plugin("test.plugin")
+
+ def test_plugin_remove(self):
+ """Test plugin removal."""
+ plugin = Plugin("test.plugin")
+ assert "test.plugin" in _all_plugins
+ plugin.remove()
+ assert "test.plugin" not in _all_plugins
+
+ def test_add_autogenerate_comparator(self):
+ """Test adding autogenerate comparison functions."""
+ plugin = Plugin("test.plugin")
+
+ def my_comparator():
+ return PriorityDispatchResult.CONTINUE
+
+ plugin.add_autogenerate_comparator(
+ my_comparator,
+ "table",
+ "column",
+ qualifier="postgresql",
+ priority=DispatchPriority.FIRST,
+ )
+
+ # Verify it was registered in the dispatcher
+ fn = plugin.autogenerate_comparators.dispatch(
+ "table", qualifier="postgresql"
+ )
+ # The dispatcher returns a callable, call it to verify
+ fn()
+
+ def test_populate_autogenerate_priority_dispatch_simple(self):
+ """Test populating dispatcher with simple include pattern."""
+ plugin1 = Plugin("test.plugin1")
+ plugin2 = Plugin("test.plugin2")
+
+ mock1 = mock.Mock()
+ mock2 = mock.Mock()
+
+ plugin1.add_autogenerate_comparator(mock1, "test")
+ plugin2.add_autogenerate_comparator(mock2, "test")
+
+ dispatcher = PriorityDispatcher()
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.plugin1"]
+ )
+
+ # Should have plugin1's handler, but not plugin2's
+ fn = dispatcher.dispatch("test")
+ fn()
+ eq_(mock1.mock_calls, [mock.call()])
+ eq_(mock2.mock_calls, [])
+
+ def test_populate_autogenerate_priority_dispatch_wildcard(self):
+ """Test populating dispatcher with wildcard pattern."""
+ plugin1_alpha = Plugin("test.plugin1.alpha")
+ plugin1_beta = Plugin("test.plugin1.beta")
+ plugin2_gamma = Plugin("test.plugin2.gamma")
+
+ mock_alpha = mock.Mock()
+ mock_beta = mock.Mock()
+ mock_gamma = mock.Mock()
+
+ plugin1_alpha.add_autogenerate_comparator(mock_alpha, "test")
+ plugin1_beta.add_autogenerate_comparator(mock_beta, "test")
+ plugin2_gamma.add_autogenerate_comparator(mock_gamma, "test")
+
+ dispatcher = PriorityDispatcher()
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.plugin1.*"]
+ )
+
+ # Both test.plugin1.* should be included
+ # test.plugin2.* should not be included
+ fn = dispatcher.dispatch("test")
+ fn()
+ eq_(mock_alpha.mock_calls, [mock.call()])
+ eq_(mock_beta.mock_calls, [mock.call()])
+ eq_(mock_gamma.mock_calls, [])
+
+ def test_populate_autogenerate_priority_dispatch_exclude(self):
+ """Test populating dispatcher with exclude pattern."""
+ plugin1 = Plugin("test.plugin1")
+ plugin2 = Plugin("test.plugin2")
+
+ mock1 = mock.Mock()
+ mock2 = mock.Mock()
+
+ plugin1.add_autogenerate_comparator(mock1, "test")
+ plugin2.add_autogenerate_comparator(mock2, "test")
+
+ dispatcher = PriorityDispatcher()
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.*", "~test.plugin2"]
+ )
+
+ # Should have plugin1's handler, but not plugin2's (excluded)
+ fn = dispatcher.dispatch("test")
+ fn()
+ eq_(mock1.mock_calls, [mock.call()])
+ eq_(mock2.mock_calls, [])
+
+ def test_populate_autogenerate_priority_dispatch_not_found(self):
+ """Test that non-matching pattern raises error."""
+ Plugin("test.plugin1")
+
+ dispatcher = PriorityDispatcher()
+ with testing.expect_raises_message(
+ util.CommandError,
+ "Did not locate plugins.*test.nonexistent",
+ ):
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.nonexistent"]
+ )
+
+ def test_populate_autogenerate_priority_dispatch_wildcard_not_found(
+ self,
+ ):
+ """Test that non-matching wildcard pattern raises error."""
+ Plugin("test.plugin1")
+
+ dispatcher = PriorityDispatcher()
+ with testing.expect_raises_message(
+ util.CommandError,
+ "Did not locate plugins",
+ ):
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["other.*"]
+ )
+
+ def test_populate_autogenerate_priority_dispatch_multiple_includes(self):
+ """Test populating with multiple include patterns."""
+ Plugin("test.plugin1")
+ Plugin("other.plugin2")
+
+ dispatcher = PriorityDispatcher()
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.plugin1", "other.plugin2"]
+ )
+ # Should not raise error
+
+ def test_setup_plugin_from_module(self):
+ """Test setting up plugin from a module."""
+ # Create a mock module with a setup function
+ mock_module = ModuleType("mock_plugin")
+
+ def setup(plugin):
+ eq_(plugin.name, "mock.plugin")
+ # Register a comparator to verify setup was called
+ plugin.add_autogenerate_comparator(
+ lambda: PriorityDispatchResult.CONTINUE,
+ "test_target",
+ )
+
+ mock_module.setup = setup
+
+ Plugin.setup_plugin_from_module(mock_module, "mock.plugin")
+
+ # Verify plugin was created
+ assert "mock.plugin" in _all_plugins
+
+ def test_autogenerate_comparators_dispatcher(self):
+ """Test that autogenerate_comparators is a PriorityDispatcher."""
+ plugin = Plugin("test.plugin")
+ assert isinstance(plugin.autogenerate_comparators, PriorityDispatcher)
+
+ def test_populate_with_real_handlers(self):
+ """Test populating dispatcher with actual comparison handlers."""
+ plugin = Plugin("test.plugin")
+ results = []
+
+ def compare_tables(
+ autogen_context, upgrade_ops, schemas
+ ): # pragma: no cover
+ results.append(("compare_tables", autogen_context))
+ return PriorityDispatchResult.CONTINUE
+
+ def compare_types(
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+ ): # pragma: no cover
+ results.append(("compare_types", tname))
+ return PriorityDispatchResult.CONTINUE
+
+ plugin.add_autogenerate_comparator(compare_tables, "table")
+ plugin.add_autogenerate_comparator(compare_types, "type")
+
+ dispatcher = PriorityDispatcher()
+ Plugin.populate_autogenerate_priority_dispatch(
+ dispatcher, ["test.plugin"]
+ )
+
+ # Verify handlers are in dispatcher
+ fn_table = dispatcher.dispatch("table")
+ fn_type = dispatcher.dispatch("type")
+
+ # Call them to verify they work
+ fn_table("autogen_ctx", "upgrade_ops", "schemas")
+ fn_type(
+ "autogen_ctx",
+ "alter_op",
+ "schema",
+ "tablename",
+ "colname",
+ "conn_col",
+ "meta_col",
+ )
+
+ eq_(results[0][0], "compare_tables")
+ eq_(results[1][0], "compare_types")
+ eq_(results[1][1], "tablename")
+
+
+class MakeReTest(TestBase):
+ """Tests for the _make_re helper function."""
+
+ def test_simple_name(self):
+ """Test regex generation for simple dotted names."""
+ pattern = _make_re("test.plugin")
+ assert pattern.match("test.plugin")
+
+ # Partial matches dont work; use a * for this
+ assert not pattern.match("test.plugin.extra")
+
+ # other tokens don't match either
+ assert not pattern.match("test.pluginfoo")
+
+ assert not pattern.match("other.plugin")
+ assert not pattern.match("test")
+
+ def test_wildcard(self):
+ """Test regex generation with wildcard."""
+ pattern = _make_re("test.*")
+ assert pattern.match("test.plugin")
+ assert pattern.match("test.plugin.extra")
+ assert not pattern.match("test")
+ assert not pattern.match("other.plugin")
+
+ def test_multiple_wildcards(self):
+ """Test regex generation with multiple wildcards."""
+ pattern = _make_re("test.*.sub.*")
+ assert pattern.match("test.plugin.sub.item")
+ assert pattern.match("test.a.sub.b")
+ assert not pattern.match("test.plugin")
+
+ def test_invalid_pattern_raises(self):
+ """Test that invalid patterns raise ValueError."""
+ with testing.expect_raises_message(
+ ValueError, "Invalid plugin expression"
+ ):
+ _make_re("test.plugin-name")
+
+ def test_valid_underscore(self):
+ """Test that underscores are valid in names."""
+ pattern = _make_re("test.my_plugin")
+ assert pattern.match("test.my_plugin")
+
+ def test_valid_mixed_case(self):
+ """Test that mixed case is valid in names."""
+ pattern = _make_re("test.MyPlugin")
+ assert pattern.match("test.MyPlugin")
+ assert not pattern.match("test.myplugin")
from alembic import testing
from alembic import util
from alembic.autogenerate import api
-from alembic.autogenerate.compare import _compare_server_default
-from alembic.autogenerate.compare import _compare_tables
-from alembic.autogenerate.compare import _render_server_default_for_compare
+from alembic.autogenerate.compare.tables import _compare_tables
from alembic.migration import MigrationContext
from alembic.operations import ops
from alembic.script import ScriptDirectory
from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest
+if True:
+ from alembic.autogenerate.compare.server_defaults import (
+ _render_server_default_for_compare,
+ ) # noqa: E501
+ from alembic.autogenerate.compare.server_defaults import (
+ _dialect_impl_compare_server_default as _compare_server_default,
+ )
+
+
class PostgresqlOpTest(TestBase):
def test_rename_table_postgresql(self):
context = op_fixture("postgresql")
from alembic import autogenerate
from alembic import op
from alembic.autogenerate import api
-from alembic.autogenerate.compare import _compare_server_default
from alembic.migration import MigrationContext
from alembic.operations import ops
from alembic.testing import assert_raises_message
from alembic.testing.fixtures import TestBase
from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest
+if True:
+ from alembic.autogenerate.compare.server_defaults import (
+ _dialect_impl_compare_server_default as _compare_server_default,
+ )
+
class SQLiteTest(TestBase):
def test_add_column(self):