]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generalize conditional DDL throughout schema / DDL
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Jan 2022 20:07:17 +0000 (15:07 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Mar 2022 14:16:36 +0000 (10:16 -0400)
Expanded on the "conditional DDL" system implemented by the
:class:`_schema.DDLElement` class to be directly available on
:class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`,
:class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic
for generating these elements is included within the default DDL emitting
process. This system can also be accommodated by a future release of
Alembic to support conditional DDL elements within all schema-management
systems.

Fixes: #7631
Change-Id: I9457524d7f66f49696187cf7d2b37dbb44f0e20b

doc/build/changelog/migration_20.rst
doc/build/changelog/unreleased_20/7631.rst [new file with mode: 0644]
doc/build/core/constraints.rst
doc/build/core/ddl.rst
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/schema.py
pyproject.toml
test/engine/test_ddlevents.py

index b8415d61a92a4cdb6b49f01639a036aaa0fc846b..2fcfafbbe85f41f3fa33cae5e8d5495dad7db2a3 100644 (file)
@@ -100,7 +100,44 @@ automatically.
 
     :ref:`postgresql_psycopg`
 
+.. _ticket_7631:
+
+New Conditional DDL for Constraints and Indexes
+------------------------------------------------
+
+A new method :meth:`_schema.Constraint.ddl_if` and :meth:`_schema.Index.ddl_if`
+allows constructs such as :class:`_schema.CheckConstraint`, :class:`_schema.UniqueConstraint`
+and :class:`_schema.Index` to be rendered conditionally for a given
+:class:`_schema.Table`, based on the same kinds of criteria that are accepted
+by the :meth:`_schema.DDLElement.execute_if` method.  In the example below,
+the CHECK constraint and index will only be produced against a PostgreSQL
+backend::
+
+    meta = MetaData()
+
+
+    my_table = Table(
+        "my_table",
+        meta,
+        Column("id", Integer, primary_key=True),
+        Column("num", Integer),
+        Column("data", String),
+        Index("my_pg_index", "data").ddl_if(dialect="postgresql"),
+        CheckConstraint("num > 5").ddl_if(dialect="postgresql"),
+    )
+
+    e1 = create_engine("sqlite://", echo=True)
+    meta.create_all(e1)  # will not generate CHECK and INDEX
+
+
+    e2 = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
+    meta.create_all(e2)  # will generate CHECK and INDEX
+
+.. seealso::
+
+    :ref:`schema_ddl_ddl_if`
 
+:ticket:`7631`
 
 Behavioral Changes
 ==================
diff --git a/doc/build/changelog/unreleased_20/7631.rst b/doc/build/changelog/unreleased_20/7631.rst
new file mode 100644 (file)
index 0000000..d6c69f5
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: feature, schema
+    :tickets: 7631
+
+    Expanded on the "conditional DDL" system implemented by the
+    :class:`_schema.DDLElement` class to be directly available on
+    :class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`,
+    :class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic
+    for generating these elements is included within the default DDL emitting
+    process. This system can also be accommodated by a future release of
+    Alembic to support conditional DDL elements within all schema-management
+    systems.
+
+
+    .. seealso::
+
+        :ref:`ticket_7631`
\ No newline at end of file
index e4b07af7e2cab81ba34e5b66fac683f5751bb889..ea84c15a3e6d3c4d80c2f9d443622e267e721481 100644 (file)
@@ -770,6 +770,7 @@ Constraints API
 ---------------
 .. autoclass:: Constraint
     :members:
+    :inherited-members:
 
 .. autoclass:: ColumnCollectionMixin
     :members:
index 9c2fed198dbb593599abf0b32c5c949fce034124..6bbd4942467081b463b049c9d2e67cd7580fc199 100644 (file)
@@ -150,7 +150,9 @@ Using the built-in DDLElement Classes
 -------------------------------------
 
 The ``sqlalchemy.schema`` package contains SQL expression constructs that
-provide DDL expressions. For example, to produce a ``CREATE TABLE`` statement:
+provide DDL expressions, all of which extend from the common base
+:class:`.DDLElement`. For example, to produce a ``CREATE TABLE`` statement,
+one can use the :class:`.CreateTable` construct:
 
 .. sourcecode:: python+sql
 
@@ -178,65 +180,133 @@ User-defined DDL constructs may also be created as subclasses of
 :class:`.DDLElement` itself.   The documentation in
 :ref:`sqlalchemy.ext.compiler_toplevel` has several examples of this.
 
-The event-driven DDL system described in the previous section
-:ref:`schema_ddl_sequences` is available with other :class:`.DDLElement`
-objects as well.  However, when dealing with the built-in constructs
-such as :class:`.CreateIndex`, :class:`.CreateSequence`, etc, the event
-system is of **limited** use, as methods like :meth:`_schema.Table.create` and
-:meth:`_schema.MetaData.create_all` will invoke these constructs unconditionally.
-In a future SQLAlchemy release, the DDL event system including conditional
-execution will taken into account for built-in constructs that currently
-invoke in all cases.
-
-We can illustrate an event-driven
-example with the :class:`.AddConstraint` and :class:`.DropConstraint`
-constructs, as the event-driven system will work for CHECK and UNIQUE
-constraints, using these as we did in our previous example of
-:meth:`.DDLElement.execute_if`:
+.. _schema_ddl_ddl_if:
+
+Controlling DDL Generation of Constraints and Indexes
+-----------------------------------------------------
+
+.. versionadded:: 2.0
+
+While the previously mentioned :meth:`.DDLElement.execute_if` method is
+useful for custom :class:`.DDL` classes which need to invoke conditionally,
+there is also a common need for elements that are typically related to a
+particular :class:`.Table`, namely constraints and indexes, to also be
+subject to "conditional" rules, such as an index that includes features
+that are specific to a particular backend such as PostgreSQL or SQL Server.
+For this use case, the :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if`
+methods may be used against constructs such as :class:`.CheckConstraint`,
+:class:`.UniqueConstraint` and :class:`.Index`, accepting the same
+arguments as the :meth:`.DDLElement.execute_if` method in order to control
+whether or not their DDL will be emitted in terms of their parent
+:class:`.Table` object.  These methods may be used inline when
+creating the definition for a :class:`.Table`
+(or similarly, when using the ``__table_args__`` collection in an ORM
+declarative mapping), such as::
+
+    from sqlalchemy import CheckConstraint, Index
+    from sqlalchemy import MetaData, Table, Column
+    from sqlalchemy import Integer, String
+
+    meta = MetaData()
+
+    my_table = Table(
+        "my_table",
+        meta,
+        Column("id", Integer, primary_key=True),
+        Column("num", Integer),
+        Column("data", String),
+        Index("my_pg_index", "data").ddl_if(dialect="postgresql"),
+        CheckConstraint("num > 5").ddl_if(dialect="postgresql"),
+    )
+
+In the above example, the :class:`.Table` construct refers to both an
+:class:`.Index` and a :class:`.CheckConstraint` construct, both which
+indicate ``.ddl_if(dialect="postgresql")``, which indicates that these
+elements will be included in the CREATE TABLE sequence only against the
+PostgreSQL dialect.  If we run ``meta.create_all()`` against the SQLite
+dialect, for example, neither construct will be included:
 
 .. sourcecode:: python+sql
 
-    def should_create(ddl, target, connection, **kw):
-        row = connection.execute(
-            "select conname from pg_constraint where conname='%s'" %
-            ddl.element.name).scalar()
-        return not bool(row)
+    >>> from sqlalchemy import create_engine
+    >>> sqlite_engine = create_engine("sqlite+pysqlite://", echo=True)
+    >>> meta.create_all(sqlite_engine)
+    {opensql}BEGIN (implicit)
+    PRAGMA main.table_info("my_table")
+    [raw sql] ()
+    PRAGMA temp.table_info("my_table")
+    [raw sql] ()
+
+    CREATE TABLE my_table (
+        id INTEGER NOT NULL,
+        num INTEGER,
+        data VARCHAR,
+        PRIMARY KEY (id)
+    )
 
-    def should_drop(ddl, target, connection, **kw):
-        return not should_create(ddl, target, connection, **kw)
+However, if we run the same commands against a PostgreSQL database, we will
+see inline DDL for the CHECK constraint as well as a separate CREATE
+statement emitted for the index:
 
-    event.listen(
-        users,
-        "after_create",
-        AddConstraint(constraint).execute_if(callable_=should_create)
-    )
-    event.listen(
-        users,
-        "before_drop",
-        DropConstraint(constraint).execute_if(callable_=should_drop)
+.. sourcecode:: python+sql
+
+    >>> from sqlalchemy import create_engine
+    >>> postgresql_engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", echo=True)
+    >>> meta.create_all(postgresql_engine)
+    {opensql}BEGIN (implicit)
+    select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where pg_catalog.pg_table_is_visible(c.oid) and relname=%(name)s
+    [generated in 0.00009s] {'name': 'my_table'}
+
+    CREATE TABLE my_table (
+        id SERIAL NOT NULL,
+        num INTEGER,
+        data VARCHAR,
+        PRIMARY KEY (id),
+        CHECK (num > 5)
     )
+    [no key 0.00007s] {}
+    CREATE INDEX my_pg_index ON my_table (data)
+    [no key 0.00013s] {}
+    COMMIT
+
+The :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if` methods create
+an event hook that may be consulted not just at DDL execution time, as is the
+behavior with :meth:`.DDLElement.execute_if`, but also within the SQL compilation
+phase of the :class:`.CreateTable` object, which is responsible for rendering
+the ``CHECK (num > 5)`` DDL inline within the CREATE TABLE statement.
+As such, the event hook that is received by the :meth:`.Constraint.ddl_if.callable_`
+parameter has a richer argument set present, including that there is
+a ``dialect`` keyword argument passed, as well as an instance of :class:`.DDLCompiler`
+via the ``compiler`` keyword argument for the "inline rendering" portion of the
+sequence.  The ``bind`` argument is **not** present when the event is triggered
+within the :class:`.DDLCompiler` sequence, so a modern event hook that wishes
+to inspect the database versioning information would best use the given
+:class:`.Dialect` object, such as to test PostgreSQL versioning:
 
-    {sql}users.create(engine)
-    CREATE TABLE users (
-        user_id SERIAL NOT NULL,
-        user_name VARCHAR(40) NOT NULL,
-        PRIMARY KEY (user_id)
+.. sourcecode:: python+sql
+
+    def only_pg_14(ddl_element, target, bind, dialect, **kw):
+        return (
+            dialect.name == "postgresql" and
+            dialect.server_version_info >= (14,)
+        )
+
+    my_table = Table(
+        "my_table",
+        meta,
+        Column("id", Integer, primary_key=True),
+        Column("num", Integer),
+        Column("data", String),
+        Index("my_pg_index", "data").ddl_if(callable_=only_pg_14),
     )
 
-    select conname from pg_constraint where conname='cst_user_name_length'
-    ALTER TABLE users ADD CONSTRAINT cst_user_name_length  CHECK (length(user_name) >= 8){stop}
+.. seealso::
+
+    :meth:`.Constraint.ddl_if`
+
+    :meth:`.Index.ddl_if`
 
-    {sql}users.drop(engine)
-    select conname from pg_constraint where conname='cst_user_name_length'
-    ALTER TABLE users DROP CONSTRAINT cst_user_name_length
-    DROP TABLE users{stop}
 
-While the above example is against the built-in :class:`.AddConstraint`
-and :class:`.DropConstraint` objects, the main usefulness of DDL events
-for now remains focused on the use of the :class:`.DDL` construct itself,
-as well as with user-defined subclasses of :class:`.DDLElement` that aren't
-already part of the :meth:`_schema.MetaData.create_all`, :meth:`_schema.Table.create`,
-and corresponding "drop" processes.
 
 .. _schema_api_ddl:
 
index 5ba52ae51c03695a80af7670585a69073b1cb157..aa98ff2565dc6301ceb37f39e6f08a0af8abfb34 100644 (file)
@@ -4889,10 +4889,7 @@ class DDLCompiler(Compiled):
             for p in (
                 self.process(constraint)
                 for constraint in constraints
-                if (
-                    constraint._create_rule is None
-                    or constraint._create_rule(self)
-                )
+                if (constraint._should_create_for_compiler(self))
                 and (
                     not self.dialect.supports_alter
                     or not getattr(constraint, "use_alter", False)
index 7acb69bebb57bc412e4f11ccc67fdd5431df92b7..4d57ad869810a3a59bf09959e33d6ea03ebc00d8 100644 (file)
@@ -12,10 +12,11 @@ to invoke them for a create/drop call.
 from __future__ import annotations
 
 import typing
+from typing import Any
 from typing import Callable
 from typing import List
 from typing import Optional
-from typing import Sequence
+from typing import Sequence as typing_Sequence
 from typing import Tuple
 
 from . import roles
@@ -26,11 +27,20 @@ from .elements import ClauseElement
 from .. import exc
 from .. import util
 from ..util import topological
-
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from .compiler import Compiled
+    from .compiler import DDLCompiler
+    from .elements import BindParameter
     from .schema import ForeignKeyConstraint
+    from .schema import SchemaItem
     from .schema import Table
+    from ..engine.base import _CompiledCacheType
+    from ..engine.base import Connection
+    from ..engine.interfaces import _SchemaTranslateMapType
+    from ..engine.interfaces import CacheStats
+    from ..engine.interfaces import Dialect
 
 
 class _DDLCompiles(ClauseElement):
@@ -43,10 +53,70 @@ class _DDLCompiles(ClauseElement):
 
         return dialect.ddl_compiler(dialect, self, **kw)
 
-    def _compile_w_cache(self, *arg, **kw):
+    def _compile_w_cache(
+        self,
+        dialect: Dialect,
+        *,
+        compiled_cache: Optional[_CompiledCacheType],
+        column_keys: List[str],
+        for_executemany: bool = False,
+        schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+        **kw: Any,
+    ) -> Tuple[
+        Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats
+    ]:
         raise NotImplementedError()
 
 
+class DDLIfCallable(Protocol):
+    def __call__(
+        self,
+        ddl: "DDLElement",
+        target: "SchemaItem",
+        bind: Optional["Connection"],
+        tables: Optional[List["Table"]] = None,
+        state: Optional[Any] = None,
+        *,
+        dialect: Dialect,
+        compiler: Optional[DDLCompiler] = ...,
+        checkfirst: bool,
+    ) -> bool:
+        ...
+
+
+class DDLIf(typing.NamedTuple):
+    dialect: Optional[str]
+    callable_: Optional[DDLIfCallable]
+    state: Optional[Any]
+
+    def _should_execute(self, ddl, target, bind, compiler=None, **kw):
+        if bind is not None:
+            dialect = bind.dialect
+        elif compiler is not None:
+            dialect = compiler.dialect
+        else:
+            assert False, "compiler or dialect is required"
+
+        if isinstance(self.dialect, str):
+            if self.dialect != dialect.name:
+                return False
+        elif isinstance(self.dialect, (tuple, list, set)):
+            if dialect.name not in self.dialect:
+                return False
+        if self.callable_ is not None and not self.callable_(
+            ddl,
+            target,
+            bind,
+            state=self.state,
+            dialect=dialect,
+            compiler=compiler,
+            **kw,
+        ):
+            return False
+
+        return True
+
+
 SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement")
 
 
@@ -80,10 +150,8 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
 
     """
 
-    target = None
-    on = None
-    dialect = None
-    callable_ = None
+    _ddl_if: Optional[DDLIf] = None
+    target: Optional["SchemaItem"] = None
 
     def _execute_on_connection(
         self, connection, distilled_params, execution_options
@@ -93,7 +161,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
         )
 
     @_generative
-    def against(self: SelfDDLElement, target) -> SelfDDLElement:
+    def against(self: SelfDDLElement, target: SchemaItem) -> SelfDDLElement:
         """Return a copy of this :class:`_schema.DDLElement` which will include
         the given target.
 
@@ -125,13 +193,15 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
             processing the DDL string.
 
         """
-
         self.target = target
         return self
 
     @_generative
     def execute_if(
-        self: SelfDDLElement, dialect=None, callable_=None, state=None
+        self: SelfDDLElement,
+        dialect: Optional[str] = None,
+        callable_: Optional[DDLIfCallable] = None,
+        state: Optional[Any] = None,
     ) -> SelfDDLElement:
         r"""Return a callable that will execute this
         :class:`_ddl.DDLElement` conditionally within an event handler.
@@ -155,7 +225,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
             DDL('something').execute_if(dialect=('postgresql', 'mysql'))
 
         :param callable\_: A callable, which will be invoked with
-          four positional arguments as well as optional keyword
+          three positional arguments as well as optional keyword
           arguments:
 
             :ddl:
@@ -168,13 +238,22 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
               explicitly.
 
             :bind:
-              The :class:`_engine.Connection` being used for DDL execution
+              The :class:`_engine.Connection` being used for DDL execution.
+              May be None if this construct is being created inline within
+              a table, in which case ``compiler`` will be present.
 
             :tables:
               Optional keyword argument - a list of Table objects which are to
               be created/ dropped within a MetaData.create_all() or drop_all()
               method call.
 
+            :dialect: keyword argument, but always present - the
+              :class:`.Dialect` involved in the operation.
+
+            :compiler: keyword argument.  Will be ``None`` for an engine
+              level DDL invocation, but will refer to a :class:`.DDLCompiler`
+              if this DDL element is being created inline within a table.
+
             :state:
               Optional keyword argument - will be the ``state`` argument
               passed to this function.
@@ -192,35 +271,30 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
 
         .. seealso::
 
+            :meth:`.SchemaItem.ddl_if`
+
             :class:`.DDLEvents`
 
             :ref:`event_toplevel`
 
         """
-        self.dialect = dialect
-        self.callable_ = callable_
-        self.state = state
+        self._ddl_if = DDLIf(dialect, callable_, state)
         return self
 
     def _should_execute(self, target, bind, **kw):
-        if isinstance(self.dialect, str):
-            if self.dialect != bind.engine.name:
-                return False
-        elif isinstance(self.dialect, (tuple, list, set)):
-            if bind.engine.name not in self.dialect:
-                return False
-        if self.callable_ is not None and not self.callable_(
-            self, target, bind, state=self.state, **kw
-        ):
-            return False
+        if self._ddl_if is None:
+            return True
+        else:
+            return self._ddl_if._should_execute(self, target, bind, **kw)
 
-        return True
+    def _invoke_with(self, bind):
+        if self._should_execute(self.target, bind):
+            return bind.execute(self)
 
     def __call__(self, target, bind, **kw):
         """Execute the DDL as a ddl_listener."""
 
-        if self._should_execute(target, bind, **kw):
-            return bind.execute(self.against(target))
+        self.against(target)._invoke_with(bind)
 
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
@@ -330,9 +404,10 @@ class _CreateDropBase(DDLElement):
         if_exists=False,
         if_not_exists=False,
     ):
-        self.element = element
+        self.element = self.target = element
         self.if_exists = if_exists
         self.if_not_exists = if_not_exists
+        self._ddl_if = getattr(element, "_ddl_if", None)
 
     @property
     def stringify_dialect(self):
@@ -358,11 +433,19 @@ class CreateSchema(_CreateDropBase):
 
     __visit_name__ = "create_schema"
 
-    def __init__(self, name, quote=None, **kw):
+    def __init__(
+        self,
+        name,
+        quote=None,
+        if_exists=False,
+        if_not_exists=False,
+    ):
         """Create a new :class:`.CreateSchema` construct."""
 
         self.quote = quote
-        super(CreateSchema, self).__init__(name, **kw)
+        self.element = name
+        self.if_exists = if_exists
+        self.if_not_exists = if_not_exists
 
 
 class DropSchema(_CreateDropBase):
@@ -374,12 +457,22 @@ class DropSchema(_CreateDropBase):
 
     __visit_name__ = "drop_schema"
 
-    def __init__(self, name, quote=None, cascade=False, **kw):
+    def __init__(
+        self,
+        name,
+        quote=None,
+        cascade=False,
+        if_exists=False,
+        if_not_exists=False,
+    ):
         """Create a new :class:`.DropSchema` construct."""
 
         self.quote = quote
         self.cascade = cascade
-        super(DropSchema, self).__init__(name, **kw)
+        self.quote = quote
+        self.element = name
+        self.if_exists = if_exists
+        self.if_not_exists = if_not_exists
 
 
 class CreateTable(_CreateDropBase):
@@ -427,6 +520,11 @@ class _DropView(_CreateDropBase):
     __visit_name__ = "drop_view"
 
 
+class CreateConstraint(_DDLCompiles):
+    def __init__(self, element):
+        self.element = element
+
+
 class CreateColumn(_DDLCompiles):
     """Represent a :class:`_schema.Column`
     as rendered in a CREATE TABLE statement,
@@ -784,15 +882,10 @@ class SchemaGenerator(DDLBase):
             # e.g., don't omit any foreign key constraints
             include_foreign_key_constraints = None
 
-        self.connection.execute(
-            # fmt: off
-            CreateTable(
-                table,
-                include_foreign_key_constraints=  # noqa
-                    include_foreign_key_constraints,  # noqa
-            )
-            # fmt: on
-        )
+        CreateTable(
+            table,
+            include_foreign_key_constraints=include_foreign_key_constraints,
+        )._invoke_with(self.connection)
 
         if hasattr(table, "indexes"):
             for index in table.indexes:
@@ -800,11 +893,11 @@ class SchemaGenerator(DDLBase):
 
         if self.dialect.supports_comments and not self.dialect.inline_comments:
             if table.comment is not None:
-                self.connection.execute(SetTableComment(table))
+                SetTableComment(table)._invoke_with(self.connection)
 
             for column in table.columns:
                 if column.comment is not None:
-                    self.connection.execute(SetColumnComment(column))
+                    SetColumnComment(column)._invoke_with(self.connection)
 
         table.dispatch.after_create(
             table,
@@ -817,17 +910,17 @@ class SchemaGenerator(DDLBase):
     def visit_foreign_key_constraint(self, constraint):
         if not self.dialect.supports_alter:
             return
-        self.connection.execute(AddConstraint(constraint))
+        AddConstraint(constraint)._invoke_with(self.connection)
 
     def visit_sequence(self, sequence, create_ok=False):
         if not create_ok and not self._can_create_sequence(sequence):
             return
-        self.connection.execute(CreateSequence(sequence))
+        CreateSequence(sequence)._invoke_with(self.connection)
 
     def visit_index(self, index, create_ok=False):
         if not create_ok and not self._can_create_index(index):
             return
-        self.connection.execute(CreateIndex(index))
+        CreateIndex(index)._invoke_with(self.connection)
 
 
 class SchemaDropper(DDLBase):
@@ -964,7 +1057,7 @@ class SchemaDropper(DDLBase):
         if not drop_ok and not self._can_drop_index(index):
             return
 
-        self.connection.execute(DropIndex(index))
+        DropIndex(index)(index, self.connection)
 
     def visit_table(
         self,
@@ -984,7 +1077,7 @@ class SchemaDropper(DDLBase):
             _is_metadata_operation=_is_metadata_operation,
         )
 
-        self.connection.execute(DropTable(table))
+        DropTable(table)._invoke_with(self.connection)
 
         # traverse client side defaults which may refer to server-side
         # sequences. noting that some of these client side defaults may also be
@@ -1009,19 +1102,21 @@ class SchemaDropper(DDLBase):
     def visit_foreign_key_constraint(self, constraint):
         if not self.dialect.supports_alter:
             return
-        self.connection.execute(DropConstraint(constraint))
+        DropConstraint(constraint)._invoke_with(self.connection)
 
     def visit_sequence(self, sequence, drop_ok=False):
 
         if not drop_ok and not self._can_drop_sequence(sequence):
             return
-        self.connection.execute(DropSequence(sequence))
+        DropSequence(sequence)._invoke_with(self.connection)
 
 
 def sort_tables(
-    tables: Sequence["Table"],
+    tables: typing_Sequence["Table"],
     skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
-    extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None,
+    extra_dependencies: Optional[
+        typing_Sequence[Tuple["Table", "Table"]]
+    ] = None,
 ) -> List["Table"]:
     """Sort a collection of :class:`_schema.Table` objects based on
     dependency.
@@ -1082,16 +1177,17 @@ def sort_tables(
     """
 
     if skip_fn is not None:
+        fixed_skip_fn = skip_fn
 
         def _skip_fn(fkc):
             for fk in fkc.elements:
-                if skip_fn(fk):
+                if fixed_skip_fn(fk):
                     return True
             else:
                 return None
 
     else:
-        _skip_fn = None
+        _skip_fn = None  # type: ignore
 
     return [
         t
index 540b62e8aa469e37db1589f8a625efa59bdbb43a..dfe82432d77e6f6c240ca39add985e841b18210c 100644 (file)
@@ -184,6 +184,61 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
     _use_schema_map = True
 
 
+SelfHasConditionalDDL = TypeVar(
+    "SelfHasConditionalDDL", bound="HasConditionalDDL"
+)
+
+
+class HasConditionalDDL:
+    """define a class that includes the :meth:`.HasConditionalDDL.ddl_if`
+    method, allowing for conditional rendering of DDL.
+
+    Currently applies to constraints and indexes.
+
+    .. versionadded:: 2.0
+
+
+    """
+
+    _ddl_if: Optional[ddl.DDLIf] = None
+
+    def ddl_if(
+        self: SelfHasConditionalDDL,
+        dialect: Optional[str] = None,
+        callable_: Optional[ddl.DDLIfCallable] = None,
+        state: Optional[Any] = None,
+    ) -> SelfHasConditionalDDL:
+        r"""apply a conditional DDL rule to this schema item.
+
+        These rules work in a similar manner to the
+        :meth:`.DDLElement.execute_if` callable, with the added feature that
+        the criteria may be checked within the DDL compilation phase for a
+        construct such as :class:`.CreateTable`.
+        :meth:`.HasConditionalDDL.ddl_if` currently applies towards the
+        :class:`.Index` construct as well as all :class:`.Constraint`
+        constructs.
+
+        :param dialect: string name of a dialect, or a tuple of string names
+         to indicate multiple dialect types.
+
+        :param callable\_: a callable that is constructed using the same form
+         as that described in :paramref:`.DDLElement.execute_if.callable_`.
+
+        :param state: any arbitrary object that will be passed to the
+         callable, if present.
+
+        .. versionadded:: 2.0
+
+        .. seealso::
+
+            :ref:`schema_ddl_ddl_if` - background and usage examples
+
+
+        """
+        self._ddl_if = ddl.DDLIf(dialect, callable_, state)
+        return self
+
+
 class HasSchemaAttr(SchemaItem):
     """schema item that includes a top-level schema name"""
 
@@ -3355,7 +3410,7 @@ class DefaultClause(FetchedValue):
         return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
 
 
-class Constraint(DialectKWArgs, SchemaItem):
+class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem):
     """A table-level SQL constraint.
 
     :class:`_schema.Constraint` serves as the base class for the series of
@@ -3424,6 +3479,16 @@ class Constraint(DialectKWArgs, SchemaItem):
         util.set_creation_order(self)
         self._validate_dialect_kwargs(dialect_kw)
 
+    def _should_create_for_compiler(self, compiler, **kw):
+        if self._create_rule is not None and not self._create_rule(compiler):
+            return False
+        elif self._ddl_if is not None:
+            return self._ddl_if._should_execute(
+                ddl.CreateConstraint(self), self, None, compiler=compiler, **kw
+            )
+        else:
+            return True
+
     @property
     def table(self):
         try:
@@ -4292,7 +4357,9 @@ class UniqueConstraint(ColumnCollectionConstraint):
     __visit_name__ = "unique_constraint"
 
 
-class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
+class Index(
+    DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem
+):
     """A table-level INDEX.
 
     Defines a composite (one or more column) INDEX.
index acbc69537adc19c56a845d5c7c7939d588dfc1f6..aa2790b049261e59d8cc0a9e62adf6e276c51ec5 100644 (file)
@@ -79,7 +79,6 @@ module = [
     "sqlalchemy.ext.serializer",
 
     "sqlalchemy.sql.selectable",   # would be nice as strict
-    "sqlalchemy.sql.ddl",
     "sqlalchemy.sql.functions",  # would be nice as strict
     "sqlalchemy.sql.lambdas",
     "sqlalchemy.sql.dml",   # would be nice as strict
@@ -132,6 +131,7 @@ module = [
     "sqlalchemy.sql.coercions",
     "sqlalchemy.sql.compiler",
     "sqlalchemy.sql.crud",
+    "sqlalchemy.sql.ddl",  # would be nice as strict
     "sqlalchemy.sql.elements",    # would be nice as strict
     "sqlalchemy.sql.naming",
     "sqlalchemy.sql.schema",      # would be nice as strict
index f339ef1715cf404067f98caa4e92789aaa21c138..0c72e32c7b206e3d2d7c17462f755822f5e5e98f 100644 (file)
@@ -1,7 +1,11 @@
+from unittest import mock
+from unittest.mock import Mock
+
 import sqlalchemy as tsa
 from sqlalchemy import create_engine
 from sqlalchemy import create_mock_engine
 from sqlalchemy import event
+from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import String
@@ -373,7 +377,7 @@ class DDLEventTest(fixtures.TestBase):
         eq_(metadata_canary.mock_calls, [])
 
 
-class DDLExecutionTest(fixtures.TestBase):
+class DDLExecutionTest(AssertsCompiledSQL, fixtures.TestBase):
     def setup_test(self):
         self.engine = engines.mock_engine()
         self.metadata = MetaData()
@@ -485,6 +489,121 @@ class DDLExecutionTest(fixtures.TestBase):
         strings = " ".join(str(x) for x in pg_mock.mock)
         assert "my_test_constraint" in strings
 
+    @testing.combinations(("dialect",), ("callable",), ("callable_w_state",))
+    def test_inline_ddl_if_dialect_name(self, ddl_if_type):
+        nonpg_mock = engines.mock_engine(dialect_name="sqlite")
+        pg_mock = engines.mock_engine(dialect_name="postgresql")
+
+        metadata = MetaData()
+
+        capture_mock = Mock()
+        state = object()
+
+        if ddl_if_type == "dialect":
+            ddl_kwargs = dict(dialect="postgresql")
+        elif ddl_if_type == "callable":
+
+            def is_pg(ddl, target, bind, **kw):
+                capture_mock.is_pg(ddl, target, bind, **kw)
+                return kw["dialect"].name == "postgresql"
+
+            ddl_kwargs = dict(callable_=is_pg)
+        elif ddl_if_type == "callable_w_state":
+
+            def is_pg(ddl, target, bind, **kw):
+                capture_mock.is_pg(ddl, target, bind, **kw)
+                return kw["dialect"].name == "postgresql"
+
+            ddl_kwargs = dict(callable_=is_pg, state=state)
+        else:
+            assert False
+
+        data_col = Column("data", String)
+        t = Table(
+            "a",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("num", Integer),
+            data_col,
+            Index("my_pg_index", data_col).ddl_if(**ddl_kwargs),
+            CheckConstraint("num > 5").ddl_if(**ddl_kwargs),
+        )
+
+        metadata.create_all(nonpg_mock)
+        eq_(len(nonpg_mock.mock), 1)
+        self.assert_compile(
+            nonpg_mock.mock[0],
+            "CREATE TABLE a (id INTEGER NOT NULL, num INTEGER, "
+            "data VARCHAR, PRIMARY KEY (id))",
+            dialect=nonpg_mock.dialect,
+        )
+
+        metadata.create_all(pg_mock)
+
+        eq_(len(pg_mock.mock), 2)
+
+        self.assert_compile(
+            pg_mock.mock[0],
+            "CREATE TABLE a (id SERIAL NOT NULL, num INTEGER, "
+            "data VARCHAR, PRIMARY KEY (id), CHECK (num > 5))",
+            dialect=pg_mock.dialect,
+        )
+        self.assert_compile(
+            pg_mock.mock[1],
+            "CREATE INDEX my_pg_index ON a (data)",
+            dialect="postgresql",
+        )
+
+        the_index = list(t.indexes)[0]
+        the_constraint = list(
+            c for c in t.constraints if isinstance(c, CheckConstraint)
+        )[0]
+
+        if ddl_if_type in ("callable", "callable_w_state"):
+
+            if ddl_if_type == "callable":
+                check_state = None
+            else:
+                check_state = state
+
+            eq_(
+                capture_mock.mock_calls,
+                [
+                    mock.call.is_pg(
+                        mock.ANY,
+                        the_index,
+                        mock.ANY,
+                        state=check_state,
+                        dialect=nonpg_mock.dialect,
+                        compiler=None,
+                    ),
+                    mock.call.is_pg(
+                        mock.ANY,
+                        the_constraint,
+                        None,
+                        state=check_state,
+                        dialect=nonpg_mock.dialect,
+                        compiler=mock.ANY,
+                    ),
+                    mock.call.is_pg(
+                        mock.ANY,
+                        the_index,
+                        mock.ANY,
+                        state=check_state,
+                        dialect=pg_mock.dialect,
+                        compiler=None,
+                    ),
+                    mock.call.is_pg(
+                        mock.ANY,
+                        the_constraint,
+                        None,
+                        state=check_state,
+                        dialect=pg_mock.dialect,
+                        compiler=mock.ANY,
+                    ),
+                ],
+            )
+
     @testing.requires.sqlite
     def test_ddl_execute(self):
         engine = create_engine("sqlite:///")