import testbase
-import unittest, sys, datetime
+import unittest, sys, datetime, random, time, threading
import tables
db = testbase.db
from sqlalchemy import *
assert c1.connection is c2.connection
c2.close()
assert c1.connection.connection is not None
+
+class ForUpdateTest(testbase.PersistTest):
+ def setUpAll(self):
+ global counters, metadata
+ metadata = MetaData()
+ counters = Table('forupdate_counters', metadata,
+ Column('counter_id', INT, primary_key = True),
+ Column('counter_value', INT),
+ mysql_engine='InnoDB'
+ )
+ counters.create(testbase.db)
+ def tearDown(self):
+ testbase.db.connect().execute(counters.delete())
+ def tearDownAll(self):
+ counters.drop(testbase.db)
+
+ def increment(self, count, errors, delay=False,
+ delay_duration=0.025, update_style=True):
+ con = db.connect()
+ sel = counters.select(for_update=update_style,
+ whereclause=counters.c.counter_id==1)
+
+ for i in xrange(count):
+ trans = con.begin()
+ try:
+ existing = con.execute(sel).fetchone()
+ incr = existing['counter_value'] + 1
+
+ if delay and random.randint(1,20) <= delay:
+ time.sleep(delay_duration)
+
+ con.execute(counters.update(counters.c.counter_id==1,
+ values={'counter_value':incr}))
+ if delay and random.randint(1,20) <= delay:
+ time.sleep(delay_duration)
+
+ readback = con.execute(sel).fetchone()
+ if (readback['counter_value'] != incr):
+ raise AssertionError("Got %s post-update, expected %s" %
+ (readback['counter_value'], incr))
+ trans.commit()
+ except Exception, e:
+ trans.rollback()
+ errors.append(e)
+ break
+
+ con.close()
+
+ def _threaded_increment(self, iterations, thread_count, delay,
+ update_style):
+ db = testbase.db
+ db.execute(counters.insert(), counter_id=1, counter_value=0)
+
+ threads, errors = [], []
+ for i in xrange(thread_count):
+ thread = threading.Thread(target=self.increment,
+ args=(iterations,),
+ kwargs={'errors': errors,
+ 'delay': delay,
+ 'update_style': update_style})
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+
+ for e in errors:
+ sys.stderr.write("Failure: %s\n" % e)
+
+ self.assert_(len(errors) == 0)
+
+ sel = counters.select(whereclause=counters.c.counter_id==1)
+ final = db.execute(sel).fetchone()
+ self.assert_(final['counter_value'] == iterations * thread_count)
+
+ @testbase.supported('mysql', 'oracle', 'postgres')
+ def testqueued_fullspeed(self):
+ """Test SELECT FOR UPDATE.
+
+ Runs concurrent modifications on a single row in the users table,
+ with each mutator trying to increment a value stored in user_name.
+
+ Updates are made as fast a possible, with no added delays.
+ """
+ self._threaded_increment(50, 5, False, True)
+
+ @testbase.supported('mysql', 'oracle', 'postgres')
+ def testqueued_delayed(self):
+ """Test SELECT FOR UPDATE with artificial delays.
+
+ Runs concurrent modifications on a single row in the users table,
+ with each mutator trying to increment a value stored in user_name.
+
+ Individual updates may random sleep, causing all updates to queue
+ for a while.
+ """
+ self._threaded_increment(50, 5, True, True)
+
+ @testbase.supported('oracle', 'postgres')
+ def testnowait(self):
+ """Test SELECT FOR UPDATE NOWAIT.
+
+ Run concurrent modifications on a single row with an artificial
+ delay, expecting that writers will abort when encountering the
+ locked row.
+ """
+
+ db = testbase.db
+ db.execute(counters.insert(), counter_id=1, counter_value=0)
+
+ iterations, thread_count = 4, 2
+ threads, errors = [], []
+ for i in xrange(thread_count):
+ thread = threading.Thread(target=self.increment,
+ args=(iterations,),
+ kwargs={'errors': errors,
+ 'delay': 20,
+ 'delay_duration': 0.1,
+ 'update_style': 'nowait'})
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+
+ self.assert_(len(errors) != 0)
+ sel = counters.select(whereclause=counters.c.counter_id==1)
+ final = db.execute(sel).fetchone()
+ self.assert_(final['counter_value'] != iterations * thread_count)
if __name__ == "__main__":
testbase.main()