]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use schema._copy_expression() fully in column collection constraints
authorGord Thompson <gord@gordthompson.com>
Mon, 25 Jan 2021 18:24:25 +0000 (11:24 -0700)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Jan 2021 21:47:15 +0000 (16:47 -0500)
Fixed issue where using :meth:`_schema.Table.to_metadata` (called
:meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL
:class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column
expressions would fail to copy correctly.

Fixes: #5850
Change-Id: I062480afb23f6f60962b7b55bc93f5e4e6ff05e4
(cherry picked from commit 81896c31ffc4db081f1f2bba199a52328398a236)

doc/build/changelog/unreleased_13/5850.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/schema.py
test/dialect/postgresql/test_compiler.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/5850.rst b/doc/build/changelog/unreleased_13/5850.rst
new file mode 100644 (file)
index 0000000..2d73a42
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 5850
+
+    Fixed issue where using :meth:`_schema.Table.to_metadata` (called
+    :meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL
+    :class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column
+    expressions would fail to copy correctly.
\ No newline at end of file
index 79522b80f9e75c8dc1bf06aed61a269c18149002..5fdb065cd6f61d3f84ba829f2282261f1143dfc7 100644 (file)
@@ -9,6 +9,7 @@ from .array import ARRAY
 from ...sql import elements
 from ...sql import expression
 from ...sql import functions
+from ...sql import schema
 from ...sql.schema import ColumnCollectionConstraint
 
 
@@ -218,8 +219,14 @@ class ExcludeConstraint(ColumnCollectionConstraint):
 
         self.ops = kw.get("ops", {})
 
-    def copy(self, **kw):
-        elements = [(col, self.operators[col]) for col in self.columns.keys()]
+    def copy(self, target_table=None, **kw):
+        elements = [
+            (
+                schema._copy_expression(expr, self.parent, target_table),
+                self.operators[expr.name],
+            )
+            for expr in self.columns
+        ]
         c = self.__class__(
             *elements,
             name=self.name,
index 97fdc7b91f9157e04402951be47e0b079bea873b..1d4375aeaba88d348d9e69c7a68e93e4a659786a 100644 (file)
@@ -618,6 +618,7 @@ from ...engine import default
 from ...engine import reflection
 from ...sql import ColumnElement
 from ...sql import compiler
+from ...sql import schema
 from ...types import BLOB  # noqa
 from ...types import BOOLEAN  # noqa
 from ...types import CHAR  # noqa
@@ -1187,9 +1188,11 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
             "on_conflict"
         ]
         if on_conflict_clause is None and len(constraint.columns) == 1:
-            on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
-                "on_conflict_unique"
-            ]
+            col1 = list(constraint)[0]
+            if isinstance(col1, schema.SchemaItem):
+                on_conflict_clause = list(constraint)[0].dialect_options[
+                    "sqlite"
+                ]["on_conflict_unique"]
 
         if on_conflict_clause is not None:
             text += " ON CONFLICT " + on_conflict_clause
index ee5a1aa211c2c2cba7b46fa106900068a57485d3..191246f4606ffe58ea7665dc337198a0eba98857 100644 (file)
@@ -80,6 +80,9 @@ def _get_table_key(name, schema):
 # this should really be in sql/util.py but we'd have to
 # break an import cycle
 def _copy_expression(expression, source_table, target_table):
+    if source_table is None or target_table is None:
+        return expression
+
     def replace(col):
         if (
             isinstance(col, Column)
@@ -3172,7 +3175,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
     def __contains__(self, x):
         return x in self.columns
 
-    def copy(self, **kw):
+    def copy(self, target_table=None, **kw):
         # ticket #5276
         constraint_kwargs = {}
         for dialect_name in self.dialect_options:
@@ -3189,7 +3192,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
             name=self.name,
             deferrable=self.deferrable,
             initially=self.initially,
-            *self.columns.keys(),
+            *[
+                _copy_expression(expr, self.parent, target_table)
+                for expr in self.columns
+            ],
             **constraint_kwargs
         )
         return self._schema_item_copy(c)
@@ -3300,6 +3306,9 @@ class CheckConstraint(ColumnCollectionConstraint):
 
     def copy(self, target_table=None, **kw):
         if target_table is not None:
+            # note that target_table is None for the copy process of
+            # a column-bound CheckConstraint, so this path is not reached
+            # in that case.
             sqltext = _copy_expression(self.sqltext, self.table, target_table)
         else:
             sqltext = self.sqltext
@@ -4880,10 +4889,11 @@ class Computed(FetchedValue, SchemaItem):
         return self
 
     def copy(self, target_table=None, **kw):
-        if target_table is not None:
-            sqltext = _copy_expression(self.sqltext, self.table, target_table)
-        else:
-            sqltext = self.sqltext
+        sqltext = _copy_expression(
+            self.sqltext,
+            self.column.table if self.column is not None else None,
+            target_table,
+        )
         g = Computed(sqltext, persisted=self.persisted)
 
         return self._schema_item_copy(g)
index 411b9547a29fc3fed35ebeca39b2e81e6fdeb36e..e71156f442f91e1cf45d9252b7a95d5d4e18aa76 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import and_
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import Computed
+from sqlalchemy import Date
 from sqlalchemy import delete
 from sqlalchemy import Enum
 from sqlalchemy import exc
@@ -767,6 +768,64 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
+    @testing.combinations(
+        (True, "deferred"),
+        (False, "immediate"),
+        argnames="deferrable_value, initially_value",
+    )
+    def test_copy_exclude_constraint_adhoc_columns(
+        self, deferrable_value, initially_value
+    ):
+        meta = MetaData()
+        table = Table(
+            "mytable",
+            meta,
+            Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True),
+            Column("valid_from_date", Date(), nullable=True),
+            Column("valid_thru_date", Date(), nullable=True),
+        )
+        cons = ExcludeConstraint(
+            (
+                literal_column(
+                    "daterange(valid_from_date, valid_thru_date, '[]')"
+                ),
+                "&&",
+            ),
+            where=column("valid_from_date") <= column("valid_thru_date"),
+            name="ex_mytable_valid_date_range",
+            deferrable=deferrable_value,
+            initially=initially_value,
+        )
+
+        table.append_constraint(cons)
+        expected = (
+            "ALTER TABLE mytable ADD CONSTRAINT ex_mytable_valid_date_range "
+            "EXCLUDE USING gist "
+            "(daterange(valid_from_date, valid_thru_date, '[]') WITH &&) "
+            "WHERE (valid_from_date <= valid_thru_date) "
+            "%s %s"
+            % (
+                "NOT DEFERRABLE" if not deferrable_value else "DEFERRABLE",
+                "INITIALLY %s" % initially_value,
+            )
+        )
+        self.assert_compile(
+            schema.AddConstraint(cons),
+            expected,
+            dialect=postgresql.dialect(),
+        )
+
+        meta2 = MetaData()
+        table2 = table.tometadata(meta2)
+        cons2 = [
+            c for c in table2.constraints if isinstance(c, ExcludeConstraint)
+        ][0]
+        self.assert_compile(
+            schema.AddConstraint(cons2),
+            expected,
+            dialect=postgresql.dialect(),
+        )
+
     def test_exclude_constraint_full(self):
         m = MetaData()
         room = Column("room", Integer, primary_key=True)
