]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Factor out constraints into separate methods
authorG Allajmi <ghaith.ger@gmail.com>
Tue, 9 Dec 2025 19:13:52 +0000 (14:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Dec 2025 20:12:34 +0000 (15:12 -0500)
Fixed issue where PostgreSQL dialect options such as ``postgresql_include``
on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were
rendered in the wrong position when combined with constraint deferrability
options like ``deferrable=True``. Pull request courtesy G Allajmi.

Fixes: #12867
Closes: #13003
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13003
Pull-request-sha: 1a9216062f12cba2695b0b4a1407e092556c2305

Change-Id: I8c55d8faae25d56ff63c9126d569c01d8ee6c7dd

doc/build/changelog/unreleased_20/12867.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
test/dialect/postgresql/test_compiler.py
test/sql/test_constraints.py

diff --git a/doc/build/changelog/unreleased_20/12867.rst b/doc/build/changelog/unreleased_20/12867.rst
new file mode 100644 (file)
index 0000000..c0ab6fc
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 12867
+
+    Fixed issue where PostgreSQL dialect options such as ``postgresql_include``
+    on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were
+    rendered in the wrong position when combined with constraint deferrability
+    options like ``deferrable=True``. Pull request courtesy G Allajmi.
index c4f166ad668c95329cf20f71cc8298d484674f65..fca3defe16759d319dd632f33d42191dd11baa97 100644 (file)
@@ -2580,13 +2580,19 @@ class PGDDLCompiler(compiler.DDLCompiler):
         return text
 
     def visit_primary_key_constraint(self, constraint, **kw):
-        text = super().visit_primary_key_constraint(constraint)
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_primary_key_body(constraint, **kw)
         text += self._define_include(constraint)
+        text += self.define_constraint_deferrability(constraint)
         return text
 
     def visit_unique_constraint(self, constraint, **kw):
-        text = super().visit_unique_constraint(constraint)
+        if len(constraint) == 0:
+            return ""
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_unique_body(constraint, **kw)
         text += self._define_include(constraint)
+        text += self.define_constraint_deferrability(constraint)
         return text
 
     @util.memoized_property
index 43e70d08b0bcf1dfa77553d020c5f3ab72cec0b2..ab507997bb72fc56d6cfb6717e318d47d5d09190 100644 (file)
@@ -117,6 +117,7 @@ if typing.TYPE_CHECKING:
     from .elements import Null
     from .elements import True_
     from .functions import Function
+    from .schema import CheckConstraint
     from .schema import Column
     from .schema import Constraint
     from .schema import ForeignKeyConstraint
@@ -7366,26 +7367,14 @@ class DDLCompiler(Compiled):
             return self.visit_check_constraint(constraint)
 
     def visit_check_constraint(self, constraint, **kw):
-        text = ""
-        if constraint.name is not None:
-            formatted_name = self.preparer.format_constraint(constraint)
-            if formatted_name is not None:
-                text += "CONSTRAINT %s " % formatted_name
-        text += "CHECK (%s)" % self.sql_compiler.process(
-            constraint.sqltext, include_table=False, literal_binds=True
-        )
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_check_body(constraint, **kw)
         text += self.define_constraint_deferrability(constraint)
         return text
 
     def visit_column_check_constraint(self, constraint, **kw):
-        text = ""
-        if constraint.name is not None:
-            formatted_name = self.preparer.format_constraint(constraint)
-            if formatted_name is not None:
-                text += "CONSTRAINT %s " % formatted_name
-        text += "CHECK (%s)" % self.sql_compiler.process(
-            constraint.sqltext, include_table=False, literal_binds=True
-        )
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_check_body(constraint, **kw)
         text += self.define_constraint_deferrability(constraint)
         return text
 
@@ -7394,11 +7383,50 @@ class DDLCompiler(Compiled):
     ) -> str:
         if len(constraint) == 0:
             return ""
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_primary_key_body(constraint, **kw)
+        text += self.define_constraint_deferrability(constraint)
+        return text
+
+    def visit_foreign_key_constraint(
+        self, constraint: ForeignKeyConstraint, **kw: Any
+    ) -> str:
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_foreign_key_body(constraint, **kw)
+        text += self.define_constraint_match(constraint)
+        text += self.define_constraint_cascades(constraint)
+        text += self.define_constraint_deferrability(constraint)
+        return text
+
+    def define_constraint_remote_table(self, constraint, table, preparer):
+        """Format the remote table clause of a CREATE CONSTRAINT clause."""
+
+        return preparer.format_table(table)
+
+    def visit_unique_constraint(
+        self, constraint: UniqueConstraint, **kw: Any
+    ) -> str:
+        if len(constraint) == 0:
+            return ""
+        text = self.define_constraint_preamble(constraint, **kw)
+        text += self.define_unique_body(constraint, **kw)
+        text += self.define_constraint_deferrability(constraint)
+        return text
+
+    def define_constraint_preamble(
+        self, constraint: Constraint, **kw: Any
+    ) -> str:
         text = ""
         if constraint.name is not None:
             formatted_name = self.preparer.format_constraint(constraint)
             if formatted_name is not None:
                 text += "CONSTRAINT %s " % formatted_name
