]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support column list for foreign key ON DELETE SET actions on PostgreSQL
authorDenis Laxalde <denis@laxalde.org>
Tue, 11 Mar 2025 19:48:39 +0000 (20:48 +0100)
committerDenis Laxalde <denis@laxalde.org>
Wed, 12 Mar 2025 17:10:11 +0000 (18:10 +0100)
Added support for specifying a list of columns for ON DELETE SET
(NULL|DEFAULT) actions of foreign key definition on PostgreSQL. This is
handled on both compiler and reflection sides.

In order to make it possible to override the logic of
DDLCompiler.define_constraint_cascades() in derived classes, namely
PGDDLCompiler here, we add two methods,
define_constraint_ondelete_cascade() and
define_constraint_onupdate_cascade(), the former being overridden in
PGDDLCompiler.

Test cases (tables definition) are taken from PostgreSQL test suite.

Fixes: #11595
Fixes: #11946
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py

index 1f00127bfa63e2ca16751b5aeb2eea2c81271680..ccc426b5bffaad62ccaec80f40b1e7a9240dee6c 100644 (file)
@@ -1672,6 +1672,11 @@ RESERVED_WORDS = {
     "verbose",
 }
 
+FK_ON_DELETE = re.compile(
+    r"^(?:RESTRICT|CASCADE|SET (?:NULL|DEFAULT)(?:\s*\(.+\))?|NO ACTION)$",
+    re.I,
+)
+
 colspecs = {
     sqltypes.ARRAY: _array.ARRAY,
     sqltypes.Interval: INTERVAL,
@@ -2250,6 +2255,11 @@ class PGDDLCompiler(compiler.DDLCompiler):
         text += self._define_constraint_validity(constraint)
         return text
 
+    def define_constraint_ondelete_cascade(self, constraint):
+        return " ON DELETE %s" % self.preparer.validate_sql_phrase(
+            constraint.ondelete, FK_ON_DELETE
+        )
+
     def visit_create_enum_type(self, create, **kw):
         type_ = create.element
 
@@ -4251,7 +4261,8 @@ class PGDialect(default.DefaultDialect):
             r"[\s]?(ON UPDATE "
             r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
             r"[\s]?(ON DELETE "
-            r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+            r"(CASCADE|RESTRICT|NO ACTION|"
+            r"SET (?:NULL|DEFAULT)(?:\s\(.+\))?)+)?"
             r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
             r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
         )
index 32043dd7bb4534c632e1cf7fa66bfc6dae708c89..66b874635a13a9d4c9fe176e93950877c6dafd8f 100644 (file)
@@ -7100,15 +7100,22 @@ class DDLCompiler(Compiled):
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
-            text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
-                constraint.ondelete, FK_ON_DELETE
-            )
+            text += self.define_constraint_ondelete_cascade(constraint)
+
         if constraint.onupdate is not None:
-            text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
-                constraint.onupdate, FK_ON_UPDATE
-            )
+            text += self.define_constraint_onupdate_cascade(constraint)
         return text
 
+    def define_constraint_ondelete_cascade(self, constraint):
+        return " ON DELETE %s" % self.preparer.validate_sql_phrase(
+            constraint.ondelete, FK_ON_DELETE
+        )
+
+    def define_constraint_onupdate_cascade(self, constraint):
+        return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+            constraint.onupdate, FK_ON_UPDATE
+        )
+
     def define_constraint_deferrability(self, constraint):
         text = ""
         if constraint.deferrable is not None:
index 8e241b82e5879e61430bf666fd04e6ce376ef62e..ac49f6f4b5105f481d3b32d497b9d0d224fff9f0 100644 (file)
@@ -1142,6 +1142,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             ")",
         )
 
