]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added ORM support
authorMario Lassnig <mario@lassnig.net>
Thu, 14 Nov 2013 19:18:52 +0000 (20:18 +0100)
committerMario Lassnig <mario@lassnig.net>
Thu, 14 Nov 2013 19:18:52 +0000 (20:18 +0100)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
test/orm/test_lockmode.py

index 1f5e05cba9575586c19348355acb05604e56a54c..74441e9a84af2c1ce57664d616a78a29bb81142a 100644 (file)
@@ -661,8 +661,18 @@ class OracleCompiler(compiler.SQLCompiler):
     def for_update_clause(self, select):
         if self.is_subquery():
             return ""
-        elif select.for_update == "nowait":
-            return " FOR UPDATE NOWAIT"
+
+        tmp = ' FOR UPDATE'
+
+        if isinstance(select.for_update_of, list):
+            tmp += ' OF ' + ', '.join(['.'.join(of) for of in select.for_update_of])
+        elif isinstance(select.for_update_of, tuple):
+            tmp += ' OF ' + '.'.join(select.for_update_of)
+
+        if select.for_update == 'nowait':
+            return tmp + ' NOWAIT'
+        elif select.for_update:
+            return tmp
         else:
             return super(OracleCompiler, self).for_update_clause(select)
 
index 19d7c81fa7c42ad1579e79e1cf0e475a3fe712c7..ec22e8633198f47855fdc0ac512a0458659031e0 100644 (file)
@@ -230,7 +230,7 @@ RESERVED_WORDS = set(
     "default", "deferrable", "desc", "distinct", "do", "else", "end",
     "except", "false", "fetch", "for", "foreign", "from", "grant", "group",
     "having", "in", "initially", "intersect", "into", "leading", "limit",
-    "localtime", "localtimestamp", "new", "not", "null", "off", "offset",
+    "localtime", "localtimestamp", "new", "not", "null", "of", "off", "offset",
     "old", "on", "only", "or", "order", "placing", "primary", "references",
     "returning", "select", "session_user", "some", "symmetric", "table",
     "then", "to", "trailing", "true", "union", "unique", "user", "using",
@@ -1014,14 +1014,22 @@ class PGCompiler(compiler.SQLCompiler):
             return ""
 
     def for_update_clause(self, select):
-        if select.for_update == 'nowait':
-            if select.for_update_of is not None:
-                return " FOR UPDATE OF " + select.for_update_of + " NOWAIT"
-            return " FOR UPDATE NOWAIT"
-        elif select.for_update == 'read':
-            return " FOR SHARE"
+
+        if select.for_update == 'read':
+            return ' FOR SHARE'
         elif select.for_update == 'read_nowait':
-            return " FOR SHARE NOWAIT"
+            return ' FOR SHARE NOWAIT'
+
+        tmp = ' FOR UPDATE'
+        if isinstance(select.for_update_of, list):
+            tmp += ' OF ' + ', '.join([of[0] for of in select.for_update_of])
+        elif isinstance(select.for_update_of, tuple):
+            tmp += ' OF ' + select.for_update_of[0]
+
+        if select.for_update == 'nowait':
+            return tmp + ' NOWAIT'
+        elif select.for_update:
+            return tmp
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
index db688955d013a74894a162cecac9ddf93861c783..f0b6bb03198dc362b86347935ccb2f276276ad55 100644 (file)
@@ -1149,13 +1149,29 @@ class Query(object):
 
             .. versionadded:: 0.7.7
                 ``FOR SHARE`` and ``FOR SHARE NOWAIT`` (PostgreSQL).
-        :param of: a table descriptor representing the optional OF
-            part of the clause. This passes ``for_update_of=table'``
-            which translates to ``FOR UPDATE OF table [NOWAIT]``.
+        :param of: either a column descriptor, or list of column
+            descriptors, representing the optional OF part of the
+            clause. This passes ``for_update_of=descriptor(s)'`` which
+            translates to ``FOR UPDATE OF table [NOWAIT]`` respectively
+            ``FOR UPDATE OF table, table [NOWAIT]`` (PostgreSQL), or
+            ``FOR UPDATE OF table.column [NOWAIT]`` respectively
+            ``FOR UPDATE OF table.column, table.column [NOWAIT]`` (Oracle).
+
+            .. versionadded:: 0.9.0
         """
 
         self._lockmode = mode
-        self._lockmode_of = of
+
+        # do not drag the ORM layer into the dialect,
+        # we only need the table name and column name
+        if isinstance(of, attributes.QueryableAttribute):
+            self._lockmode_of = (of.expression.table.name,
+                                 of.expression.name)
+        elif isinstance(of, (tuple, list)):
+            self._lockmode_of = [(o.expression.table.name,
+                                  o.expression.name) for o in of]
+        elif of is not None:
+            raise TypeError('OF parameter is not a column(list)')
 
     @_generative()
     def params(self, *args, **kwargs):
index 51ec0d9ebea76bd7a34333a5069ddba4b8a6bb5b..4f3dbba3688ff172302d333a1ef454396c37dbff 100644 (file)
@@ -1571,8 +1571,6 @@ class SQLCompiler(Compiled):
 
     def for_update_clause(self, select):
         if select.for_update:
-            if select.for_update_of is not None:
-                return " FOR UPDATE OF " + select.for_update_of
             return " FOR UPDATE"
         else:
             return ""
index a16a545ba80d70893d9d7ba00e10988f012468aa..f9950c2610cca530c93a18f4b5e043db6d9c7760 100644 (file)
@@ -76,11 +76,20 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_postgres_update_of(self):
         User = self.classes.User
         sess = Session()
-        self.assert_compile(sess.query(User.id).with_lockmode('update', of='users'),
+        self.assert_compile(sess.query(User.id).with_lockmode('update', of=User.id),
             "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
             dialect=postgresql.dialect()
         )
 
+    def test_postgres_update_of_list(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update', of=[User.id, User.id, User.id]),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users, users, users",
+            dialect=postgresql.dialect()
+        )
+
+
     def test_postgres_update_nowait(self):
         User = self.classes.User
         sess = Session()
@@ -92,11 +101,19 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_postgres_update_nowait_of(self):
         User = self.classes.User
         sess = Session()
-        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of='users'),
+        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of=User.id),
             "SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT",
             dialect=postgresql.dialect()
         )
 
+    def test_postgres_update_nowait_of_list(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of=[User.id, User.id, User.id]),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users, users, users NOWAIT",
+            dialect=postgresql.dialect()
+        )
+
     def test_oracle_update(self):
         User = self.classes.User
         sess = Session()
@@ -105,6 +122,22 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             dialect=oracle.dialect()
         )
 
+    def test_oracle_update_of(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update', of=User.id),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users.id",
+            dialect=oracle.dialect()
+        )
+
+    def test_oracle_update_of_list(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update', of=[User.id, User.id, User.id]),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users.id, users.id, users.id",
+            dialect=oracle.dialect()
+        )
+
     def test_oracle_update_nowait(self):
         User = self.classes.User
         sess = Session()
@@ -113,6 +146,22 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             dialect=oracle.dialect()
         )
 
+    def test_oracle_update_nowait_of(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of=User.id),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users.id NOWAIT",
+            dialect=oracle.dialect()
+        )
+
+    def test_oracle_update_nowait_of_list(self):
+        User = self.classes.User
+        sess = Session()
+        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait', of=[User.id, User.id, User.id]),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users.id, users.id, users.id NOWAIT",
+            dialect=oracle.dialect()
+        )
+
     def test_mysql_read(self):
         User = self.classes.User
         sess = Session()