]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix up rendering of "of"
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Nov 2013 03:25:09 +0000 (22:25 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Nov 2013 03:25:09 +0000 (22:25 -0500)
- move out tests, dialect specific out of compiler, compiler tests use new API,
legacy API tests in test_selecatble
- add support for adaptation of ForUpdateArg, alias support in compilers

lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/dialect/mysql/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/dialect/test_oracle.py
test/orm/test_lockmode.py
test/sql/test_compiler.py
test/sql/test_selectable.py

index a3c31b7ccb243452140b1dfd8c1859ba905ed765..ba69c3d1fb8270f96e4ffa48b7c15a574f7f675c 100644 (file)
@@ -666,15 +666,15 @@ class OracleCompiler(compiler.SQLCompiler):
 
         tmp = ' FOR UPDATE'
 
-        if select._for_update_arg.nowait:
-            tmp += " NOWAIT"
-
         if select._for_update_arg.of:
             tmp += ' OF ' + ', '.join(
-                                    self._process(elem) for elem in
+                                    self.process(elem) for elem in
                                     select._for_update_arg.of
                                 )
 
+        if select._for_update_arg.nowait:
+            tmp += " NOWAIT"
+
         return tmp
 
 
index 091fdeda2610e9ed25a0e92960d0c16ed4051bcb..69b0fb040339c06bc9c8c026664a2e2be526b81d 100644 (file)
@@ -1020,17 +1020,17 @@ class PGCompiler(compiler.SQLCompiler):
         else:
             tmp = " FOR UPDATE"
 
-        if select._for_update_arg.nowait:
-            tmp += " NOWAIT"
-
         if select._for_update_arg.of:
             # TODO: assuming simplistic c.table here
             tables = set(c.table for c in select._for_update_arg.of)
             tmp += " OF " + ", ".join(
-                                self.process(table, asfrom=True)
+                                self.process(table, ashint=True)
                                 for table in tables
                             )
 
+        if select._for_update_arg.nowait:
+            tmp += " NOWAIT"
+
         return tmp
 
     def returning_clause(self, stmt, returning_cols):
index f0d9a47d6a208c3f447780db50c72680f8083774..173ad038eb4ec35250f642602ebe23c908bd52d3 100644 (file)
@@ -1124,10 +1124,10 @@ class Query(object):
         self._execution_options = self._execution_options.union(kwargs)
 
     @_generative()
-    def with_lockmode(self, mode, of=None):
+    def with_lockmode(self, mode):
         """Return a new Query object with the specified locking mode.
 
-        .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.for_update`.
+        .. deprecated:: 0.9.0b2 superseded by :meth:`.Query.with_for_update`.
 
         :param mode: a string representing the desired locking mode. A
             corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object
index 0fc99897efe87b10b2365b02a91b7501120de834..3ba3957d6485559d3974a1b9cced2dfa6c448e80 100644 (file)
@@ -1513,9 +1513,11 @@ class SQLCompiler(Compiled):
 
             text += self.order_by_clause(select,
                             order_by_select=order_by_select, **kwargs)
+
         if select._limit is not None or select._offset is not None:
             text += self.limit_clause(select)
-        if select._for_update_arg:
+
+        if select._for_update_arg is not None:
             text += self.for_update_clause(select)
 
         if self.ctes and \
index e49c10001c3ee7e83da8614571d5f863691af50b..01c803f3b412e80bf056b42f46041177bf0acbeb 100644 (file)
@@ -1151,7 +1151,7 @@ class TableClause(Immutable, FromClause):
         return [self]
 
 
-class ForUpdateArg(object):
+class ForUpdateArg(ClauseElement):
 
     @classmethod
     def parse_legacy_select(self, arg):
@@ -1185,6 +1185,8 @@ class ForUpdateArg(object):
             read = True
         elif arg == 'read_nowait':
             read = nowait = True
+        elif arg is not True:
+            raise exc.ArgumentError("Unknown for_update argument: %r" % arg)
 
         return ForUpdateArg(read=read, nowait=nowait)
 
@@ -1195,9 +1197,13 @@ class ForUpdateArg(object):
         elif self.read and self.nowait:
             return "read_nowait"
         elif self.nowait:
-            return "update_nowait"
+            return "nowait"
         else:
-            return "update"
+            return True
+
+    def _copy_internals(self, clone=_clone, **kw):
+        if self.of is not None:
+            self.of = [clone(col, **kw) for col in self.of]
 
     def __init__(self, nowait=False, read=False, of=None):
         """Represents arguments specified to :meth:`.Select.for_update`.
@@ -1208,7 +1214,7 @@ class ForUpdateArg(object):
         self.nowait = nowait
         self.read = read
         if of is not None:
-            self.of = [_only_column_elements(of, "of")
+            self.of = [_only_column_elements(elem, "of")
                         for elem in util.to_list(of)]
         else:
             self.of = None
@@ -1770,7 +1776,7 @@ class CompoundSelect(SelectBase):
         self.selects = [clone(s, **kw) for s in self.selects]
         if hasattr(self, '_col_map'):
             del self._col_map
-        for attr in ('_order_by_clause', '_group_by_clause'):
+        for attr in ('_order_by_clause', '_group_by_clause', '_for_update_arg'):
             if getattr(self, attr) is not None:
                 setattr(self, attr, clone(getattr(self, attr), **kw))
 
@@ -2255,7 +2261,7 @@ class Select(HasPrefixes, SelectBase):
         # present here.
         self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
         for attr in '_whereclause', '_having', '_order_by_clause', \
-            '_group_by_clause':
+            '_group_by_clause', '_for_update_arg':
             if getattr(self, attr) is not None:
                 setattr(self, attr, clone(getattr(self, attr), **kw))
 
index a50c6a90172b76d30c6265ff2f500abfe9e396d1..46e8bfb828a95fddb6c03998cbcc81f93712d206 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import sql, exc, schema, types as sqltypes
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.testing import fixtures, AssertsCompiledSQL
 from sqlalchemy import testing
+from sqlalchemy.sql import table, column
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
@@ -131,6 +132,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             schema.CreateTable(t2).compile, dialect=mysql.dialect()
         )
 
+    def test_for_update(self):
+        table1 = table('mytable',
+                    column('myid'), column('name'), column('description'))
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s FOR UPDATE")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(read=True),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE")
+
 class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
     """Tests MySQL-dialect specific compilation."""
 