+    def test_create_foreign_key_constraint_ondelete_column_list(self):
+        m = MetaData()
+        pktable = Table(
+            "pktable",
+            m,
+            Column("tid", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True),
+        )
+        fktable = Table(
+            "fktable",
+            m,
+            Column("tid", Integer),
+            Column("id", Integer),
+            Column("fk_id_del_set_null", Integer),
+            Column("fk_id_del_set_default", Integer, server_default=text("0")),
+            ForeignKeyConstraint(
+                columns=["tid", "fk_id_del_set_null"],
+                refcolumns=[pktable.c.tid, pktable.c.id],
+                ondelete="SET NULL (fk_id_del_set_null)",
+            ),
+            ForeignKeyConstraint(
+                columns=["tid", "fk_id_del_set_default"],
+                refcolumns=[pktable.c.tid, pktable.c.id],
+                ondelete="SET DEFAULT(fk_id_del_set_default)",
+            ),
+        )
+
+        self.assert_compile(
+            schema.CreateTable(fktable),
+            "CREATE TABLE fktable ("
+            "tid INTEGER, id INTEGER, "
+            "fk_id_del_set_null INTEGER, "
+            "fk_id_del_set_default INTEGER DEFAULT 0, "
+            "FOREIGN KEY(tid, fk_id_del_set_null)"
+            " REFERENCES pktable (tid, id)"
+            " ON DELETE SET NULL (fk_id_del_set_null), "
+            "FOREIGN KEY(tid, fk_id_del_set_default)"
+            " REFERENCES pktable (tid, id)"
+            " ON DELETE SET DEFAULT(fk_id_del_set_default)"
+            ")",
+        )
+
     def test_exclude_constraint_min(self):
         m = MetaData()
         tbl = Table("testtbl", m, Column("room", Integer, primary_key=True))
index 4d889c6775fd7329b57b78d54c2b03b36fd1895c..c67e09e42074ed41e9c74b9181a44d19e628414f 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import BigInteger
 from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import ForeignKeyConstraint
 from sqlalchemy import Identity
 from sqlalchemy import Index
 from sqlalchemy import inspect
@@ -20,6 +21,7 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
+from sqlalchemy import text
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import base as postgresql
@@ -908,6 +910,51 @@ class ReflectionTest(
         subject = Table("subject", meta2, autoload_with=connection)
         eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
 
+    def test_reflected_foreign_key_ondelete_column_list(
+        self, metadata, connection
+    ):
+        meta1 = metadata
+        pktable = Table(
+            "pktable",
+            meta1,
+            Column("tid", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True),
+        )
+        Table(
+            "fktable",
+            meta1,
+            Column("tid", Integer),
+            Column("id", Integer),
+            Column("fk_id_del_set_null", Integer),
+            Column("fk_id_del_set_default", Integer, server_default=text("0")),
+            ForeignKeyConstraint(
+                columns=["tid", "fk_id_del_set_null"],
+                refcolumns=[pktable.c.tid, pktable.c.id],
+                ondelete="SET NULL (fk_id_del_set_null)",
+            ),
+            ForeignKeyConstraint(
+                columns=["tid", "fk_id_del_set_default"],
+                refcolumns=[pktable.c.tid, pktable.c.id],
+                ondelete="SET DEFAULT(fk_id_del_set_default)",
+            ),
+        )
+
+        meta1.create_all(connection)
+        meta2 = MetaData()
+        fktable = Table("fktable", meta2, autoload_with=connection)
+        fkey_set_null = next(
+            c
+            for c in fktable.foreign_key_constraints
+            if c.name == "fktable_tid_fk_id_del_set_null_fkey"
+        )
+        eq_(fkey_set_null.ondelete, "SET NULL (fk_id_del_set_null)")
+        fkey_set_default = next(
+            c
+            for c in fktable.foreign_key_constraints
+            if c.name == "fktable_tid_fk_id_del_set_default_fkey"
+        )
+        eq_(fkey_set_default.ondelete, "SET DEFAULT (fk_id_del_set_default)")
+
     def test_pg_weirdchar_reflection(self, metadata, connection):
         meta1 = metadata
         subject = Table(