From: raylu Date: Sat, 16 Mar 2019 00:24:39 +0000 (-0700) Subject: Fix PostgreSQL for_update on JTI subclasses X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F4551%2Fhead;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix PostgreSQL for_update on JTI subclasses When using joined table inheritance, querying for a child model automatically joins to the parent model. When passing the child model to with_for_update's of kwarg like with_for_update(of=Engineer), this results in invalid SQL like FOR UPDATE OF engineer JOIN employee ON engineer.id = employee.id Fixes #4550 --- diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 4d302dabe7..90b13473df 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -928,6 +928,7 @@ from ...engine import reflection from ...sql import compiler from ...sql import elements from ...sql import expression +from ...sql import selectable from ...sql import sqltypes from ...types import BIGINT from ...types import BOOLEAN @@ -1681,10 +1682,15 @@ class PGCompiler(compiler.SQLCompiler): tmp = " FOR UPDATE" if select._for_update_arg.of: - tables = util.OrderedSet( - c.table if isinstance(c, expression.ColumnClause) else c - for c in select._for_update_arg.of - ) + tables = util.OrderedSet() + for clause in select._for_update_arg.of: + if isinstance(clause, expression.ColumnClause): + tables.add(clause.table) + else: + table_classes = (schema.Table, selectable.TableClause) + for f in clause.select()._froms: + if isinstance(f, table_classes): + tables.add(f) tmp += " OF " + ", ".join( self.process(table, ashint=True, use_schema=False, **kw) for table in tables diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 696078cc44..10d15fbf5a 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1056,6 +1056,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1", ) + table2 = table('table2', column('mytable_id')) + join = table2.join(table1, table2.c.mytable_id == table1.c.myid) + self.assert_compile( + join.select(table2.c.mytable_id == 7). + with_for_update(of=[join]), + "SELECT table2.mytable_id, " + "mytable.myid, mytable.name, mytable.description " + "FROM table2 " + "JOIN mytable ON table2.mytable_id = mytable.myid " + "WHERE table2.mytable_id = %(mytable_id_1)s " + "FOR UPDATE OF table2, mytable" + ) + def test_for_update_with_schema(self): m = MetaData() table1 = Table(