index 76fd9d90722b3313f7a20d3d7a65ad9b3ffc5359..05963e51c38d67908aae3fb86585a2f28ec5271c 100644 (file)
@@ -249,6 +249,61 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             'SUBSTRING(%(substring_1)s FROM %(substring_2)s)')
 
 
+    def test_for_update(self):
+        table1 = table('mytable',
+                    column('myid'), column('name'), column('description'))
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(nowait=True),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(read=True),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                    with_for_update(read=True, nowait=True),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                    with_for_update(of=table1.c.myid),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR UPDATE OF mytable")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                with_for_update(read=True, nowait=True, of=table1.c.myid),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR SHARE OF mytable NOWAIT")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                with_for_update(read=True, nowait=True,
+                        of=[table1.c.myid, table1.c.name]),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = %(myid_1)s "
+            "FOR SHARE OF mytable NOWAIT")
+
+        ta = table1.alias()
+        self.assert_compile(
+            ta.select(ta.c.myid == 7).
+                with_for_update(of=[ta.c.myid, ta.c.name]),
+            "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM mytable AS mytable_1 "
+            "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1"
+        )
 
 
     def test_reserved_words(self):
index 185bfb883142537db450098e213dbfa9241b68fb..3af57c50b03d165fb3fca1101ad8141536c8eca1 100644 (file)
@@ -217,6 +217,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             ':ROWNUM_1) WHERE ora_rn > :ora_rn_1 FOR '
                             'UPDATE')
 
+    def test_for_update(self):
+        table1 = table('mytable',
+                    column('myid'), column('name'), column('description'))
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF mytable.myid")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).with_for_update(nowait=True),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                                with_for_update(nowait=True, of=table1.c.myid),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = :myid_1 "
+            "FOR UPDATE OF mytable.myid NOWAIT")
+
+        self.assert_compile(
+            table1.select(table1.c.myid == 7).
+                with_for_update(nowait=True, of=[table1.c.myid, table1.c.name]),
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF "
+            "mytable.myid, mytable.name NOWAIT")
+
+        ta = table1.alias()
+        self.assert_compile(
+            ta.select(ta.c.myid == 7).
+                with_for_update(of=[ta.c.myid, ta.c.name]),
+            "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM mytable mytable_1 "
+            "WHERE mytable_1.myid = :myid_1 FOR UPDATE OF "
+            "mytable_1.myid, mytable_1.name"
+        )
+
     def test_limit_preserves_typing_information(self):
         class MyType(TypeDecorator):
             impl = Integer
index f9950c2610cca530c93a18f4b5e043db6d9c7760..3a8379be9e1a142f0988aae8fb8f3a9e0860d2e5 100644 (file)
@@ -76,7 +76,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_postgres_update_of(self):
         User = self.classes.User
         sess = Session()