+        return text
+
+    def define_primary_key_body(
+        self, constraint: PrimaryKeyConstraint, **kw: Any
+    ) -> str:
+        text = ""
         text += "PRIMARY KEY "
         text += "(%s)" % ", ".join(
             self.preparer.quote(c.name)
@@ -7408,18 +7436,14 @@ class DDLCompiler(Compiled):
                 else constraint.columns
             )
         )
-        text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_foreign_key_constraint(self, constraint, **kw):
+    def define_foreign_key_body(
+        self, constraint: ForeignKeyConstraint, **kw: Any
+    ) -> str:
         preparer = self.preparer
-        text = ""
-        if constraint.name is not None:
-            formatted_name = self.preparer.format_constraint(constraint)
-            if formatted_name is not None:
-                text += "CONSTRAINT %s " % formatted_name
         remote_table = list(constraint.elements)[0].column.table
-        text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+        text = "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
             ", ".join(
                 preparer.quote(f.parent.name) for f in constraint.elements
             ),
@@ -7430,31 +7454,21 @@ class DDLCompiler(Compiled):
                 preparer.quote(f.column.name) for f in constraint.elements
             ),
         )
-        text += self.define_constraint_match(constraint)
-        text += self.define_constraint_cascades(constraint)
-        text += self.define_constraint_deferrability(constraint)
         return text
 
