]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix PostgreSQL for_update on JTI subclasses 4551/head
authorraylu <lurayl@gmail.com>
Sat, 16 Mar 2019 00:24:39 +0000 (17:24 -0700)
committerraylu <lurayl@gmail.com>
Wed, 20 Mar 2019 18:16:52 +0000 (11:16 -0700)
When using joined table inheritance, querying for a child model
automatically joins to the parent model. When passing the child model to
with_for_update's of kwarg like with_for_update(of=Engineer), this
results in invalid SQL like
FOR UPDATE OF engineer JOIN employee ON engineer.id = employee.id

Fixes #4550

lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_compiler.py

index 4d302dabe730e3140eee1e161fedf0fb14fb5b35..90b13473df977f4d8fee8c9d187d4ef70d797625 100644 (file)
@@ -928,6 +928,7 @@ from ...engine import reflection
 from ...sql import compiler
 from ...sql import elements
 from ...sql import expression
+from ...sql import selectable
 from ...sql import sqltypes
 from ...types import BIGINT
 from ...types import BOOLEAN
@@ -1681,10 +1682,15 @@ class PGCompiler(compiler.SQLCompiler):
             tmp = " FOR UPDATE"
 
         if select._for_update_arg.of:
-            tables = util.OrderedSet(
-                c.table if isinstance(c, expression.ColumnClause) else c
-                for c in select._for_update_arg.of
-            )
+            tables = util.OrderedSet()
+            for clause in select._for_update_arg.of:
+                if isinstance(clause, expression.ColumnClause):
+                    tables.add(clause.table)
+                else:
+                    table_classes = (schema.Table, selectable.TableClause)
+                    for f in clause.select()._froms:
+                        if isinstance(f, table_classes):
+                            tables.add(f)
             tmp += " OF " + ", ".join(
                 self.process(table, ashint=True, use_schema=False, **kw)
                 for table in tables
index 696078cc4434c43eed9d8565be75e1cab8c0a2b2..10d15fbf5af0b5482e9501029db138f2c6a24845 100644 (file)
@@ -1056,6 +1056,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1",
         )
 
+        table2 = table('table2', column('mytable_id'))
+        join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+        self.assert_compile(
+            join.select(table2.c.mytable_id == 7).
+            with_for_update(of=[join]),
+            "SELECT table2.mytable_id, "
+            "mytable.myid, mytable.name, mytable.description "
+            "FROM table2 "
+            "JOIN mytable ON table2.mytable_id = mytable.myid "
+            "WHERE table2.mytable_id = %(mytable_id_1)s "
+            "FOR UPDATE OF table2, mytable"
+        )
+
     def test_for_update_with_schema(self):
         m = MetaData()
         table1 = Table(