-        self.assert_compile(sess.query(User.id).with_lockmode('update', of=User.id),
+        self.assert_compile(sess.query(User.id).for_update(of=User.id),
             "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
             dialect=postgresql.dialect()
         )
@@ -84,8 +84,8 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_postgres_update_of_list(self):
         User = self.classes.User
         sess = Session()
-        self.assert_compile(sess.query(User.id).with_lockmode('update', of=[User.id, User.id, User.id]),
-            "SELECT users.id AS users_id FROM users FOR UPDATE OF users, users, users",
+        self.assert_compile(sess.query(User.id).for_update(of=[User.id, User.id, User.id]),
+            "SELECT users.id AS users_id FROM users FOR UPDATE OF users",
             dialect=postgresql.dialect()
         )
 
@@ -93,7 +93,7 @@ class LockModeTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_postgres_update_nowait(self):
         User = self.classes.User
         sess = Session()
-        self.assert_compile(sess.query(User.id).with_lockmode('update_nowait'),
+        self.assert_compile(sess.query(User.id).for_updatewith_lockmode('update_nowait'),
             "SELECT users.id AS users_id FROM users FOR UPDATE NOWAIT",
             dialect=postgresql.dialect()
         )
index 26cd3002664571bcf426a5ebae1c67596ced075d..f1f852ddcac71b205720f94527bfcc961eba352a 100644 (file)
@@ -1045,86 +1045,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def test_for_update(self):
         self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update=True),
+            table1.select(table1.c.myid == 7).with_for_update(),
             "SELECT mytable.myid, mytable.name, mytable.description "
             "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
 
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update=False),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = :myid_1")
-
         # not supported by dialect, should just use update
         self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update='nowait'),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
-
-        # unknown lock mode
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update='unknown_mode'),
+            table1.select(table1.c.myid == 7).with_for_update(nowait=True),
             "SELECT mytable.myid, mytable.name, mytable.description "
             "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE")
 
-        # ----- mysql
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update=True),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
-            dialect=mysql.dialect())
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update="read"),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
-            dialect=mysql.dialect())
-
-        # ----- oracle
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update=True),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE",
-            dialect=oracle.dialect())
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update="nowait"),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT",
-            dialect=oracle.dialect())
-
-        # ----- postgresql
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update=True),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE",
-            dialect=postgresql.dialect())
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update="nowait"),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT",
-            dialect=postgresql.dialect())
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update="read"),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE",
-            dialect=postgresql.dialect())
-
-        self.assert_compile(
-            table1.select(table1.c.myid == 7, for_update="read_nowait"),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT",
-            dialect=postgresql.dialect())
+        assert_raises_message(
+            exc.ArgumentError,
+            "Unknown for_update argument: 'unknown_mode'",
+            table1.select, table1.c.myid == 7, for_update='unknown_mode'
+        )
 
-        self.assert_compile(
-            table1.select(table1.c.myid == 7).with_for_update(of=table1.c.myid),
-            "SELECT mytable.myid, mytable.name, mytable.description "
-            "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE OF mytable",
-            dialect=postgresql.dialect())
 
     def test_alias(self):
         # test the alias for a table1.  column names stay the same,
index 0fc7a0ed0dbc49c3e8205b8ef5f3d5231171f7b3..66cdd87c2b04469e11b0b06acf197fe807e05084 100644 (file)
@@ -1903,3 +1903,57 @@ class WithLabelsTest(fixtures.TestBase):
             ['t1_x', 't2_x']
         )
         self._assert_result_keys(sel, ['t1_a', 't2_b'])
+
+class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def _assert_legacy(self, leg, read=False, nowait=False):
+        t = table('t', column('c'))
+        s1 = select([t], for_update=leg)
+        if leg is False:
+            assert s1._for_update_arg is None
+            assert s1.for_update is None
+        else:
+            eq_(
+                s1._for_update_arg.read, read
+            )
+            eq_(
+                s1._for_update_arg.nowait, nowait
+            )
+            eq_(s1.for_update, leg)
+
+    def test_false_legacy(self):
+        self._assert_legacy(False)
+
+    def test_plain_true_legacy(self):
+        self._assert_legacy(True)
+
+    def test_read_legacy(self):
+        self._assert_legacy("read", read=True)
+
+    def test_nowait_legacy(self):
+        self._assert_legacy("nowait", nowait=True)
+
+    def test_read_nowait_legacy(self):
+        self._assert_legacy("read_nowait", read=True, nowait=True)
+
+    def test_basic_clone(self):
+        t = table('t', column('c'))
+        s = select([t]).with_for_update(read=True, of=t.c.c)
+        s2 = visitors.ReplacingCloningVisitor().traverse(s)
+        assert s2._for_update_arg is not s._for_update_arg
+        eq_(s2._for_update_arg.read, True)
+        eq_(s2._for_update_arg.of, [t.c.c])
+        self.assert_compile(s2,
+            "SELECT t.c FROM t FOR SHARE OF t",
+            dialect="postgresql")
+
+    def test_adapt(self):
+        t = table('t', column('c'))
+        s = select([t]).with_for_update(read=True, of=t.c.c)
+        a = t.alias()
+        s2 = sql_util.ClauseAdapter(a).traverse(s)
+        eq_(s2._for_update_arg.of, [a.c.c])
+        self.assert_compile(s2,
+            "SELECT t_1.c FROM t AS t_1 FOR SHARE OF t_1",
+            dialect="postgresql")