]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added tests for SELECT ... FOR UPDATE
authorJason Kirtland <jek@discorporate.us>
Tue, 12 Jun 2007 00:45:23 +0000 (00:45 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 12 Jun 2007 00:45:23 +0000 (00:45 +0000)
- Added postgres support for FOR UPDATE NOWAIT via select(for_update='nowait')

lib/sqlalchemy/databases/postgres.py
test/engine/transaction.py

index 0eca18be3800acd3782a72c8ac99b275553d114c..b2fe00fa0eb2a387004962139eb7819440c82d6a 100644 (file)
@@ -529,6 +529,12 @@ class PGCompiler(ansisql.ANSICompiler):
         else:
             return ""
 
+    def for_update_clause(self, select):
+        if select.for_update == 'nowait':
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(PGCompiler, self).for_update_clause(select)
+
     def binary_operator_string(self, binary):
         if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
             return '||'
index f80352f63efd1968dd48cf4119370f016b139a54..179a39b802d044cefc662cd8c853989b19e329e2 100644 (file)
@@ -1,6 +1,6 @@
 
 import testbase
-import unittest, sys, datetime
+import unittest, sys, datetime, random, time, threading
 import tables
 db = testbase.db
 from sqlalchemy import *
@@ -348,6 +348,133 @@ class TLTransactionTest(testbase.PersistTest):
         assert c1.connection is c2.connection
         c2.close()
         assert c1.connection.connection is not None
+
+class ForUpdateTest(testbase.PersistTest):
+    def setUpAll(self):
+        global counters, metadata
+        metadata = MetaData()
+        counters = Table('forupdate_counters', metadata,
+            Column('counter_id', INT, primary_key = True),
+            Column('counter_value', INT),
+            mysql_engine='InnoDB'
+        )
+        counters.create(testbase.db)
+    def tearDown(self):
+        testbase.db.connect().execute(counters.delete())
+    def tearDownAll(self):
+        counters.drop(testbase.db)
+
+    def increment(self, count, errors, delay=False,
+                  delay_duration=0.025, update_style=True):
+        con = db.connect()
+        sel = counters.select(for_update=update_style,
+                              whereclause=counters.c.counter_id==1)
+        
+        for i in xrange(count):
+            trans = con.begin()
+            try:
+                existing = con.execute(sel).fetchone()
+                incr = existing['counter_value'] + 1
+
+                if delay and random.randint(1,20) <= delay:
+                    time.sleep(delay_duration)
+
+                con.execute(counters.update(counters.c.counter_id==1,
+                                            values={'counter_value':incr}))
+                if delay and random.randint(1,20) <= delay:
+                    time.sleep(delay_duration)
+
+                readback = con.execute(sel).fetchone()
+                if (readback['counter_value'] != incr):
+                    raise AssertionError("Got %s post-update, expected %s" %
+                                         (readback['counter_value'], incr))
+                trans.commit()
+            except Exception, e:
+                trans.rollback()
+                errors.append(e)
+                break
+
+        con.close()
+
+    def _threaded_increment(self, iterations, thread_count, delay,
+                            update_style):
+        db = testbase.db
+        db.execute(counters.insert(), counter_id=1, counter_value=0)
+
+        threads, errors = [], []
+        for i in xrange(thread_count):
+            thread = threading.Thread(target=self.increment,
+                                      args=(iterations,),
+                                      kwargs={'errors': errors,
+                                              'delay': delay,
+                                              'update_style': update_style})
+            thread.start()
+            threads.append(thread)
+        for thread in threads:
+            thread.join()
+
+        for e in errors:
+            sys.stderr.write("Failure: %s\n" % e)
+
+        self.assert_(len(errors) == 0)
+
+        sel = counters.select(whereclause=counters.c.counter_id==1)
+        final = db.execute(sel).fetchone()
+        self.assert_(final['counter_value'] == iterations * thread_count)
+
+    @testbase.supported('mysql', 'oracle', 'postgres')
+    def testqueued_fullspeed(self):
+        """Test SELECT FOR UPDATE.
+
+        Runs concurrent modifications on a single row in the users table,
+        with each mutator trying to increment a value stored in user_name.
+        
+        Updates are made as fast a possible, with no added delays.
+        """
+        self._threaded_increment(50, 5, False, True)
+
+    @testbase.supported('mysql', 'oracle', 'postgres')
+    def testqueued_delayed(self):
+        """Test SELECT FOR UPDATE with artificial delays.
+
+        Runs concurrent modifications on a single row in the users table,
+        with each mutator trying to increment a value stored in user_name.
+
+        Individual updates may random sleep, causing all updates to queue
+        for a while.
+        """
+        self._threaded_increment(50, 5, True, True)
+
+    @testbase.supported('oracle', 'postgres')
+    def testnowait(self):
+        """Test SELECT FOR UPDATE NOWAIT.
+
+        Run concurrent modifications on a single row with an artificial
+        delay, expecting that writers will abort when encountering the
+        locked row.
+        """
+
+        db = testbase.db
+        db.execute(counters.insert(), counter_id=1, counter_value=0)
+        
+        iterations, thread_count = 4, 2
+        threads, errors = [], []
+        for i in xrange(thread_count):
+            thread = threading.Thread(target=self.increment,
+                                      args=(iterations,),
+                                      kwargs={'errors': errors,
+                                              'delay': 20,
+                                              'delay_duration': 0.1,
+                                              'update_style': 'nowait'})
+            thread.start()
+            threads.append(thread)
+        for thread in threads:
+            thread.join()
+
+        self.assert_(len(errors) != 0)
+        sel = counters.select(whereclause=counters.c.counter_id==1)
+        final = db.execute(sel).fetchone()
+        self.assert_(final['counter_value'] != iterations * thread_count)
         
 if __name__ == "__main__":
     testbase.main()