From: Gord Thompson Date: Mon, 25 Jan 2021 18:24:25 +0000 (-0700) Subject: Use schema._copy_expression() fully in column collection constraints X-Git-Tag: rel_1_3_23~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7dd3381edb15e9699d24d78caa8a1021667ce92d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Use schema._copy_expression() fully in column collection constraints Fixed issue where using :meth:`_schema.Table.to_metadata` (called :meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL :class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column expressions would fail to copy correctly. Fixes: #5850 Change-Id: I062480afb23f6f60962b7b55bc93f5e4e6ff05e4 (cherry picked from commit 81896c31ffc4db081f1f2bba199a52328398a236) --- diff --git a/doc/build/changelog/unreleased_13/5850.rst b/doc/build/changelog/unreleased_13/5850.rst new file mode 100644 index 0000000000..2d73a42fb2 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5850.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 5850 + + Fixed issue where using :meth:`_schema.Table.to_metadata` (called + :meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL + :class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column + expressions would fail to copy correctly. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 79522b80f9..5fdb065cd6 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -9,6 +9,7 @@ from .array import ARRAY from ...sql import elements from ...sql import expression from ...sql import functions +from ...sql import schema from ...sql.schema import ColumnCollectionConstraint @@ -218,8 +219,14 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.ops = kw.get("ops", {}) - def copy(self, **kw): - elements = [(col, self.operators[col]) for col in self.columns.keys()] + def copy(self, target_table=None, **kw): + elements = [ + ( + schema._copy_expression(expr, self.parent, target_table), + self.operators[expr.name], + ) + for expr in self.columns + ] c = self.__class__( *elements, name=self.name, diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 97fdc7b91f..1d4375aeab 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -618,6 +618,7 @@ from ...engine import default from ...engine import reflection from ...sql import ColumnElement from ...sql import compiler +from ...sql import schema from ...types import BLOB # noqa from ...types import BOOLEAN # noqa from ...types import CHAR # noqa @@ -1187,9 +1188,11 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): "on_conflict" ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ - "on_conflict_unique" - ] + col1 = list(constraint)[0] + if isinstance(col1, schema.SchemaItem): + on_conflict_clause = list(constraint)[0].dialect_options[ + "sqlite" + ]["on_conflict_unique"] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ee5a1aa211..191246f460 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -80,6 +80,9 @@ def _get_table_key(name, schema): # this should really be in sql/util.py but we'd have to # break an import cycle def _copy_expression(expression, source_table, target_table): + if source_table is None or target_table is None: + return expression + def replace(col): if ( isinstance(col, Column) @@ -3172,7 +3175,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): def __contains__(self, x): return x in self.columns - def copy(self, **kw): + def copy(self, target_table=None, **kw): # ticket #5276 constraint_kwargs = {} for dialect_name in self.dialect_options: @@ -3189,7 +3192,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): name=self.name, deferrable=self.deferrable, initially=self.initially, - *self.columns.keys(), + *[ + _copy_expression(expr, self.parent, target_table) + for expr in self.columns + ], **constraint_kwargs ) return self._schema_item_copy(c) @@ -3300,6 +3306,9 @@ class CheckConstraint(ColumnCollectionConstraint): def copy(self, target_table=None, **kw): if target_table is not None: + # note that target_table is None for the copy process of + # a column-bound CheckConstraint, so this path is not reached + # in that case. sqltext = _copy_expression(self.sqltext, self.table, target_table) else: sqltext = self.sqltext @@ -4880,10 +4889,11 @@ class Computed(FetchedValue, SchemaItem): return self def copy(self, target_table=None, **kw): - if target_table is not None: - sqltext = _copy_expression(self.sqltext, self.table, target_table) - else: - sqltext = self.sqltext + sqltext = _copy_expression( + self.sqltext, + self.column.table if self.column is not None else None, + target_table, + ) g = Computed(sqltext, persisted=self.persisted) return self._schema_item_copy(g) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 411b9547a2..e71156f442 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -4,6 +4,7 @@ from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Computed +from sqlalchemy import Date from sqlalchemy import delete from sqlalchemy import Enum from sqlalchemy import exc @@ -767,6 +768,64 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=postgresql.dialect(), ) + @testing.combinations( + (True, "deferred"), + (False, "immediate"), + argnames="deferrable_value, initially_value", + ) + def test_copy_exclude_constraint_adhoc_columns( + self, deferrable_value, initially_value + ): + meta = MetaData() + table = Table( + "mytable", + meta, + Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True), + Column("valid_from_date", Date(), nullable=True), + Column("valid_thru_date", Date(), nullable=True), + ) + cons = ExcludeConstraint( + ( + literal_column( + "daterange(valid_from_date, valid_thru_date, '[]')" + ), + "&&", + ), + where=column("valid_from_date") <= column("valid_thru_date"), + name="ex_mytable_valid_date_range", + deferrable=deferrable_value, + initially=initially_value, + ) + + table.append_constraint(cons) + expected = ( + "ALTER TABLE mytable ADD CONSTRAINT ex_mytable_valid_date_range " + "EXCLUDE USING gist " + "(daterange(valid_from_date, valid_thru_date, '[]') WITH &&) " + "WHERE (valid_from_date <= valid_thru_date) " + "%s %s" + % ( + "NOT DEFERRABLE" if not deferrable_value else "DEFERRABLE", + "INITIALLY %s" % initially_value, + ) + ) + self.assert_compile( + schema.AddConstraint(cons), + expected, + dialect=postgresql.dialect(), + ) + + meta2 = MetaData() + table2 = table.tometadata(meta2) + cons2 = [ + c for c in table2.constraints if isinstance(c, ExcludeConstraint) + ][0] + self.assert_compile( + schema.AddConstraint(cons2), + expected, + dialect=postgresql.dialect(), + ) + def test_exclude_constraint_full(self): m = MetaData() room = Column("room", Integer, primary_key=True) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index fee1fbb749..cd86d1f1c0 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -41,6 +41,7 @@ from sqlalchemy.schema import DropIndex from sqlalchemy.sql import elements from sqlalchemy.sql import naming from sqlalchemy.sql.elements import _NONE_NAME +from sqlalchemy.sql.elements import literal_column from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -760,7 +761,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): eq_(repr(const), exp) -class ToMetaDataTest(fixtures.TestBase, ComparesTables): +class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): @testing.requires.check_constraints def test_copy(self): # TODO: modernize this test @@ -913,6 +914,53 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): a2 = a.tometadata(m2) assert b2.c.y.references(a2.c.x) + def test_column_collection_constraint_w_ad_hoc_columns(self): + """Test ColumnCollectionConstraint that has columns that aren't + part of the Table. + + """ + meta = MetaData() + + uq1 = UniqueConstraint(literal_column("some_name")) + cc1 = CheckConstraint(literal_column("some_name") > 5) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + uq1, + cc1, + ) + + self.assert_compile( + schema.AddConstraint(uq1), + "ALTER TABLE mytable ADD UNIQUE (some_name)", + dialect="default", + ) + self.assert_compile( + schema.AddConstraint(cc1), + "ALTER TABLE mytable ADD CHECK (some_name > 5)", + dialect="default", + ) + meta2 = MetaData() + table2 = table.tometadata(meta2) + uq2 = [ + c for c in table2.constraints if isinstance(c, UniqueConstraint) + ][0] + cc2 = [ + c for c in table2.constraints if isinstance(c, CheckConstraint) + ][0] + self.assert_compile( + schema.AddConstraint(uq2), + "ALTER TABLE mytable ADD UNIQUE (some_name)", + dialect="default", + ) + self.assert_compile( + schema.AddConstraint(cc2), + "ALTER TABLE mytable ADD CHECK (some_name > 5)", + dialect="default", + ) + def test_change_schema(self): meta = MetaData()