]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add psql FOR UPDATE OF functionality
authorMario Lassnig <mario@lassnig.net>
Tue, 12 Nov 2013 22:08:51 +0000 (23:08 +0100)
committerMario Lassnig <mario@lassnig.net>
Tue, 12 Nov 2013 22:08:51 +0000 (23:08 +0100)
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_lockmode.py

index e1dc4af7149e335c6254977fe48abf94879027ba..19d7c81fa7c42ad1579e79e1cf0e475a3fe712c7 100644 (file)
@@ -1015,6 +1015,8 @@ class PGCompiler(compiler.SQLCompiler):
 
     def for_update_clause(self, select):
         if select.for_update == 'nowait':
+            if select.for_update_of is not None:
+                return " FOR UPDATE OF " + select.for_update_of + " NOWAIT"
             return " FOR UPDATE NOWAIT"
         elif select.for_update == 'read':
             return " FOR SHARE"
index c9e7d444bbfa41c5c984094c94d46666681c773b..db688955d013a74894a162cecac9ddf93861c783 100644 (file)
@@ -70,6 +70,7 @@ class Query(object):
     _criterion = None
     _yield_per = None
     _lockmode = None
+    _lockmode_of = None
     _order_by = False
     _group_by = False
     _having = None
@@ -1124,7 +1125,7 @@ class Query(object):
         self._execution_options = self._execution_options.union(kwargs)
 
     @_generative()
-    def with_lockmode(self, mode):
+    def with_lockmode(self, mode, of=None):
         """Return a new Query object with the specified locking mode.
 
         :param mode: a string representing the desired locking mode. A
@@ -1148,9 +1149,13 @@ class Query(object):
 
             .. versionadded:: 0.7.7
                 ``FOR SHARE`` and ``FOR SHARE NOWAIT`` (PostgreSQL).
+        :param of: a table descriptor representing the optional OF
+            part of the clause. This passes ``for_update_of=table'``
+            which translates to ``FOR UPDATE OF table [NOWAIT]``.
         """
 
         self._lockmode = mode
+        self._lockmode_of = of
 
     @_generative()
     def params(self, *args, **kwargs):
@@ -2705,6 +2710,9 @@ class Query(object):
             except KeyError:
                 raise sa_exc.ArgumentError(
                                 "Unknown lockmode %r" % self._lockmode)
+            if self._lockmode_of is not None:
+                context.for_update_of = self._lockmode_of
+
         for entity in self._entities:
             entity.setup_context(self, context)
 
@@ -2789,6 +2797,7 @@ class Query(object):
         statement = sql.select(
                             [inner] + context.secondary_columns,
                             for_update=context.for_update,
+                            for_update_of=context.for_update_of,
                             use_labels=context.labels)
 
         from_clause = inner
@@ -2834,6 +2843,7 @@ class Query(object):
                         from_obj=context.froms,
                         use_labels=context.labels,
                         for_update=context.for_update,
+                        for_update_of=context.for_update_of,
                         order_by=context.order_by,
                         **self._select_args
                     )
@@ -3415,6 +3425,7 @@ class QueryContext(object):
     adapter = None
     froms = ()
     for_update = False
+    for_update_of = None
 
     def __init__(self, query):
 
index 4f3dbba3688ff172302d333a1ef454396c37dbff..51ec0d9ebea76bd7a34333a5069ddba4b8a6bb5b 100644 (file)
@@ -1571,6 +1571,8 @@ class SQLCompiler(Compiled):
 
     def for_update_clause(self, select):
         if select.for_update:
+            if select.for_update_of is not None:
+                return " FOR UPDATE OF " + select.for_update_of
             return " FOR UPDATE"
         else:
             return ""
index 550e250f1327a83795ebc0327396333be5cb5c7a..8ad238ca332d339283323d04ce0c837f74e1ff03 100644 (file)
@@ -1162,6 +1162,7 @@ class SelectBase(Executable, FromClause):
     def __init__(self,
             use_labels=False,
             for_update=False,
+            for_update_of=None,
             limit=None,
             offset=None,
             order_by=None,
@@ -1170,6 +1171,7 @@ class SelectBase(Executable, FromClause):
             autocommit=None):
         self.use_labels = use_labels
         self.for_update = for_update
+        self.for_update_of = for_update_of
         if autocommit is not None:
             util.warn_deprecated('autocommit on select() is '
                                  'deprecated.  Use .execution_options(a'
index 0fe82f39443522f10c9082b12a71b75b1383a903..a16a545ba80d70893d9d7ba00e10988f012468aa 100644 (file)
@@ -73,6 +73,14 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             dialect=postgresql.dialect()
         )
 
+    def test_postgres_update_of(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update', of='users'),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
+            dialect=postgresql.dialect()
+        )
+
     def test_postgres_update_nowait(self):
         User = self.classes.User
         sess = Session()
@@ -81,6 +89,14 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             dialect=postgresql.dialect()
         )
 
+    def test_postgres_update_nowait_of(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of='users'),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT",
+            dialect=postgresql.dialect()
+        )
+
     def test_oracle_update(self):
         User = self.classes.User
         sess = Session()