From 57877461c1bd3b43a9d833fbca873d59db36b6f7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 27 Jan 2022 15:07:17 -0500 Subject: [PATCH] generalize conditional DDL throughout schema / DDL Expanded on the "conditional DDL" system implemented by the :class:`_schema.DDLElement` class to be directly available on :class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`, :class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic for generating these elements is included within the default DDL emitting process. This system can also be accommodated by a future release of Alembic to support conditional DDL elements within all schema-management systems. Fixes: #7631 Change-Id: I9457524d7f66f49696187cf7d2b37dbb44f0e20b --- doc/build/changelog/migration_20.rst | 37 ++++ doc/build/changelog/unreleased_20/7631.rst | 17 ++ doc/build/core/constraints.rst | 1 + doc/build/core/ddl.rst | 168 ++++++++++++----- lib/sqlalchemy/sql/compiler.py | 5 +- lib/sqlalchemy/sql/ddl.py | 206 +++++++++++++++------ lib/sqlalchemy/sql/schema.py | 71 ++++++- pyproject.toml | 2 +- test/engine/test_ddlevents.py | 121 +++++++++++- 9 files changed, 516 insertions(+), 112 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7631.rst diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index b8415d61a9..2fcfafbbe8 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -100,7 +100,44 @@ automatically. :ref:`postgresql_psycopg` +.. _ticket_7631: + +New Conditional DDL for Constraints and Indexes +------------------------------------------------ + +A new method :meth:`_schema.Constraint.ddl_if` and :meth:`_schema.Index.ddl_if` +allows constructs such as :class:`_schema.CheckConstraint`, :class:`_schema.UniqueConstraint` +and :class:`_schema.Index` to be rendered conditionally for a given +:class:`_schema.Table`, based on the same kinds of criteria that are accepted +by the :meth:`_schema.DDLElement.execute_if` method. In the example below, +the CHECK constraint and index will only be produced against a PostgreSQL +backend:: + + meta = MetaData() + + + my_table = Table( + "my_table", + meta, + Column("id", Integer, primary_key=True), + Column("num", Integer), + Column("data", String), + Index("my_pg_index", "data").ddl_if(dialect="postgresql"), + CheckConstraint("num > 5").ddl_if(dialect="postgresql"), + ) + + e1 = create_engine("sqlite://", echo=True) + meta.create_all(e1) # will not generate CHECK and INDEX + + + e2 = create_engine("postgresql://scott:tiger@localhost/test", echo=True) + meta.create_all(e2) # will generate CHECK and INDEX + +.. seealso:: + + :ref:`schema_ddl_ddl_if` +:ticket:`7631` Behavioral Changes ================== diff --git a/doc/build/changelog/unreleased_20/7631.rst b/doc/build/changelog/unreleased_20/7631.rst new file mode 100644 index 0000000000..d6c69f5d53 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7631.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: feature, schema + :tickets: 7631 + + Expanded on the "conditional DDL" system implemented by the + :class:`_schema.DDLElement` class to be directly available on + :class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`, + :class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic + for generating these elements is included within the default DDL emitting + process. This system can also be accommodated by a future release of + Alembic to support conditional DDL elements within all schema-management + systems. + + + .. seealso:: + + :ref:`ticket_7631` \ No newline at end of file diff --git a/doc/build/core/constraints.rst b/doc/build/core/constraints.rst index e4b07af7e2..ea84c15a3e 100644 --- a/doc/build/core/constraints.rst +++ b/doc/build/core/constraints.rst @@ -770,6 +770,7 @@ Constraints API --------------- .. autoclass:: Constraint :members: + :inherited-members: .. autoclass:: ColumnCollectionMixin :members: diff --git a/doc/build/core/ddl.rst b/doc/build/core/ddl.rst index 9c2fed198d..6bbd494246 100644 --- a/doc/build/core/ddl.rst +++ b/doc/build/core/ddl.rst @@ -150,7 +150,9 @@ Using the built-in DDLElement Classes ------------------------------------- The ``sqlalchemy.schema`` package contains SQL expression constructs that -provide DDL expressions. For example, to produce a ``CREATE TABLE`` statement: +provide DDL expressions, all of which extend from the common base +:class:`.DDLElement`. For example, to produce a ``CREATE TABLE`` statement, +one can use the :class:`.CreateTable` construct: .. sourcecode:: python+sql @@ -178,65 +180,133 @@ User-defined DDL constructs may also be created as subclasses of :class:`.DDLElement` itself. The documentation in :ref:`sqlalchemy.ext.compiler_toplevel` has several examples of this. -The event-driven DDL system described in the previous section -:ref:`schema_ddl_sequences` is available with other :class:`.DDLElement` -objects as well. However, when dealing with the built-in constructs -such as :class:`.CreateIndex`, :class:`.CreateSequence`, etc, the event -system is of **limited** use, as methods like :meth:`_schema.Table.create` and -:meth:`_schema.MetaData.create_all` will invoke these constructs unconditionally. -In a future SQLAlchemy release, the DDL event system including conditional -execution will taken into account for built-in constructs that currently -invoke in all cases. - -We can illustrate an event-driven -example with the :class:`.AddConstraint` and :class:`.DropConstraint` -constructs, as the event-driven system will work for CHECK and UNIQUE -constraints, using these as we did in our previous example of -:meth:`.DDLElement.execute_if`: +.. _schema_ddl_ddl_if: + +Controlling DDL Generation of Constraints and Indexes +----------------------------------------------------- + +.. versionadded:: 2.0 + +While the previously mentioned :meth:`.DDLElement.execute_if` method is +useful for custom :class:`.DDL` classes which need to invoke conditionally, +there is also a common need for elements that are typically related to a +particular :class:`.Table`, namely constraints and indexes, to also be +subject to "conditional" rules, such as an index that includes features +that are specific to a particular backend such as PostgreSQL or SQL Server. +For this use case, the :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if` +methods may be used against constructs such as :class:`.CheckConstraint`, +:class:`.UniqueConstraint` and :class:`.Index`, accepting the same +arguments as the :meth:`.DDLElement.execute_if` method in order to control +whether or not their DDL will be emitted in terms of their parent +:class:`.Table` object. These methods may be used inline when +creating the definition for a :class:`.Table` +(or similarly, when using the ``__table_args__`` collection in an ORM +declarative mapping), such as:: + + from sqlalchemy import CheckConstraint, Index + from sqlalchemy import MetaData, Table, Column + from sqlalchemy import Integer, String + + meta = MetaData() + + my_table = Table( + "my_table", + meta, + Column("id", Integer, primary_key=True), + Column("num", Integer), + Column("data", String), + Index("my_pg_index", "data").ddl_if(dialect="postgresql"), + CheckConstraint("num > 5").ddl_if(dialect="postgresql"), + ) + +In the above example, the :class:`.Table` construct refers to both an +:class:`.Index` and a :class:`.CheckConstraint` construct, both which +indicate ``.ddl_if(dialect="postgresql")``, which indicates that these +elements will be included in the CREATE TABLE sequence only against the +PostgreSQL dialect. If we run ``meta.create_all()`` against the SQLite +dialect, for example, neither construct will be included: .. sourcecode:: python+sql - def should_create(ddl, target, connection, **kw): - row = connection.execute( - "select conname from pg_constraint where conname='%s'" % - ddl.element.name).scalar() - return not bool(row) + >>> from sqlalchemy import create_engine + >>> sqlite_engine = create_engine("sqlite+pysqlite://", echo=True) + >>> meta.create_all(sqlite_engine) + {opensql}BEGIN (implicit) + PRAGMA main.table_info("my_table") + [raw sql] () + PRAGMA temp.table_info("my_table") + [raw sql] () + + CREATE TABLE my_table ( + id INTEGER NOT NULL, + num INTEGER, + data VARCHAR, + PRIMARY KEY (id) + ) - def should_drop(ddl, target, connection, **kw): - return not should_create(ddl, target, connection, **kw) +However, if we run the same commands against a PostgreSQL database, we will +see inline DDL for the CHECK constraint as well as a separate CREATE +statement emitted for the index: - event.listen( - users, - "after_create", - AddConstraint(constraint).execute_if(callable_=should_create) - ) - event.listen( - users, - "before_drop", - DropConstraint(constraint).execute_if(callable_=should_drop) +.. sourcecode:: python+sql + + >>> from sqlalchemy import create_engine + >>> postgresql_engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", echo=True) + >>> meta.create_all(postgresql_engine) + {opensql}BEGIN (implicit) + select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where pg_catalog.pg_table_is_visible(c.oid) and relname=%(name)s + [generated in 0.00009s] {'name': 'my_table'} + + CREATE TABLE my_table ( + id SERIAL NOT NULL, + num INTEGER, + data VARCHAR, + PRIMARY KEY (id), + CHECK (num > 5) ) + [no key 0.00007s] {} + CREATE INDEX my_pg_index ON my_table (data) + [no key 0.00013s] {} + COMMIT + +The :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if` methods create +an event hook that may be consulted not just at DDL execution time, as is the +behavior with :meth:`.DDLElement.execute_if`, but also within the SQL compilation +phase of the :class:`.CreateTable` object, which is responsible for rendering +the ``CHECK (num > 5)`` DDL inline within the CREATE TABLE statement. +As such, the event hook that is received by the :meth:`.Constraint.ddl_if.callable_` +parameter has a richer argument set present, including that there is +a ``dialect`` keyword argument passed, as well as an instance of :class:`.DDLCompiler` +via the ``compiler`` keyword argument for the "inline rendering" portion of the +sequence. The ``bind`` argument is **not** present when the event is triggered +within the :class:`.DDLCompiler` sequence, so a modern event hook that wishes +to inspect the database versioning information would best use the given +:class:`.Dialect` object, such as to test PostgreSQL versioning: - {sql}users.create(engine) - CREATE TABLE users ( - user_id SERIAL NOT NULL, - user_name VARCHAR(40) NOT NULL, - PRIMARY KEY (user_id) +.. sourcecode:: python+sql + + def only_pg_14(ddl_element, target, bind, dialect, **kw): + return ( + dialect.name == "postgresql" and + dialect.server_version_info >= (14,) + ) + + my_table = Table( + "my_table", + meta, + Column("id", Integer, primary_key=True), + Column("num", Integer), + Column("data", String), + Index("my_pg_index", "data").ddl_if(callable_=only_pg_14), ) - select conname from pg_constraint where conname='cst_user_name_length' - ALTER TABLE users ADD CONSTRAINT cst_user_name_length CHECK (length(user_name) >= 8){stop} +.. seealso:: + + :meth:`.Constraint.ddl_if` + + :meth:`.Index.ddl_if` - {sql}users.drop(engine) - select conname from pg_constraint where conname='cst_user_name_length' - ALTER TABLE users DROP CONSTRAINT cst_user_name_length - DROP TABLE users{stop} -While the above example is against the built-in :class:`.AddConstraint` -and :class:`.DropConstraint` objects, the main usefulness of DDL events -for now remains focused on the use of the :class:`.DDL` construct itself, -as well as with user-defined subclasses of :class:`.DDLElement` that aren't -already part of the :meth:`_schema.MetaData.create_all`, :meth:`_schema.Table.create`, -and corresponding "drop" processes. .. _schema_api_ddl: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5ba52ae51c..aa98ff2565 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -4889,10 +4889,7 @@ class DDLCompiler(Compiled): for p in ( self.process(constraint) for constraint in constraints - if ( - constraint._create_rule is None - or constraint._create_rule(self) - ) + if (constraint._should_create_for_compiler(self)) and ( not self.dialect.supports_alter or not getattr(constraint, "use_alter", False) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 7acb69bebb..4d57ad8698 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -12,10 +12,11 @@ to invoke them for a create/drop call. from __future__ import annotations import typing +from typing import Any from typing import Callable from typing import List from typing import Optional -from typing import Sequence +from typing import Sequence as typing_Sequence from typing import Tuple from . import roles @@ -26,11 +27,20 @@ from .elements import ClauseElement from .. import exc from .. import util from ..util import topological - +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .elements import BindParameter from .schema import ForeignKeyConstraint + from .schema import SchemaItem from .schema import Table + from ..engine.base import _CompiledCacheType + from ..engine.base import Connection + from ..engine.interfaces import _SchemaTranslateMapType + from ..engine.interfaces import CacheStats + from ..engine.interfaces import Dialect class _DDLCompiles(ClauseElement): @@ -43,10 +53,70 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) - def _compile_w_cache(self, *arg, **kw): + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats + ]: raise NotImplementedError() +class DDLIfCallable(Protocol): + def __call__( + self, + ddl: "DDLElement", + target: "SchemaItem", + bind: Optional["Connection"], + tables: Optional[List["Table"]] = None, + state: Optional[Any] = None, + *, + dialect: Dialect, + compiler: Optional[DDLCompiler] = ..., + checkfirst: bool, + ) -> bool: + ... + + +class DDLIf(typing.NamedTuple): + dialect: Optional[str] + callable_: Optional[DDLIfCallable] + state: Optional[Any] + + def _should_execute(self, ddl, target, bind, compiler=None, **kw): + if bind is not None: + dialect = bind.dialect + elif compiler is not None: + dialect = compiler.dialect + else: + assert False, "compiler or dialect is required" + + if isinstance(self.dialect, str): + if self.dialect != dialect.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if dialect.name not in self.dialect: + return False + if self.callable_ is not None and not self.callable_( + ddl, + target, + bind, + state=self.state, + dialect=dialect, + compiler=compiler, + **kw, + ): + return False + + return True + + SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement") @@ -80,10 +150,8 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """ - target = None - on = None - dialect = None - callable_ = None + _ddl_if: Optional[DDLIf] = None + target: Optional["SchemaItem"] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -93,7 +161,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): ) @_generative - def against(self: SelfDDLElement, target) -> SelfDDLElement: + def against(self: SelfDDLElement, target: SchemaItem) -> SelfDDLElement: """Return a copy of this :class:`_schema.DDLElement` which will include the given target. @@ -125,13 +193,15 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): processing the DDL string. """ - self.target = target return self @_generative def execute_if( - self: SelfDDLElement, dialect=None, callable_=None, state=None + self: SelfDDLElement, + dialect: Optional[str] = None, + callable_: Optional[DDLIfCallable] = None, + state: Optional[Any] = None, ) -> SelfDDLElement: r"""Return a callable that will execute this :class:`_ddl.DDLElement` conditionally within an event handler. @@ -155,7 +225,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): DDL('something').execute_if(dialect=('postgresql', 'mysql')) :param callable\_: A callable, which will be invoked with - four positional arguments as well as optional keyword + three positional arguments as well as optional keyword arguments: :ddl: @@ -168,13 +238,22 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): explicitly. :bind: - The :class:`_engine.Connection` being used for DDL execution + The :class:`_engine.Connection` being used for DDL execution. + May be None if this construct is being created inline within + a table, in which case ``compiler`` will be present. :tables: Optional keyword argument - a list of Table objects which are to be created/ dropped within a MetaData.create_all() or drop_all() method call. + :dialect: keyword argument, but always present - the + :class:`.Dialect` involved in the operation. + + :compiler: keyword argument. Will be ``None`` for an engine + level DDL invocation, but will refer to a :class:`.DDLCompiler` + if this DDL element is being created inline within a table. + :state: Optional keyword argument - will be the ``state`` argument passed to this function. @@ -192,35 +271,30 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): .. seealso:: + :meth:`.SchemaItem.ddl_if` + :class:`.DDLEvents` :ref:`event_toplevel` """ - self.dialect = dialect - self.callable_ = callable_ - self.state = state + self._ddl_if = DDLIf(dialect, callable_, state) return self def _should_execute(self, target, bind, **kw): - if isinstance(self.dialect, str): - if self.dialect != bind.engine.name: - return False - elif isinstance(self.dialect, (tuple, list, set)): - if bind.engine.name not in self.dialect: - return False - if self.callable_ is not None and not self.callable_( - self, target, bind, state=self.state, **kw - ): - return False + if self._ddl_if is None: + return True + else: + return self._ddl_if._should_execute(self, target, bind, **kw) - return True + def _invoke_with(self, bind): + if self._should_execute(self.target, bind): + return bind.execute(self) def __call__(self, target, bind, **kw): """Execute the DDL as a ddl_listener.""" - if self._should_execute(target, bind, **kw): - return bind.execute(self.against(target)) + self.against(target)._invoke_with(bind) def _generate(self): s = self.__class__.__new__(self.__class__) @@ -330,9 +404,10 @@ class _CreateDropBase(DDLElement): if_exists=False, if_not_exists=False, ): - self.element = element + self.element = self.target = element self.if_exists = if_exists self.if_not_exists = if_not_exists + self._ddl_if = getattr(element, "_ddl_if", None) @property def stringify_dialect(self): @@ -358,11 +433,19 @@ class CreateSchema(_CreateDropBase): __visit_name__ = "create_schema" - def __init__(self, name, quote=None, **kw): + def __init__( + self, + name, + quote=None, + if_exists=False, + if_not_exists=False, + ): """Create a new :class:`.CreateSchema` construct.""" self.quote = quote - super(CreateSchema, self).__init__(name, **kw) + self.element = name + self.if_exists = if_exists + self.if_not_exists = if_not_exists class DropSchema(_CreateDropBase): @@ -374,12 +457,22 @@ class DropSchema(_CreateDropBase): __visit_name__ = "drop_schema" - def __init__(self, name, quote=None, cascade=False, **kw): + def __init__( + self, + name, + quote=None, + cascade=False, + if_exists=False, + if_not_exists=False, + ): """Create a new :class:`.DropSchema` construct.""" self.quote = quote self.cascade = cascade - super(DropSchema, self).__init__(name, **kw) + self.quote = quote + self.element = name + self.if_exists = if_exists + self.if_not_exists = if_not_exists class CreateTable(_CreateDropBase): @@ -427,6 +520,11 @@ class _DropView(_CreateDropBase): __visit_name__ = "drop_view" +class CreateConstraint(_DDLCompiles): + def __init__(self, element): + self.element = element + + class CreateColumn(_DDLCompiles): """Represent a :class:`_schema.Column` as rendered in a CREATE TABLE statement, @@ -784,15 +882,10 @@ class SchemaGenerator(DDLBase): # e.g., don't omit any foreign key constraints include_foreign_key_constraints = None - self.connection.execute( - # fmt: off - CreateTable( - table, - include_foreign_key_constraints= # noqa - include_foreign_key_constraints, # noqa - ) - # fmt: on - ) + CreateTable( + table, + include_foreign_key_constraints=include_foreign_key_constraints, + )._invoke_with(self.connection) if hasattr(table, "indexes"): for index in table.indexes: @@ -800,11 +893,11 @@ class SchemaGenerator(DDLBase): if self.dialect.supports_comments and not self.dialect.inline_comments: if table.comment is not None: - self.connection.execute(SetTableComment(table)) + SetTableComment(table)._invoke_with(self.connection) for column in table.columns: if column.comment is not None: - self.connection.execute(SetColumnComment(column)) + SetColumnComment(column)._invoke_with(self.connection) table.dispatch.after_create( table, @@ -817,17 +910,17 @@ class SchemaGenerator(DDLBase): def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - self.connection.execute(AddConstraint(constraint)) + AddConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, create_ok=False): if not create_ok and not self._can_create_sequence(sequence): return - self.connection.execute(CreateSequence(sequence)) + CreateSequence(sequence)._invoke_with(self.connection) def visit_index(self, index, create_ok=False): if not create_ok and not self._can_create_index(index): return - self.connection.execute(CreateIndex(index)) + CreateIndex(index)._invoke_with(self.connection) class SchemaDropper(DDLBase): @@ -964,7 +1057,7 @@ class SchemaDropper(DDLBase): if not drop_ok and not self._can_drop_index(index): return - self.connection.execute(DropIndex(index)) + DropIndex(index)(index, self.connection) def visit_table( self, @@ -984,7 +1077,7 @@ class SchemaDropper(DDLBase): _is_metadata_operation=_is_metadata_operation, ) - self.connection.execute(DropTable(table)) + DropTable(table)._invoke_with(self.connection) # traverse client side defaults which may refer to server-side # sequences. noting that some of these client side defaults may also be @@ -1009,19 +1102,21 @@ class SchemaDropper(DDLBase): def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - self.connection.execute(DropConstraint(constraint)) + DropConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not self._can_drop_sequence(sequence): return - self.connection.execute(DropSequence(sequence)) + DropSequence(sequence)._invoke_with(self.connection) def sort_tables( - tables: Sequence["Table"], + tables: typing_Sequence["Table"], skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, - extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None, + extra_dependencies: Optional[ + typing_Sequence[Tuple["Table", "Table"]] + ] = None, ) -> List["Table"]: """Sort a collection of :class:`_schema.Table` objects based on dependency. @@ -1082,16 +1177,17 @@ def sort_tables( """ if skip_fn is not None: + fixed_skip_fn = skip_fn def _skip_fn(fkc): for fk in fkc.elements: - if skip_fn(fk): + if fixed_skip_fn(fk): return True else: return None else: - _skip_fn = None + _skip_fn = None # type: ignore return [ t diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 540b62e8aa..dfe82432d7 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -184,6 +184,61 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): _use_schema_map = True +SelfHasConditionalDDL = TypeVar( + "SelfHasConditionalDDL", bound="HasConditionalDDL" +) + + +class HasConditionalDDL: + """define a class that includes the :meth:`.HasConditionalDDL.ddl_if` + method, allowing for conditional rendering of DDL. + + Currently applies to constraints and indexes. + + .. versionadded:: 2.0 + + + """ + + _ddl_if: Optional[ddl.DDLIf] = None + + def ddl_if( + self: SelfHasConditionalDDL, + dialect: Optional[str] = None, + callable_: Optional[ddl.DDLIfCallable] = None, + state: Optional[Any] = None, + ) -> SelfHasConditionalDDL: + r"""apply a conditional DDL rule to this schema item. + + These rules work in a similar manner to the + :meth:`.DDLElement.execute_if` callable, with the added feature that + the criteria may be checked within the DDL compilation phase for a + construct such as :class:`.CreateTable`. + :meth:`.HasConditionalDDL.ddl_if` currently applies towards the + :class:`.Index` construct as well as all :class:`.Constraint` + constructs. + + :param dialect: string name of a dialect, or a tuple of string names + to indicate multiple dialect types. + + :param callable\_: a callable that is constructed using the same form + as that described in :paramref:`.DDLElement.execute_if.callable_`. + + :param state: any arbitrary object that will be passed to the + callable, if present. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`schema_ddl_ddl_if` - background and usage examples + + + """ + self._ddl_if = ddl.DDLIf(dialect, callable_, state) + return self + + class HasSchemaAttr(SchemaItem): """schema item that includes a top-level schema name""" @@ -3355,7 +3410,7 @@ class DefaultClause(FetchedValue): return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) -class Constraint(DialectKWArgs, SchemaItem): +class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): """A table-level SQL constraint. :class:`_schema.Constraint` serves as the base class for the series of @@ -3424,6 +3479,16 @@ class Constraint(DialectKWArgs, SchemaItem): util.set_creation_order(self) self._validate_dialect_kwargs(dialect_kw) + def _should_create_for_compiler(self, compiler, **kw): + if self._create_rule is not None and not self._create_rule(compiler): + return False + elif self._ddl_if is not None: + return self._ddl_if._should_execute( + ddl.CreateConstraint(self), self, None, compiler=compiler, **kw + ) + else: + return True + @property def table(self): try: @@ -4292,7 +4357,9 @@ class UniqueConstraint(ColumnCollectionConstraint): __visit_name__ = "unique_constraint" -class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): +class Index( + DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem +): """A table-level INDEX. Defines a composite (one or more column) INDEX. diff --git a/pyproject.toml b/pyproject.toml index acbc69537a..aa2790b049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,6 @@ module = [ "sqlalchemy.ext.serializer", "sqlalchemy.sql.selectable", # would be nice as strict - "sqlalchemy.sql.ddl", "sqlalchemy.sql.functions", # would be nice as strict "sqlalchemy.sql.lambdas", "sqlalchemy.sql.dml", # would be nice as strict @@ -132,6 +131,7 @@ module = [ "sqlalchemy.sql.coercions", "sqlalchemy.sql.compiler", "sqlalchemy.sql.crud", + "sqlalchemy.sql.ddl", # would be nice as strict "sqlalchemy.sql.elements", # would be nice as strict "sqlalchemy.sql.naming", "sqlalchemy.sql.schema", # would be nice as strict diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index f339ef1715..0c72e32c7b 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,7 +1,11 @@ +from unittest import mock +from unittest.mock import Mock + import sqlalchemy as tsa from sqlalchemy import create_engine from sqlalchemy import create_mock_engine from sqlalchemy import event +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import String @@ -373,7 +377,7 @@ class DDLEventTest(fixtures.TestBase): eq_(metadata_canary.mock_calls, []) -class DDLExecutionTest(fixtures.TestBase): +class DDLExecutionTest(AssertsCompiledSQL, fixtures.TestBase): def setup_test(self): self.engine = engines.mock_engine() self.metadata = MetaData() @@ -485,6 +489,121 @@ class DDLExecutionTest(fixtures.TestBase): strings = " ".join(str(x) for x in pg_mock.mock) assert "my_test_constraint" in strings + @testing.combinations(("dialect",), ("callable",), ("callable_w_state",)) + def test_inline_ddl_if_dialect_name(self, ddl_if_type): + nonpg_mock = engines.mock_engine(dialect_name="sqlite") + pg_mock = engines.mock_engine(dialect_name="postgresql") + + metadata = MetaData() + + capture_mock = Mock() + state = object() + + if ddl_if_type == "dialect": + ddl_kwargs = dict(dialect="postgresql") + elif ddl_if_type == "callable": + + def is_pg(ddl, target, bind, **kw): + capture_mock.is_pg(ddl, target, bind, **kw) + return kw["dialect"].name == "postgresql" + + ddl_kwargs = dict(callable_=is_pg) + elif ddl_if_type == "callable_w_state": + + def is_pg(ddl, target, bind, **kw): + capture_mock.is_pg(ddl, target, bind, **kw) + return kw["dialect"].name == "postgresql" + + ddl_kwargs = dict(callable_=is_pg, state=state) + else: + assert False + + data_col = Column("data", String) + t = Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("num", Integer), + data_col, + Index("my_pg_index", data_col).ddl_if(**ddl_kwargs), + CheckConstraint("num > 5").ddl_if(**ddl_kwargs), + ) + + metadata.create_all(nonpg_mock) + eq_(len(nonpg_mock.mock), 1) + self.assert_compile( + nonpg_mock.mock[0], + "CREATE TABLE a (id INTEGER NOT NULL, num INTEGER, " + "data VARCHAR, PRIMARY KEY (id))", + dialect=nonpg_mock.dialect, + ) + + metadata.create_all(pg_mock) + + eq_(len(pg_mock.mock), 2) + + self.assert_compile( + pg_mock.mock[0], + "CREATE TABLE a (id SERIAL NOT NULL, num INTEGER, " + "data VARCHAR, PRIMARY KEY (id), CHECK (num > 5))", + dialect=pg_mock.dialect, + ) + self.assert_compile( + pg_mock.mock[1], + "CREATE INDEX my_pg_index ON a (data)", + dialect="postgresql", + ) + + the_index = list(t.indexes)[0] + the_constraint = list( + c for c in t.constraints if isinstance(c, CheckConstraint) + )[0] + + if ddl_if_type in ("callable", "callable_w_state"): + + if ddl_if_type == "callable": + check_state = None + else: + check_state = state + + eq_( + capture_mock.mock_calls, + [ + mock.call.is_pg( + mock.ANY, + the_index, + mock.ANY, + state=check_state, + dialect=nonpg_mock.dialect, + compiler=None, + ), + mock.call.is_pg( + mock.ANY, + the_constraint, + None, + state=check_state, + dialect=nonpg_mock.dialect, + compiler=mock.ANY, + ), + mock.call.is_pg( + mock.ANY, + the_index, + mock.ANY, + state=check_state, + dialect=pg_mock.dialect, + compiler=None, + ), + mock.call.is_pg( + mock.ANY, + the_constraint, + None, + state=check_state, + dialect=pg_mock.dialect, + compiler=mock.ANY, + ), + ], + ) + @testing.requires.sqlite def test_ddl_execute(self): engine = create_engine("sqlite:///") -- 2.47.2