]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Expand joins when calculating PostgreSQL "WITH FOR UPDATE OF"
authorraylu <lurayl@gmail.com>
Wed, 20 Mar 2019 21:22:19 +0000 (17:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Mar 2019 14:25:29 +0000 (10:25 -0400)
Modified the :paramref:`.Select.with_for_update.of` parameter so that if a
join or other composed selectable is passed, the individual :class:`.Table`
objects will be filtered from it, allowing one to pass a join() object to
the parameter, as occurs normally when using joined table inheritance with
the ORM.  Pull request courtesy Raymond Lu.

Fixes: #4550
Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #4551
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4551
Pull-request-sha: 452da77d154a4087d530456db1c9af207d65cef4

Change-Id: If4b7c231f7b71190d7245543959fb5c3351125a1

doc/build/changelog/unreleased_13/4550.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/util.py
test/dialect/postgresql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_13/4550.rst b/doc/build/changelog/unreleased_13/4550.rst
new file mode 100644 (file)
index 0000000..6837baa
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+   :tags: bug, postgresql
+   :tickets: 4550
+
+   Modified the :paramref:`.Select.with_for_update.of` parameter so that if a
+   join or other composed selectable is passed, the individual :class:`.Table`
+   objects will be filtered from it, allowing one to pass a join() object to
+   the parameter, as occurs normally when using joined table inheritance with
+   the ORM.  Pull request courtesy Raymond Lu.
+
index 4d302dabe730e3140eee1e161fedf0fb14fb5b35..3781a7ba22ddba81618c4fefc5ade2f4b6bf30bd 100644 (file)
@@ -929,6 +929,7 @@ from ...sql import compiler
 from ...sql import elements
 from ...sql import expression
 from ...sql import sqltypes
+from ...sql import util as sql_util
 from ...types import BIGINT
 from ...types import BOOLEAN
 from ...types import CHAR
@@ -1681,10 +1682,11 @@ class PGCompiler(compiler.SQLCompiler):
             tmp = " FOR UPDATE"
 
         if select._for_update_arg.of:
-            tables = util.OrderedSet(
-                c.table if isinstance(c, expression.ColumnClause) else c
-                for c in select._for_update_arg.of
-            )
+
+            tables = util.OrderedSet()
+            for c in select._for_update_arg.of:
+                tables.update(sql_util.surface_selectables_only(c))
+
             tmp += " OF " + ", ".join(
                 self.process(table, ashint=True, use_schema=False, **kw)
                 for table in tables
index 5a44f873d9c9d231de9d6b6b19dfdf349555e463..3077840c65df4a1b233101c567788de872f812dc 100644 (file)
@@ -29,11 +29,13 @@ from .elements import ColumnElement
 from .elements import Null
 from .elements import UnaryExpression
 from .schema import Column
+from .selectable import Alias
 from .selectable import FromClause
 from .selectable import FromGrouping
 from .selectable import Join
 from .selectable import ScalarSelect
 from .selectable import SelectBase
+from .selectable import TableClause
 from .. import exc
 from .. import util
 
@@ -339,6 +341,20 @@ def surface_selectables(clause):
             stack.append(elem.element)
 
 
+def surface_selectables_only(clause):
+    stack = [clause]
+    while stack:
+        elem = stack.pop()
+        if isinstance(elem, (TableClause, Alias)):
+            yield elem
+        if isinstance(elem, Join):
+            stack.extend((elem.left, elem.right))
+        elif isinstance(elem, FromGrouping):
+            stack.append(elem.element)
+        elif isinstance(elem, ColumnClause):
+            stack.append(elem.table)
+
+
 def surface_column_elements(clause, include_scalar_selects=True):
     """traverse and yield only outer-exposed column elements, such as would
     be addressable in the WHERE clause of a SELECT if this element were
index 696078cc4434c43eed9d8565be75e1cab8c0a2b2..13e4aaad5db215d7798d1d421d218fb7a5b64328 100644 (file)
@@ -1056,6 +1056,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1",
         )
 
+        table2 = table("table2", column("mytable_id"))
+        join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+        self.assert_compile(
+            join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
+            "SELECT table2.mytable_id, "
+            "mytable.myid, mytable.name, mytable.description "
+            "FROM table2 "
+            "JOIN mytable ON table2.mytable_id = mytable.myid "
+            "WHERE table2.mytable_id = %(mytable_id_1)s "
+            "FOR UPDATE OF mytable, table2",
+        )
+
+        join = table2.join(ta, table2.c.mytable_id == ta.c.myid)
+        self.assert_compile(
+            join.select(table2.c.mytable_id == 7).with_for_update(of=[join]),
+            "SELECT table2.mytable_id, "
+            "mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM table2 "
+            "JOIN mytable AS mytable_1 "
+            "ON table2.mytable_id = mytable_1.myid "
+            "WHERE table2.mytable_id = %(mytable_id_1)s "
+            "FOR UPDATE OF mytable_1, table2",
+        )
+
     def test_for_update_with_schema(self):
         m = MetaData()
         table1 = Table(