]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added LockmodeArgs
authorMario Lassnig <mario@lassnig.net>
Thu, 28 Nov 2013 13:50:41 +0000 (14:50 +0100)
committerMario Lassnig <mario@lassnig.net>
Thu, 28 Nov 2013 13:50:41 +0000 (14:50 +0100)
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py

index 6883be5af631d167e737290301e7cda6bd8a9924..d70e9b6068ea9ccdfe7246af936ae48d08d9e2f6 100644 (file)
@@ -1422,7 +1422,14 @@ class MySQLCompiler(compiler.SQLCompiler):
              self.process(join.onclause, **kwargs)))
 
     def for_update_clause(self, select):
-        if select.for_update == 'read':
+        # backwards compatibility
+        if isinstance(select.for_update, bool):
+            return ' FOR UPDATE'
+        elif isinstance(select.for_update, str):
+            if select.for_update == 'read':
+                return ' LOCK IN SHARE MODE'
+
+        if select.for_update.mode == 'read':
             return ' LOCK IN SHARE MODE'
         else:
             return super(MySQLCompiler, self).for_update_clause(select)
index 74441e9a84af2c1ce57664d616a78a29bb81142a..c0d9732b4bfc5da631f5ebb2837d5faea87f983e 100644 (file)
@@ -664,14 +664,24 @@ class OracleCompiler(compiler.SQLCompiler):
 
         tmp = ' FOR UPDATE'
 
-        if isinstance(select.for_update_of, list):
-            tmp += ' OF ' + ', '.join(['.'.join(of) for of in select.for_update_of])
-        elif isinstance(select.for_update_of, tuple):
-            tmp += ' OF ' + '.'.join(select.for_update_of)
+        # backwards compatibility
+        if isinstance(select.for_update, bool):
+            if select.for_update:
+                return tmp
+        elif isinstance(select.for_update, str):
+            if select.for_update == 'nowait':
+                return tmp + ' NOWAIT'
+            else:
+                return tmp
+
+        if isinstance(select.for_update.of, list):
+            tmp += ' OF ' + ', '.join(['.'.join(of) for of in select.for_update.of])
+        elif isinstance(select.for_update.of, tuple):
+            tmp += ' OF ' + '.'.join(select.for_update.of)
 
-        if select.for_update == 'nowait':
+        if select.for_update.mode == 'update_nowait':
             return tmp + ' NOWAIT'
-        elif select.for_update:
+        elif select.for_update.mode == 'update':
             return tmp
         else:
             return super(OracleCompiler, self).for_update_clause(select)
index ec22e8633198f47855fdc0ac512a0458659031e0..08976997515fa71821929b059c704c7f330d1e10 100644 (file)
@@ -1015,20 +1015,32 @@ class PGCompiler(compiler.SQLCompiler):
 
     def for_update_clause(self, select):
 
-        if select.for_update == 'read':
+        tmp = ' FOR UPDATE'
+
+        # backwards compatibility
+        if isinstance(select.for_update, bool):
+            return tmp
+        elif isinstance(select.for_update, str):
+            if select.for_update == 'nowait':
+                return tmp + ' NOWAIT'
+            elif select.for_update == 'read':
+                return ' FOR SHARE'
+            elif select.for_update == 'read_nowait':
+                return ' FOR SHARE NOWAIT'
+
+        if select.for_update.mode == 'read':
             return ' FOR SHARE'
-        elif select.for_update == 'read_nowait':
+        elif select.for_update.mode == 'read_nowait':
             return ' FOR SHARE NOWAIT'
 
-        tmp = ' FOR UPDATE'
-        if isinstance(select.for_update_of, list):
-            tmp += ' OF ' + ', '.join([of[0] for of in select.for_update_of])
-        elif isinstance(select.for_update_of, tuple):
-            tmp += ' OF ' + select.for_update_of[0]
+        if isinstance(select.for_update.of, list):
+            tmp += ' OF ' + ', '.join([of[0] for of in select.for_update.of])
+        elif isinstance(select.for_update.of, tuple):
+            tmp += ' OF ' + select.for_update.of[0]
 
-        if select.for_update == 'nowait':
+        if select.for_update.mode == 'update_nowait':
             return tmp + ' NOWAIT'
-        elif select.for_update:
+        elif select.for_update.mode == 'update':
             return tmp
         else:
             return super(PGCompiler, self).for_update_clause(select)
index f0b6bb03198dc362b86347935ccb2f276276ad55..a37d8d7e92604de6c6ea6b13d4d2a72d5593ad73 100644 (file)
@@ -70,7 +70,6 @@ class Query(object):
     _criterion = None
     _yield_per = None
     _lockmode = None
