]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- changed "for_update" parameter to accept False/True/"nowait"
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Sep 2006 23:52:04 +0000 (23:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Sep 2006 23:52:04 +0000 (23:52 +0000)
and "read", the latter two of which are interpreted only by
Oracle and Mysql [ticket:292]
- added "lockmode" argument to base Query select/get functions,
including "with_lockmode" function to get a Query copy that has
a default locking mode.  Will translate "read"/"update"
arguments into a for_update argument on the select side.
[ticket:292]

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/orm/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 3ce07ba9cf85ded5ba219a3ba93a599d2ecf2561..7cb9b2cb517d3be8e676685cca92fcb0ce006974 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -2,6 +2,14 @@
 - more rearrangements of unit-of-work commit scheme to better allow
 dependencies within circular flushes to work properly...updated
 task traversal/logging implementation
+- changed "for_update" parameter to accept False/True/"nowait"
+and "read", the latter two of which are interpreted only by
+Oracle and Mysql [ticket:292]
+- added "lockmode" argument to base Query select/get functions, 
+including "with_lockmode" function to get a Query copy that has 
+a default locking mode.  Will translate "read"/"update" 
+arguments into a for_update argument on the select side.
+[ticket:292]
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index c44595f36a3f375ec2bf095758de7253645ba153..66b917c208979797e2a92a23237bda06f2574cf7 100644 (file)
@@ -394,13 +394,9 @@ class ANSICompiler(sql.Compiled):
             text += " ORDER BY " + order_by
 
         text += self.visit_select_postclauses(select)
-        if select.for_update:
-            text += " FOR UPDATE"
 
-        if select.nowait:
-            text += " NOWAIT"
-            
+        text += self.for_update_clause(select)
+
         if getattr(select, 'parens', False):
             self.strings[select] = "(" + text + ")"
         else:
@@ -415,6 +411,12 @@ class ANSICompiler(sql.Compiled):
         """ called when building a SELECT statement, position is after all other SELECT clauses. Most DB syntaxes put LIMIT/OFFSET here """
         return (select.limit or select.offset) and self.limit_clause(select) or ""
 
+    def for_update_clause(self, select):
+        if select.for_update:
+            return " FOR UPDATE"
+        else:
+            return ""
+
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
index c6d78cf904f1f24014ce68cf423b7f07e5797753..4eab9e55c554f44f107bca7af986868ce364d729 100644 (file)
@@ -411,6 +411,12 @@ class MySQLCompiler(ansisql.ANSICompiler):
             # TODO: put whatever MySQL does for CAST here.
             self.strings[cast] = self.strings[cast.clause]
 
+    def for_update_clause(self, select):
+        if select.for_update == 'read':
+             return ' LOCK IN SHARE MODE'
+        else:
+            return super(MySQLCompiler, self).for_update_clause(select)
+
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
index 5f574338b7052229c8d4e1502f5c5e2063f8b327..5db157cbb191d7f7b2b38bb997ea121c9e396b2c 100644 (file)
@@ -402,6 +402,12 @@ class OracleCompiler(ansisql.ANSICompiler):
     def limit_clause(self, select):
         return ""
 
+    def for_update_clause(self, select):
+        if select.for_update=="nowait":
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(OracleCompiler, self).for_update_clause(select)
+
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
index 29cc56761da852b4375af32255b786ea5da6dcfc..d35219208d72cb51cc9e9450d59fe08ecd8154da 100644 (file)
@@ -1,4 +1,4 @@
-# orm/query.py
+ # orm/query.py
 # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -12,7 +12,7 @@ import mapper
 
 class Query(object):
     """encapsulates the object-fetching operations provided by Mappers."""
-    def __init__(self, class_or_mapper, session=None, entity_name=None, **kwargs):
+    def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, **kwargs):
         if isinstance(class_or_mapper, type):
             self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
         else:
@@ -20,6 +20,7 @@ class Query(object):
         self.mapper = self.mapper.get_select_mapper().compile()
         self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
         self.order_by = kwargs.pop('order_by', self.mapper.order_by)
+        self.lockmode = lockmode
         self.extension = kwargs.pop('extension', self.mapper.extension)
         self._session = session
         if not hasattr(self.mapper, '_get_clause'):
