From 9fe3c3cd30bd7d4afc877bb2243ba7679ebe185d Mon Sep 17 00:00:00 2001 From: G Allajmi Date: Tue, 9 Dec 2025 14:13:52 -0500 Subject: [PATCH] Factor out constraints into separate methods Fixed issue where PostgreSQL dialect options such as ``postgresql_include`` on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were rendered in the wrong position when combined with constraint deferrability options like ``deferrable=True``. Pull request courtesy G Allajmi. Fixes: #12867 Closes: #13003 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13003 Pull-request-sha: 1a9216062f12cba2695b0b4a1407e092556c2305 Change-Id: I8c55d8faae25d56ff63c9126d569c01d8ee6c7dd --- doc/build/changelog/unreleased_20/12867.rst | 8 + lib/sqlalchemy/dialects/postgresql/base.py | 10 +- lib/sqlalchemy/sql/compiler.py | 100 ++++++---- test/dialect/postgresql/test_compiler.py | 206 ++++++++++++++++++++ test/sql/test_constraints.py | 151 ++++++++++++++ 5 files changed, 430 insertions(+), 45 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/12867.rst diff --git a/doc/build/changelog/unreleased_20/12867.rst b/doc/build/changelog/unreleased_20/12867.rst new file mode 100644 index 0000000000..c0ab6fc4c1 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12867.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 12867 + + Fixed issue where PostgreSQL dialect options such as ``postgresql_include`` + on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were + rendered in the wrong position when combined with constraint deferrability + options like ``deferrable=True``. Pull request courtesy G Allajmi. diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c4f166ad66..fca3defe16 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2580,13 +2580,19 @@ class PGDDLCompiler(compiler.DDLCompiler): return text def visit_primary_key_constraint(self, constraint, **kw): - text = super().visit_primary_key_constraint(constraint) + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_primary_key_body(constraint, **kw) text += self._define_include(constraint) + text += self.define_constraint_deferrability(constraint) return text def visit_unique_constraint(self, constraint, **kw): - text = super().visit_unique_constraint(constraint) + if len(constraint) == 0: + return "" + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_unique_body(constraint, **kw) text += self._define_include(constraint) + text += self.define_constraint_deferrability(constraint) return text @util.memoized_property diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 43e70d08b0..ab507997bb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -117,6 +117,7 @@ if typing.TYPE_CHECKING: from .elements import Null from .elements import True_ from .functions import Function + from .schema import CheckConstraint from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint @@ -7366,26 +7367,14 @@ class DDLCompiler(Compiled): return self.visit_check_constraint(constraint) def visit_check_constraint(self, constraint, **kw): - text = "" - if constraint.name is not None: - formatted_name = self.preparer.format_constraint(constraint) - if formatted_name is not None: - text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process( - constraint.sqltext, include_table=False, literal_binds=True - ) + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_check_body(constraint, **kw) text += self.define_constraint_deferrability(constraint) return text def visit_column_check_constraint(self, constraint, **kw): - text = "" - if constraint.name is not None: - formatted_name = self.preparer.format_constraint(constraint) - if formatted_name is not None: - text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process( - constraint.sqltext, include_table=False, literal_binds=True - ) + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_check_body(constraint, **kw) text += self.define_constraint_deferrability(constraint) return text @@ -7394,11 +7383,50 @@ class DDLCompiler(Compiled): ) -> str: if len(constraint) == 0: return "" + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_primary_key_body(constraint, **kw) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint( + self, constraint: ForeignKeyConstraint, **kw: Any + ) -> str: + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_foreign_key_body(constraint, **kw) + text += self.define_constraint_match(constraint) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + + def visit_unique_constraint( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: + if len(constraint) == 0: + return "" + text = self.define_constraint_preamble(constraint, **kw) + text += self.define_unique_body(constraint, **kw) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_preamble( + self, constraint: Constraint, **kw: Any + ) -> str: text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name + return text + + def define_primary_key_body( + self, constraint: PrimaryKeyConstraint, **kw: Any + ) -> str: + text = "" text += "PRIMARY KEY " text += "(%s)" % ", ".join( self.preparer.quote(c.name) @@ -7408,18 +7436,14 @@ class DDLCompiler(Compiled): else constraint.columns ) ) - text += self.define_constraint_deferrability(constraint) return text - def visit_foreign_key_constraint(self, constraint, **kw): + def define_foreign_key_body( + self, constraint: ForeignKeyConstraint, **kw: Any + ) -> str: preparer = self.preparer - text = "" - if constraint.name is not None: - formatted_name = self.preparer.format_constraint(constraint) - if formatted_name is not None: - text += "CONSTRAINT %s " % formatted_name remote_table = list(constraint.elements)[0].column.table - text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + text = "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ", ".join( preparer.quote(f.parent.name) for f in constraint.elements ), @@ -7430,31 +7454,21 @@ class DDLCompiler(Compiled): preparer.quote(f.column.name) for f in constraint.elements ), ) - text += self.define_constraint_match(constraint) - text += self.define_constraint_cascades(constraint) - text += self.define_constraint_deferrability(constraint) return text - def define_constraint_remote_table(self, constraint, table, preparer): - """Format the remote table clause of a CREATE CONSTRAINT clause.""" - - return preparer.format_table(table) - - def visit_unique_constraint( + def define_unique_body( self, constraint: UniqueConstraint, **kw: Any ) -> str: - if len(constraint) == 0: - return "" - text = "" - if constraint.name is not None: - formatted_name = self.preparer.format_constraint(constraint) - if formatted_name is not None: - text += "CONSTRAINT %s " % formatted_name - text += "UNIQUE %s(%s)" % ( + text = "UNIQUE %s(%s)" % ( self.define_unique_constraint_distinct(constraint, **kw), ", ".join(self.preparer.quote(c.name) for c in constraint), ) - text += self.define_constraint_deferrability(constraint) + return text + + def define_check_body(self, constraint: CheckConstraint, **kw: Any) -> str: + text = "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) return text def define_unique_constraint_distinct( @@ -7500,7 +7514,7 @@ class DDLCompiler(Compiled): ) return text - def define_constraint_match(self, constraint): + def define_constraint_match(self, constraint: ForeignKeyConstraint) -> str: text = "" if constraint.match is not None: text += " MATCH %s" % constraint.match diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index ed1bece524..817fb620c5 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -2666,6 +2666,212 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)" ) + def test_primary_key_constraint_with_include(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + Column("data", Integer), + PrimaryKeyConstraint("id", postgresql_include=["data"]), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id SERIAL NOT NULL, data INTEGER, " + "PRIMARY KEY (id) INCLUDE (data))", + ) + + def test_primary_key_constraint_with_deferrable(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + PrimaryKeyConstraint("id", deferrable=True), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id SERIAL NOT NULL, " + "PRIMARY KEY (id) DEFERRABLE)", + ) + + def test_primary_key_constraint_with_deferrable_and_include(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + Column("created_at", Integer), + PrimaryKeyConstraint( + "id", + deferrable=True, + initially="IMMEDIATE", + postgresql_include=["created_at"], + ), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id SERIAL NOT NULL, created_at INTEGER, " + "PRIMARY KEY (id) INCLUDE (created_at) " + "DEFERRABLE INITIALLY IMMEDIATE)", + ) + + def test_foreign_key_constraint_with_deferrable(self): + m = MetaData() + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table( + "t2", + m, + Column("t1_id", Integer), + ForeignKeyConstraint(["t1_id"], ["t1.id"], deferrable=True), + ) + self.assert_compile( + schema.CreateTable(t2), + "CREATE TABLE t2 (t1_id INTEGER, " + "FOREIGN KEY(t1_id) REFERENCES t1 (id) DEFERRABLE)", + ) + + def test_foreign_key_constraint_with_not_valid(self): + m = MetaData() + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table( + "t2", + m, + Column("t1_id", Integer), + ForeignKeyConstraint( + ["t1_id"], ["t1.id"], name="fk_t1", postgresql_not_valid=True + ), + ) + self.assert_compile( + schema.CreateTable(t2), + "CREATE TABLE t2 (t1_id INTEGER, " + "CONSTRAINT fk_t1 FOREIGN KEY(t1_id) REFERENCES " + "t1 (id) NOT VALID)", + ) + + def test_foreign_key_constraint_with_cascades_and_not_valid(self): + m = MetaData() + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table("t2", m, Column("t1_id", Integer)) + constraint = ForeignKeyConstraint( + ["t1_id"], + ["t1.id"], + name="fk_t1", + ondelete="CASCADE", + onupdate="SET NULL", + postgresql_not_valid=True, + ) + t2.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT fk_t1 FOREIGN KEY(t1_id) " + "REFERENCES t1 (id) ON DELETE CASCADE " + "ON UPDATE SET NULL NOT VALID", + ) + + def test_unique_constraint_with_deferrable(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + UniqueConstraint("id", name="uq_id", deferrable=True), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id INTEGER, " + "CONSTRAINT uq_id UNIQUE (id) DEFERRABLE)", + ) + + def test_unique_constraint_with_include(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + Column("data", Integer), + Column("created_at", Integer), + UniqueConstraint( + "id", name="uq_id", postgresql_include=["data", "created_at"] + ), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id INTEGER, data INTEGER, created_at INTEGER, " + "CONSTRAINT uq_id UNIQUE (id) INCLUDE (data, created_at))", + ) + + def test_unique_constraint_with_deferrable_and_include(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("id", Integer), + Column("data", Integer), + UniqueConstraint( + "id", + name="uq_id", + deferrable=True, + initially="DEFERRED", + postgresql_include=["data"], + ), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (id INTEGER, data INTEGER, " + "CONSTRAINT uq_id UNIQUE (id) INCLUDE (data) " + "DEFERRABLE INITIALLY DEFERRED)", + ) + + def test_check_constraint_with_not_valid(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("data", Integer), + CheckConstraint( + "data > 0", name="ck_data", postgresql_not_valid=True + ), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (data INTEGER, " + "CONSTRAINT ck_data CHECK (data > 0) NOT VALID)", + ) + + def test_check_constraint_with_deferrable_and_not_valid(self): + m = MetaData() + tbl = Table("test", m, Column("data", Integer)) + constraint = CheckConstraint( + "data > 0", + name="ck_data", + deferrable=True, + initially="DEFERRED", + postgresql_not_valid=True, + ) + tbl.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE test ADD CONSTRAINT ck_data CHECK (data > 0) " + "DEFERRABLE INITIALLY DEFERRED NOT VALID", + ) + + def test_check_constraint_with_deferrable(self): + m = MetaData() + tbl = Table( + "test", + m, + Column("data", Integer), + CheckConstraint("data > 0", name="ck_data", deferrable=True), + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE test (data INTEGER, " + "CONSTRAINT ck_data CHECK (data > 0) DEFERRABLE)", + ) + @testing.fixture def update_tables(self): self.weather = table( diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 6c47edd9c8..70055ace1e 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -1380,3 +1380,154 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.CreateIndex(constraint), "CREATE INDEX name ON tbl (a + 5)" ) + + +class ConstraintCompositionTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + def _constraint_create_fixture(self): + m = MetaData() + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) + return t, t2 + + def test_define_check_body(self): + t, _ = self._constraint_create_fixture() + constraint = CheckConstraint("a > 5", table=t) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CHECK (a > 5)", + ) + + def test_define_check_body_column_level(self): + m = MetaData() + t = Table( + "tbl", + m, + Column("a", Integer, CheckConstraint("a > 5", name="ck_a")), + ) + + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER CONSTRAINT ck_a CHECK (a > 5))", + ) + + def test_define_foreign_key_body_single_column(self): + t, _ = self._constraint_create_fixture() + constraint = ForeignKeyConstraint(["b"], ["t2.a"]) + t.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)", + ) + + def test_define_foreign_key_body_multi_column(self): + m = MetaData() + Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("id2", Integer, primary_key=True), + ) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) + + constraint = ForeignKeyConstraint(["a", "b"], ["t1.id", "t1.id2"]) + t2.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD FOREIGN KEY(a, b) REFERENCES t1 (id, id2)", + ) + + def test_define_unique_body_single_column(self): + t, _ = self._constraint_create_fixture() + constraint = UniqueConstraint("a", name="uq_a") + t.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT uq_a UNIQUE (a)", + ) + + def test_define_unique_body_multi_column(self): + t, _ = self._constraint_create_fixture() + constraint = UniqueConstraint("a", "b", name="uq_ab") + t.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT uq_ab UNIQUE (a, b)", + ) + + def test_define_constraint_preamble_named(self): + t, _ = self._constraint_create_fixture() + constraint = CheckConstraint("a > 5", name="ck_test", table=t) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT ck_test CHECK (a > 5)", + ) + + def test_define_constraint_preamble_unnamed(self): + t, _ = self._constraint_create_fixture() + constraint = CheckConstraint("a > 5", table=t) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CHECK (a > 5)", + ) + + def test_visit_check_constraint_composition(self): + t, _ = self._constraint_create_fixture() + constraint = CheckConstraint( + "a < b", + name="ck_test", + deferrable=True, + initially="DEFERRED", + table=t, + ) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT ck_test CHECK (a < b) " + "DEFERRABLE INITIALLY DEFERRED", + ) + + def test_visit_foreign_key_constraint_composition(self): + m = MetaData() + Table("t1", m, Column("a", Integer, primary_key=True)) + t2 = Table("t2", m, Column("b", Integer)) + + constraint = ForeignKeyConstraint( + ["b"], + ["t1.a"], + name="fk_test", + ondelete="CASCADE", + onupdate="SET NULL", + match="FULL", + deferrable=True, + initially="IMMEDIATE", + ) + t2.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT fk_test FOREIGN KEY(b) " + "REFERENCES t1 (a) MATCH FULL ON DELETE CASCADE " + "ON UPDATE SET NULL DEFERRABLE INITIALLY IMMEDIATE", + ) + + def test_visit_unique_constraint_composition(self): + t, _ = self._constraint_create_fixture() + constraint = UniqueConstraint( + "a", "b", name="uq_test", deferrable=True, initially="DEFERRED" + ) + t.append_constraint(constraint) + + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT uq_test UNIQUE (a, b) " + "DEFERRABLE INITIALLY DEFERRED", + ) -- 2.47.3