From 47d8b03b14145997fc0936bd674363f0e213f019 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 10 Sep 2006 23:52:04 +0000 Subject: [PATCH] - changed "for_update" parameter to accept False/True/"nowait" and "read", the latter two of which are interpreted only by Oracle and Mysql [ticket:292] - added "lockmode" argument to base Query select/get functions, including "with_lockmode" function to get a Query copy that has a default locking mode. Will translate "read"/"update" arguments into a for_update argument on the select side. [ticket:292] --- CHANGES | 8 ++++++++ lib/sqlalchemy/ansisql.py | 14 ++++++++------ lib/sqlalchemy/databases/mysql.py | 6 ++++++ lib/sqlalchemy/databases/oracle.py | 6 ++++++ lib/sqlalchemy/orm/query.py | 28 +++++++++++++++++++--------- test/sql/select.py | 29 +++++++++++++++++++---------- 6 files changed, 66 insertions(+), 25 deletions(-) diff --git a/CHANGES b/CHANGES index 3ce07ba9cf..7cb9b2cb51 100644 --- a/CHANGES +++ b/CHANGES @@ -2,6 +2,14 @@ - more rearrangements of unit-of-work commit scheme to better allow dependencies within circular flushes to work properly...updated task traversal/logging implementation +- changed "for_update" parameter to accept False/True/"nowait" +and "read", the latter two of which are interpreted only by +Oracle and Mysql [ticket:292] +- added "lockmode" argument to base Query select/get functions, +including "with_lockmode" function to get a Query copy that has +a default locking mode. Will translate "read"/"update" +arguments into a for_update argument on the select side. +[ticket:292] 0.2.8 - cleanup on connection methods + documentation. custom DBAPI diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index c44595f36a..66b917c208 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -394,13 +394,9 @@ class ANSICompiler(sql.Compiled): text += " ORDER BY " + order_by text += self.visit_select_postclauses(select) - - if select.for_update: - text += " FOR UPDATE" - if select.nowait: - text += " NOWAIT" - + text += self.for_update_clause(select) + if getattr(select, 'parens', False): self.strings[select] = "(" + text + ")" else: @@ -415,6 +411,12 @@ class ANSICompiler(sql.Compiled): """ called when building a SELECT statement, position is after all other SELECT clauses. Most DB syntaxes put LIMIT/OFFSET here """ return (select.limit or select.offset) and self.limit_clause(select) or "" + def for_update_clause(self, select): + if select.for_update: + return " FOR UPDATE" + else: + return "" + def limit_clause(self, select): text = "" if select.limit is not None: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index c6d78cf904..4eab9e55c5 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -411,6 +411,12 @@ class MySQLCompiler(ansisql.ANSICompiler): # TODO: put whatever MySQL does for CAST here. self.strings[cast] = self.strings[cast.clause] + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) + def limit_clause(self, select): text = "" if select.limit is not None: diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 5f574338b7..5db157cbb1 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -402,6 +402,12 @@ class OracleCompiler(ansisql.ANSICompiler): def limit_clause(self, select): return "" + def for_update_clause(self, select): + if select.for_update=="nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 29cc56761d..d35219208d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,4 +1,4 @@ -# orm/query.py + # orm/query.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,7 @@ import mapper class Query(object): """encapsulates the object-fetching operations provided by Mappers.""" - def __init__(self, class_or_mapper, session=None, entity_name=None, **kwargs): + def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, **kwargs): if isinstance(class_or_mapper, type): self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name) else: @@ -20,6 +20,7 @@ class Query(object): self.mapper = self.mapper.get_select_mapper().compile() self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) + self.lockmode = lockmode self.extension = kwargs.pop('extension', self.mapper.extension) self._session = session if not hasattr(self.mapper, '_get_clause'): @@ -67,7 +68,8 @@ class Query(object): e.g. u = usermapper.get_by(user_name = 'fred') """ - x = self.select_whereclause(self.join_by(*args, **params), limit=1) + lockmode=params.pop('lockmode', self.lockmode) + x = self.select_whereclause(self.join_by(*args, **params), lockmode=lockmode, limit=1) if x: return x[0] else: @@ -248,7 +250,11 @@ class Query(object): def options(self, *args, **kwargs): """returns a new Query object using the given MapperOptions.""" return self.mapper.options(*args, **kwargs).using(session=self._session) - + + def with_lockmode(self, mode): + """return a new Query object with the specified locking mode.""" + return Query(self.mapper, self._session, lockmode=mode) + def __getattr__(self, key): if (key.startswith('select_by_')): key = key[10:] @@ -270,8 +276,9 @@ class Query(object): finally: result.close() - def _get(self, key, ident=None, reload=False): - if not reload and not self.always_refresh: + def _get(self, key, ident=None, reload=False, lockmode=None): + lockmode = lockmode or self.lockmode + if not reload and not self.always_refresh and lockmode == None: try: return self.session._get(key) except KeyError: @@ -293,7 +300,7 @@ class Query(object): if len(ident) > i + 1: i += 1 try: - statement = self.compile(self._get_clause) + statement = self.compile(self._get_clause, lockmode=lockmode) return self._select_statement(statement, params=params, populate_existing=reload)[0] except IndexError: return None @@ -320,11 +327,14 @@ class Query(object): def compile(self, whereclause = None, **kwargs): order_by = kwargs.pop('order_by', False) from_obj = kwargs.pop('from_obj', []) + lockmode = kwargs.pop('lockmode', self.lockmode) if order_by is False: order_by = self.order_by if order_by is False: if self.table.default_order_by() is not None: order_by = self.table.default_order_by() + + for_update = {'read':'read','update':True,'update_nowait':'nowait'}.get(lockmode, False) if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity) @@ -349,7 +359,7 @@ class Query(object): crit = [] for i in range(0, len(self.table.primary_key)): crit.append(s3.primary_key[i] == self.table.primary_key[i]) - statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True) + statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update) # raise "OK statement", str(statement) # now for the order by, convert the columns to their corresponding columns @@ -364,7 +374,7 @@ class Query(object): statement.order_by(*util.to_list(order_by)) else: from_obj.append(self.table) - statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, **kwargs) + statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **kwargs) if order_by: statement.order_by(*util.to_list(order_by)) # for a DISTINCT query, you need the columns explicitly specified in order diff --git a/test/sql/select.py b/test/sql/select.py index 6eeef27040..8a5fc302cc 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -55,7 +55,7 @@ class SQLTest(PersistTest): c = clause.compile(parameters=params, dialect=dialect) self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) cc = re.sub(r'\n', '', str(c)) - self.assert_(cc == result, str(c) + "\n does not match \n" + result) + self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'") if checkparams is not None: if isinstance(checkparams, list): self.assert_(c.get_params().values() == checkparams, "params dont match ") @@ -213,12 +213,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) def testunicodestartswith(self): - string = u"hi \xf6 \xf5" - self.runtest( - table1.select(table1.c.name.startswith(string)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name", - checkparams = {'mytable_name': u'hi \xf6 \xf5%'}, - ) + string = u"hi \xf6 \xf5" + self.runtest( + table1.select(table1.c.name.startswith(string)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name", + checkparams = {'mytable_name': u'hi \xf6 \xf5%'}, + ) def testmultiparam(self): self.runtest( @@ -249,9 +249,18 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) def testforupdate(self): - self.runtest( - table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE" - ) + self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE") + + self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE") + + self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE NOWAIT", dialect=oracle.dialect()) + + self.runtest(table1.select(table1.c.myid==7, for_update="read"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", dialect=mysql.dialect()) + + self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s FOR UPDATE", dialect=mysql.dialect()) + + self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE", dialect=oracle.dialect()) + def testalias(self): # test the alias for a table1. column names stay the same, table name "changes" to "foo". self.runtest( -- 2.47.2