@@ -67,7 +68,8 @@ class Query(object):
 
         e.g.   u = usermapper.get_by(user_name = 'fred')
         """
-        x = self.select_whereclause(self.join_by(*args, **params), limit=1)
+        lockmode=params.pop('lockmode', self.lockmode)
+        x = self.select_whereclause(self.join_by(*args, **params), lockmode=lockmode, limit=1)
         if x:
             return x[0]
         else:
@@ -248,7 +250,11 @@ class Query(object):
     def options(self, *args, **kwargs):
         """returns a new Query object using the given MapperOptions."""
         return self.mapper.options(*args, **kwargs).using(session=self._session)
-
+    
+    def with_lockmode(self, mode):
+        """return a new Query object with the specified locking mode."""
+        return Query(self.mapper, self._session, lockmode=mode)
+        
     def __getattr__(self, key):
         if (key.startswith('select_by_')):
             key = key[10:]
@@ -270,8 +276,9 @@ class Query(object):
         finally:
             result.close()
         
-    def _get(self, key, ident=None, reload=False):
-        if not reload and not self.always_refresh:
+    def _get(self, key, ident=None, reload=False, lockmode=None):
+        lockmode = lockmode or self.lockmode
+        if not reload and not self.always_refresh and lockmode == None:
             try:
                 return self.session._get(key)
             except KeyError:
@@ -293,7 +300,7 @@ class Query(object):
             if len(ident) > i + 1:
                 i += 1
         try:
-            statement = self.compile(self._get_clause)
+            statement = self.compile(self._get_clause, lockmode=lockmode)
             return self._select_statement(statement, params=params, populate_existing=reload)[0]
         except IndexError:
             return None
@@ -320,11 +327,14 @@ class Query(object):
     def compile(self, whereclause = None, **kwargs):
         order_by = kwargs.pop('order_by', False)
         from_obj = kwargs.pop('from_obj', [])
+        lockmode = kwargs.pop('lockmode', self.lockmode)
         if order_by is False:
             order_by = self.order_by
         if order_by is False:
             if self.table.default_order_by() is not None:
                 order_by = self.table.default_order_by()
+
+        for_update = {'read':'read','update':True,'update_nowait':'nowait'}.get(lockmode, False)
         
         if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
             whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity)
@@ -349,7 +359,7 @@ class Query(object):
             crit = []
             for i in range(0, len(self.table.primary_key)):
                 crit.append(s3.primary_key[i] == self.table.primary_key[i])
-            statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True)
+            statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update)
  #           raise "OK statement", str(statement)
  
             # now for the order by, convert the columns to their corresponding columns
@@ -364,7 +374,7 @@ class Query(object):
                 statement.order_by(*util.to_list(order_by))
         else:
             from_obj.append(self.table)
-            statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, **kwargs)
+            statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **kwargs)
             if order_by:
                 statement.order_by(*util.to_list(order_by))
             # for a DISTINCT query, you need the columns explicitly specified in order
index 6eeef270407d78d8fb2be8e221ab50ff661f01bb..8a5fc302ccbca0b7e3e924ec5a279c80aad0db54 100644 (file)
@@ -55,7 +55,7 @@ class SQLTest(PersistTest):
         c = clause.compile(parameters=params, dialect=dialect)
         self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
         cc = re.sub(r'\n', '', str(c))
-        self.assert_(cc == result, str(c) + "\n does not match \n" + result)
+        self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
         if checkparams is not None:
             if isinstance(checkparams, list):
                 self.assert_(c.get_params().values() == checkparams, "params dont match ")
@@ -213,12 +213,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
 
     def testunicodestartswith(self):
-       string = u"hi \xf6 \xf5"
-       self.runtest(
-               table1.select(table1.c.name.startswith(string)),
-               "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name",
-               checkparams = {'mytable_name': u'hi \xf6 \xf5%'},
-       )
+        string = u"hi \xf6 \xf5"
+        self.runtest(
+            table1.select(table1.c.name.startswith(string)),
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name",
+            checkparams = {'mytable_name': u'hi \xf6 \xf5%'},
+        )
 
     def testmultiparam(self):
         self.runtest(
@@ -249,9 +249,18 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
     
     def testforupdate(self):
-        self.runtest(
-            table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE"
-        )
+        self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE")
+    
+        self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE")
+
+        self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE NOWAIT", dialect=oracle.dialect())
+
+        self.runtest(table1.select(table1.c.myid==7, for_update="read"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", dialect=mysql.dialect())
+
+        self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s FOR UPDATE", dialect=mysql.dialect())
+
+        self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE", dialect=oracle.dialect())
+   
     def testalias(self):
         # test the alias for a table1.  column names stay the same, table name "changes" to "foo".
         self.runtest(