index fee1fbb7497cc7e53bbeec9cfce1f340dcc0b17b..cd86d1f1c0cf7581d33e2e19ba90e79dc93ac92d 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.schema import DropIndex
 from sqlalchemy.sql import elements
 from sqlalchemy.sql import naming
 from sqlalchemy.sql.elements import _NONE_NAME
+from sqlalchemy.sql.elements import literal_column
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -760,7 +761,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
             eq_(repr(const), exp)
 
 
-class ToMetaDataTest(fixtures.TestBase, ComparesTables):
+class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
     @testing.requires.check_constraints
     def test_copy(self):
         # TODO: modernize this test
@@ -913,6 +914,53 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables):
         a2 = a.tometadata(m2)
         assert b2.c.y.references(a2.c.x)
 
+    def test_column_collection_constraint_w_ad_hoc_columns(self):
+        """Test ColumnCollectionConstraint that has columns that aren't
+        part of the Table.
+
+        """
+        meta = MetaData()
+
+        uq1 = UniqueConstraint(literal_column("some_name"))
+        cc1 = CheckConstraint(literal_column("some_name") > 5)
+        table = Table(
+            "mytable",
+            meta,
+            Column("myid", Integer, primary_key=True),
+            Column("name", String(40), nullable=True),
+            uq1,
+            cc1,
+        )
+
+        self.assert_compile(
+            schema.AddConstraint(uq1),
+            "ALTER TABLE mytable ADD UNIQUE (some_name)",
+            dialect="default",
+        )
+        self.assert_compile(
+            schema.AddConstraint(cc1),
+            "ALTER TABLE mytable ADD CHECK (some_name > 5)",
+            dialect="default",
+        )
+        meta2 = MetaData()
+        table2 = table.tometadata(meta2)
+        uq2 = [
+            c for c in table2.constraints if isinstance(c, UniqueConstraint)
+        ][0]
+        cc2 = [
+            c for c in table2.constraints if isinstance(c, CheckConstraint)
+        ][0]
+        self.assert_compile(
+            schema.AddConstraint(uq2),
+            "ALTER TABLE mytable ADD UNIQUE (some_name)",
+            dialect="default",
+        )
+        self.assert_compile(
+            schema.AddConstraint(cc2),
+            "ALTER TABLE mytable ADD CHECK (some_name > 5)",
+            dialect="default",
+        )
+
     def test_change_schema(self):
         meta = MetaData()