-    def define_constraint_remote_table(self, constraint, table, preparer):
-        """Format the remote table clause of a CREATE CONSTRAINT clause."""
-
-        return preparer.format_table(table)
-
-    def visit_unique_constraint(
+    def define_unique_body(
         self, constraint: UniqueConstraint, **kw: Any
     ) -> str:
-        if len(constraint) == 0:
-            return ""
-        text = ""
-        if constraint.name is not None:
-            formatted_name = self.preparer.format_constraint(constraint)
-            if formatted_name is not None:
-                text += "CONSTRAINT %s " % formatted_name
-        text += "UNIQUE %s(%s)" % (
+        text = "UNIQUE %s(%s)" % (
             self.define_unique_constraint_distinct(constraint, **kw),
             ", ".join(self.preparer.quote(c.name) for c in constraint),
         )
-        text += self.define_constraint_deferrability(constraint)
+        return text
+
+    def define_check_body(self, constraint: CheckConstraint, **kw: Any) -> str:
+        text = "CHECK (%s)" % self.sql_compiler.process(
+            constraint.sqltext, include_table=False, literal_binds=True
+        )
         return text
 
     def define_unique_constraint_distinct(
@@ -7500,7 +7514,7 @@ class DDLCompiler(Compiled):
             )
         return text
 
-    def define_constraint_match(self, constraint):
+    def define_constraint_match(self, constraint: ForeignKeyConstraint) -> str:
         text = ""
         if constraint.match is not None:
             text += " MATCH %s" % constraint.match
index ed1bece524c44534a8f52d3b330cf0232d1cda5e..817fb620c588881b0beb81be5b1bf2e0d4958642 100644 (file)
@@ -2666,6 +2666,212 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
         )
 
+    def test_primary_key_constraint_with_include(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            Column("data", Integer),
+            PrimaryKeyConstraint("id", postgresql_include=["data"]),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id SERIAL NOT NULL, data INTEGER, "
+            "PRIMARY KEY (id) INCLUDE (data))",
+        )
+
+    def test_primary_key_constraint_with_deferrable(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            PrimaryKeyConstraint("id", deferrable=True),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id SERIAL NOT NULL, "
+            "PRIMARY KEY (id) DEFERRABLE)",
+        )
+
+    def test_primary_key_constraint_with_deferrable_and_include(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            Column("created_at", Integer),
+            PrimaryKeyConstraint(
+                "id",
+                deferrable=True,
+                initially="IMMEDIATE",
+                postgresql_include=["created_at"],
+            ),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id SERIAL NOT NULL, created_at INTEGER, "
+            "PRIMARY KEY (id) INCLUDE (created_at) "
+            "DEFERRABLE INITIALLY IMMEDIATE)",
+        )
+
+    def test_foreign_key_constraint_with_deferrable(self):
+        m = MetaData()
+        Table("t1", m, Column("id", Integer, primary_key=True))
+        t2 = Table(
+            "t2",
+            m,
+            Column("t1_id", Integer),
+            ForeignKeyConstraint(["t1_id"], ["t1.id"], deferrable=True),
+        )
+        self.assert_compile(
+            schema.CreateTable(t2),
+            "CREATE TABLE t2 (t1_id INTEGER, "
+            "FOREIGN KEY(t1_id) REFERENCES t1 (id) DEFERRABLE)",
+        )
+
+    def test_foreign_key_constraint_with_not_valid(self):
+        m = MetaData()
+        Table("t1", m, Column("id", Integer, primary_key=True))
+        t2 = Table(
+            "t2",
+            m,
+            Column("t1_id", Integer),
+            ForeignKeyConstraint(
+                ["t1_id"], ["t1.id"], name="fk_t1", postgresql_not_valid=True
+            ),
+        )
+        self.assert_compile(
+            schema.CreateTable(t2),
+            "CREATE TABLE t2 (t1_id INTEGER, "
+            "CONSTRAINT fk_t1 FOREIGN KEY(t1_id) REFERENCES "
+            "t1 (id) NOT VALID)",
+        )
+
+    def test_foreign_key_constraint_with_cascades_and_not_valid(self):
+        m = MetaData()
+        Table("t1", m, Column("id", Integer, primary_key=True))
+        t2 = Table("t2", m, Column("t1_id", Integer))
+        constraint = ForeignKeyConstraint(
+            ["t1_id"],
+            ["t1.id"],
+            name="fk_t1",
+            ondelete="CASCADE",
+            onupdate="SET NULL",
+            postgresql_not_valid=True,
+        )
+        t2.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT fk_t1 FOREIGN KEY(t1_id) "
+            "REFERENCES t1 (id) ON DELETE CASCADE "
+            "ON UPDATE SET NULL NOT VALID",
+        )
+
+    def test_unique_constraint_with_deferrable(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            UniqueConstraint("id", name="uq_id", deferrable=True),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id INTEGER, "
+            "CONSTRAINT uq_id UNIQUE (id) DEFERRABLE)",
+        )
+
+    def test_unique_constraint_with_include(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            Column("data", Integer),
+            Column("created_at", Integer),
+            UniqueConstraint(
+                "id", name="uq_id", postgresql_include=["data", "created_at"]
+            ),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id INTEGER, data INTEGER, created_at INTEGER, "
+            "CONSTRAINT uq_id UNIQUE (id) INCLUDE (data, created_at))",
+        )
+
+    def test_unique_constraint_with_deferrable_and_include(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("id", Integer),
+            Column("data", Integer),
+            UniqueConstraint(
+                "id",
+                name="uq_id",
+                deferrable=True,
+                initially="DEFERRED",
+                postgresql_include=["data"],
+            ),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (id INTEGER, data INTEGER, "
+            "CONSTRAINT uq_id UNIQUE (id) INCLUDE (data) "
+            "DEFERRABLE INITIALLY DEFERRED)",
+        )
+
+    def test_check_constraint_with_not_valid(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("data", Integer),
+            CheckConstraint(
+                "data > 0", name="ck_data", postgresql_not_valid=True
+            ),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (data INTEGER, "
+            "CONSTRAINT ck_data CHECK (data > 0) NOT VALID)",
+        )
+
+    def test_check_constraint_with_deferrable_and_not_valid(self):
+        m = MetaData()
+        tbl = Table("test", m, Column("data", Integer))
+        constraint = CheckConstraint(
+            "data > 0",
+            name="ck_data",
+            deferrable=True,
+            initially="DEFERRED",
+            postgresql_not_valid=True,
+        )
+        tbl.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE test ADD CONSTRAINT ck_data CHECK (data > 0) "
+            "DEFERRABLE INITIALLY DEFERRED NOT VALID",
+        )
+
+    def test_check_constraint_with_deferrable(self):
+        m = MetaData()
+        tbl = Table(
+            "test",
+            m,
+            Column("data", Integer),
+            CheckConstraint("data > 0", name="ck_data", deferrable=True),
+        )
+        self.assert_compile(
+            schema.CreateTable(tbl),
+            "CREATE TABLE test (data INTEGER, "
+            "CONSTRAINT ck_data CHECK (data > 0) DEFERRABLE)",
+        )
+
     @testing.fixture
     def update_tables(self):
         self.weather = table(
index 6c47edd9c8faf2e4d1ff654414a3d7b8f08cfdad..70055ace1e5991d015617ba97e76602f92158236 100644 (file)
@@ -1380,3 +1380,154 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             schema.CreateIndex(constraint), "CREATE INDEX name ON tbl (a + 5)"
         )
+
+
+class ConstraintCompositionTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def _constraint_create_fixture(self):
+        m = MetaData()
+        t = Table("tbl", m, Column("a", Integer), Column("b", Integer))
+        t2 = Table("t2", m, Column("a", Integer), Column("b", Integer))
+        return t, t2
+
+    def test_define_check_body(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = CheckConstraint("a > 5", table=t)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CHECK (a > 5)",
+        )
+
+    def test_define_check_body_column_level(self):
+        m = MetaData()
+        t = Table(
+            "tbl",
+            m,
+            Column("a", Integer, CheckConstraint("a > 5", name="ck_a")),
+        )
+
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE tbl (a INTEGER CONSTRAINT ck_a CHECK (a > 5))",
+        )
+
+    def test_define_foreign_key_body_single_column(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = ForeignKeyConstraint(["b"], ["t2.a"])
+        t.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)",
+        )
+
+    def test_define_foreign_key_body_multi_column(self):
+        m = MetaData()
+        Table(
+            "t1",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+        )
+        t2 = Table("t2", m, Column("a", Integer), Column("b", Integer))
+
+        constraint = ForeignKeyConstraint(["a", "b"], ["t1.id", "t1.id2"])
+        t2.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD FOREIGN KEY(a, b) REFERENCES t1 (id, id2)",
+        )
+
+    def test_define_unique_body_single_column(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = UniqueConstraint("a", name="uq_a")
+        t.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT uq_a UNIQUE (a)",
+        )
+
+    def test_define_unique_body_multi_column(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = UniqueConstraint("a", "b", name="uq_ab")
+        t.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT uq_ab UNIQUE (a, b)",
+        )
+
+    def test_define_constraint_preamble_named(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = CheckConstraint("a > 5", name="ck_test", table=t)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT ck_test CHECK (a > 5)",
+        )
+
+    def test_define_constraint_preamble_unnamed(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = CheckConstraint("a > 5", table=t)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CHECK (a > 5)",
+        )
+
+    def test_visit_check_constraint_composition(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = CheckConstraint(
+            "a < b",
+            name="ck_test",
+            deferrable=True,
+            initially="DEFERRED",
+            table=t,
+        )
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT ck_test CHECK (a < b) "
+            "DEFERRABLE INITIALLY DEFERRED",
+        )
+
+    def test_visit_foreign_key_constraint_composition(self):
+        m = MetaData()
+        Table("t1", m, Column("a", Integer, primary_key=True))
+        t2 = Table("t2", m, Column("b", Integer))
+
+        constraint = ForeignKeyConstraint(
+            ["b"],
+            ["t1.a"],
+            name="fk_test",
+            ondelete="CASCADE",
+            onupdate="SET NULL",
+            match="FULL",
+            deferrable=True,
+            initially="IMMEDIATE",
+        )
+        t2.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT fk_test FOREIGN KEY(b) "
+            "REFERENCES t1 (a) MATCH FULL ON DELETE CASCADE "
+            "ON UPDATE SET NULL DEFERRABLE INITIALLY IMMEDIATE",
+        )
+
+    def test_visit_unique_constraint_composition(self):
+        t, _ = self._constraint_create_fixture()
+        constraint = UniqueConstraint(
+            "a", "b", name="uq_test", deferrable=True, initially="DEFERRED"
+        )
+        t.append_constraint(constraint)
+
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT uq_test UNIQUE (a, b) "
+            "DEFERRABLE INITIALLY DEFERRED",
+        )