-    _lockmode_of = None
     _order_by = False
     _group_by = False
     _having = None
@@ -1129,49 +1128,38 @@ class Query(object):
         """Return a new Query object with the specified locking mode.
 
         :param mode: a string representing the desired locking mode. A
-            corresponding value is passed to the ``for_update`` parameter of
-            :meth:`~sqlalchemy.sql.expression.select` when the query is
-            executed. Valid values are:
+            corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object
+            is passed to the ``for_update`` parameter of
+            :meth:`~sqlalchemy.sql.expression.select` when the
+            query is executed. Valid values are:
 
-            ``'update'`` - passes ``for_update=True``, which translates to
-            ``FOR UPDATE`` (standard SQL, supported by most dialects)
+            ``None`` - translates to no lockmode
 
-            ``'update_nowait'`` - passes ``for_update='nowait'``, which
-            translates to ``FOR UPDATE NOWAIT`` (supported by Oracle,
-            PostgreSQL 8.1 upwards)
+            ``'update'`` - translates to ``FOR UPDATE``
+            (standard SQL, supported by most dialects)
 
-            ``'read'`` - passes ``for_update='read'``, which translates to
-            ``LOCK IN SHARE MODE`` (for MySQL), and ``FOR SHARE`` (for
-            PostgreSQL)
+            ``'update_nowait'`` - translates to ``FOR UPDATE NOWAIT``
+            (supported by Oracle, PostgreSQL 8.1 upwards)
 
-            ``'read_nowait'`` - passes ``for_update='read_nowait'``, which
-            translates to ``FOR SHARE NOWAIT`` (supported by PostgreSQL).
+            ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL),
+            and ``FOR SHARE`` (for PostgreSQL)
 
             .. versionadded:: 0.7.7
                 ``FOR SHARE`` and ``FOR SHARE NOWAIT`` (PostgreSQL).
-        :param of: either a column descriptor, or list of column
+
+         :param of: either a column descriptor, or list of column
             descriptors, representing the optional OF part of the
-            clause. This passes ``for_update_of=descriptor(s)'`` which
-            translates to ``FOR UPDATE OF table [NOWAIT]`` respectively
+            clause. This passes the descriptor to the
+            corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object,
+            and translates to ``FOR UPDATE OF table [NOWAIT]`` respectively
             ``FOR UPDATE OF table, table [NOWAIT]`` (PostgreSQL), or
             ``FOR UPDATE OF table.column [NOWAIT]`` respectively
             ``FOR UPDATE OF table.column, table.column [NOWAIT]`` (Oracle).
 
-            .. versionadded:: 0.9.0
+            .. versionadded:: 0.9.0b2
         """
 
-        self._lockmode = mode
-
-        # do not drag the ORM layer into the dialect,
-        # we only need the table name and column name
-        if isinstance(of, attributes.QueryableAttribute):
-            self._lockmode_of = (of.expression.table.name,
-                                 of.expression.name)
-        elif isinstance(of, (tuple, list)):
-            self._lockmode_of = [(o.expression.table.name,
-                                  o.expression.name) for o in of]
-        elif of is not None:
-            raise TypeError('OF parameter is not a column(list)')
+        self._lockmode = LockmodeArgs(mode=mode, of=of)
 
     @_generative()
     def params(self, *args, **kwargs):
@@ -2704,13 +2692,6 @@ class Query(object):
         update_op.exec_()
         return update_op.rowcount
 
-    _lockmode_lookup = {
-            'read': 'read',
-              'read_nowait': 'read_nowait',
-              'update': True,
-              'update_nowait': 'nowait',
-              None: False
-    }
 
     def _compile_context(self, labels=True):
         context = QueryContext(self)
@@ -2720,14 +2701,12 @@ class Query(object):
 
         context.labels = labels
 
-        if self._lockmode:
-            try:
-                context.for_update = self._lockmode_lookup[self._lockmode]
-            except KeyError:
-                raise sa_exc.ArgumentError(
-                                "Unknown lockmode %r" % self._lockmode)
-            if self._lockmode_of is not None:
-                context.for_update_of = self._lockmode_of
+        if isinstance(self._lockmode, bool) and self._lockmode:
+            context.for_update = LockmodeArgs(mode='update')
+        elif isinstance(self._lockmode, LockmodeArgs):
+            if self._lockmode.mode not in LockmodeArgs.lockmodes:
+                raise sa_exc.ArgumentError('Unknown lockmode %r' % self._lockmode.mode)
+            context.for_update = self._lockmode
 
         for entity in self._entities:
             entity.setup_context(self, context)
@@ -2813,7 +2792,6 @@ class Query(object):
         statement = sql.select(
                             [inner] + context.secondary_columns,
                             for_update=context.for_update,
-                            for_update_of=context.for_update_of,
                             use_labels=context.labels)
 
         from_clause = inner
