From: Jason Kirtland Date: Tue, 12 Jun 2007 00:45:23 +0000 (+0000) Subject: - Added tests for SELECT ... FOR UPDATE X-Git-Tag: rel_0_4_6~207 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=60bbb8020260121298b4ddd10a4a61c982a3ab09;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Added tests for SELECT ... FOR UPDATE - Added postgres support for FOR UPDATE NOWAIT via select(for_update='nowait') --- diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 0eca18be38..b2fe00fa0e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -529,6 +529,12 @@ class PGCompiler(ansisql.ANSICompiler): else: return "" + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" + else: + return super(PGCompiler, self).for_update_clause(select) + def binary_operator_string(self, binary): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' diff --git a/test/engine/transaction.py b/test/engine/transaction.py index f80352f63e..179a39b802 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -1,6 +1,6 @@ import testbase -import unittest, sys, datetime +import unittest, sys, datetime, random, time, threading import tables db = testbase.db from sqlalchemy import * @@ -348,6 +348,133 @@ class TLTransactionTest(testbase.PersistTest): assert c1.connection is c2.connection c2.close() assert c1.connection.connection is not None + +class ForUpdateTest(testbase.PersistTest): + def setUpAll(self): + global counters, metadata + metadata = MetaData() + counters = Table('forupdate_counters', metadata, + Column('counter_id', INT, primary_key = True), + Column('counter_value', INT), + mysql_engine='InnoDB' + ) + counters.create(testbase.db) + def tearDown(self): + testbase.db.connect().execute(counters.delete()) + def tearDownAll(self): + counters.drop(testbase.db) + + def increment(self, count, errors, delay=False, + delay_duration=0.025, update_style=True): + con = db.connect() + sel = counters.select(for_update=update_style, + whereclause=counters.c.counter_id==1) + + for i in xrange(count): + trans = con.begin() + try: + existing = con.execute(sel).fetchone() + incr = existing['counter_value'] + 1 + + if delay and random.randint(1,20) <= delay: + time.sleep(delay_duration) + + con.execute(counters.update(counters.c.counter_id==1, + values={'counter_value':incr})) + if delay and random.randint(1,20) <= delay: + time.sleep(delay_duration) + + readback = con.execute(sel).fetchone() + if (readback['counter_value'] != incr): + raise AssertionError("Got %s post-update, expected %s" % + (readback['counter_value'], incr)) + trans.commit() + except Exception, e: + trans.rollback() + errors.append(e) + break + + con.close() + + def _threaded_increment(self, iterations, thread_count, delay, + update_style): + db = testbase.db + db.execute(counters.insert(), counter_id=1, counter_value=0) + + threads, errors = [], [] + for i in xrange(thread_count): + thread = threading.Thread(target=self.increment, + args=(iterations,), + kwargs={'errors': errors, + 'delay': delay, + 'update_style': update_style}) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + + for e in errors: + sys.stderr.write("Failure: %s\n" % e) + + self.assert_(len(errors) == 0) + + sel = counters.select(whereclause=counters.c.counter_id==1) + final = db.execute(sel).fetchone() + self.assert_(final['counter_value'] == iterations * thread_count) + + @testbase.supported('mysql', 'oracle', 'postgres') + def testqueued_fullspeed(self): + """Test SELECT FOR UPDATE. + + Runs concurrent modifications on a single row in the users table, + with each mutator trying to increment a value stored in user_name. + + Updates are made as fast a possible, with no added delays. + """ + self._threaded_increment(50, 5, False, True) + + @testbase.supported('mysql', 'oracle', 'postgres') + def testqueued_delayed(self): + """Test SELECT FOR UPDATE with artificial delays. + + Runs concurrent modifications on a single row in the users table, + with each mutator trying to increment a value stored in user_name. + + Individual updates may random sleep, causing all updates to queue + for a while. + """ + self._threaded_increment(50, 5, True, True) + + @testbase.supported('oracle', 'postgres') + def testnowait(self): + """Test SELECT FOR UPDATE NOWAIT. + + Run concurrent modifications on a single row with an artificial + delay, expecting that writers will abort when encountering the + locked row. + """ + + db = testbase.db + db.execute(counters.insert(), counter_id=1, counter_value=0) + + iterations, thread_count = 4, 2 + threads, errors = [], [] + for i in xrange(thread_count): + thread = threading.Thread(target=self.increment, + args=(iterations,), + kwargs={'errors': errors, + 'delay': 20, + 'delay_duration': 0.1, + 'update_style': 'nowait'}) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + + self.assert_(len(errors) != 0) + sel = counters.select(whereclause=counters.c.counter_id==1) + final = db.execute(sel).fetchone() + self.assert_(final['counter_value'] != iterations * thread_count) if __name__ == "__main__": testbase.main()