@@ -2859,7 +2837,6 @@ class Query(object):
                         from_obj=context.froms,
                         use_labels=context.labels,
                         for_update=context.for_update,
-                        for_update_of=context.for_update_of,
                         order_by=context.order_by,
                         **self._select_args
                     )
@@ -3435,13 +3412,11 @@ class _ColumnEntity(_QueryEntity):
         return str(self.column)
 
 
-
 class QueryContext(object):
     multi_row_eager_loaders = False
     adapter = None
     froms = ()
-    for_update = False
-    for_update_of = None
+    for_update = None
 
     def __init__(self, query):
 
@@ -3516,3 +3491,62 @@ class AliasOption(interfaces.MapperOption):
         else:
             alias = self.alias
         query._from_obj_alias = sql_util.ColumnAdapter(alias)
+
+
+class LockmodeArgs(object):
+
+    lockmodes = [None,
+                 'read', 'read_nowait',
+                 'update', 'update_nowait'
+    ]
+
+    mode = None
+    of = None
+
+    def __init__(self, mode=None, of=None):
+        """ORM-level Lockmode
+
+        :class:`.LockmodeArgs` defines the locking strategy for the
+        dialects as given by ``FOR UPDATE [OF] [NOWAIT]``. The optional
+        OF component is translated by the dialects into the supported
+        tablename and columnname descriptors.
+
+        :param mode: Defines the lockmode to use.
+
+            ``None`` - translates to no lockmode
+
+            ``'update'`` - translates to ``FOR UPDATE``
+            (standard SQL, supported by most dialects)
+
+            ``'update_nowait'`` - translates to ``FOR UPDATE NOWAIT``
+            (supported by Oracle, PostgreSQL 8.1 upwards)
+
+            ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL),
+            and ``FOR SHARE`` (for PostgreSQL)
+
+            ``'read_nowait'`` - translates to ``FOR SHARE NOWAIT``
+            (supported by PostgreSQL). ``FOR SHARE`` and
+            ``FOR SHARE NOWAIT`` (PostgreSQL).
+
+        :param of: either a column descriptor, or list of column
+            descriptors, representing the optional OF part of the
+            clause. This passes the descriptor to the
+            corresponding :meth:`~sqlalchemy.orm.query.LockmodeArgs` object,
+            and translates to ``FOR UPDATE OF table [NOWAIT]`` respectively
+            ``FOR UPDATE OF table, table [NOWAIT]`` (PostgreSQL), or
+            ``FOR UPDATE OF table.column [NOWAIT]`` respectively
+            ``FOR UPDATE OF table.column, table.column [NOWAIT]`` (Oracle).
+
+        .. versionadded:: 0.9.0b2
+        """
+
+        if isinstance(mode, bool) and mode:
+            mode = 'update'
+
+        self.mode = mode
+
+        # extract table names and column names
+        if isinstance(of, attributes.QueryableAttribute):
+            self.of = (of.expression.table.name, of.expression.name)
+        elif isinstance(of, (tuple, list)) and of != []:
+            self.of = [(o.expression.table.name, o.expression.name) for o in of]
index 4f3dbba3688ff172302d333a1ef454396c37dbff..54eb1f9eb3b32610497f71f1e08c73a9c919431c 100644 (file)
@@ -1570,7 +1570,12 @@ class SQLCompiler(Compiled):
             return ""
 
     def for_update_clause(self, select):
-        if select.for_update:
+        # backwards compatibility
+        if isinstance(select.for_update, bool):
+            return " FOR UPDATE" if select.for_update else ""
+        elif isinstance(select.for_update, str):
+            return " FOR UPDATE"
+        elif select.for_update.mode is not None:
             return " FOR UPDATE"
         else:
             return ""
index 8ad238ca332d339283323d04ce0c837f74e1ff03..dcf7689cf5d29de7d1e23b916a039943a07d71e0 100644 (file)
@@ -1162,7 +1162,6 @@ class SelectBase(Executable, FromClause):
     def __init__(self,
             use_labels=False,
             for_update=False,
-            for_update_of=None,
             limit=None,
             offset=None,
             order_by=None,
@@ -1171,7 +1170,6 @@ class SelectBase(Executable, FromClause):
             autocommit=None):
         self.use_labels = use_labels
         self.for_update = for_update
-        self.for_update_of = for_update_of
         if autocommit is not None:
             util.warn_deprecated('autocommit on select() is '
                                  'deprecated.  Use .execution_options(a'
@@ -2787,4 +2785,3 @@ class AnnotatedFromClause(Annotated):
         Annotated.__init__(self, element, values)
 
 
-