]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added exception wrapping/reconnect support to result set
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Feb 2008 18:32:11 +0000 (18:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Feb 2008 18:32:11 +0000 (18:32 +0000)
fetching.  Reconnect works for those databases that
raise a catchable data error during results
(i.e. doesn't work on MySQL) [ticket:978]

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
test/engine/reconnect.py

diff --git a/CHANGES b/CHANGES
index ba20c25a363fd9ce5b0465eada38db1b3d767523..63dfbd26978ec381115e858a0ffbac9b2dc53b4b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -11,6 +11,11 @@ CHANGES
     - The value of a bindparam() can be a callable, in which case
       it's evaluated at statement execution time to get the value.
 
+    - added exception wrapping/reconnect support to result set 
+      fetching.  Reconnect works for those databases that 
+      raise a catchable data error during results 
+      (i.e. doesn't work on MySQL) [ticket:978]
+
 - orm
     - any(), has(), contains(), attribute level == and != now
       work properly with self-referential relations - the clause
index 578ef48d0bbc6167a237edab6e740a1136513ad0..a9fa90c68842b5646f4de4feca12d3e397b6ece0 100644 (file)
@@ -384,7 +384,7 @@ class PGDialect(default.DefaultDialect):
         if isinstance(e, self.dbapi.OperationalError):
             return 'closed the connection' in str(e) or 'connection not open' in str(e)
         elif isinstance(e, self.dbapi.InterfaceError):
-            return 'connection already closed' in str(e)
+            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
         elif isinstance(e, self.dbapi.ProgrammingError):
             # yes, it really says "losed", not "closed"
             return "losed the connection unexpectedly" in str(e)
index 28951f90035283a71bbe6483534a378246b324b4..bb52070fc29307770c264508020471b40abcb9b5 100644 (file)
@@ -1616,30 +1616,47 @@ class ResultProxy(object):
     def fetchall(self):
         """Fetch all rows, just like DB-API ``cursor.fetchall()``."""
 
-        l = [self._process_row(self, row) for row in self._fetchall_impl()]
-        self.close()
-        return l
+        try:
+            l = [self._process_row(self, row) for row in self._fetchall_impl()]
+            self.close()
+            return l
+        except Exception, e:
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            raise
 
     def fetchmany(self, size=None):
         """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``."""
 
-        l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
-        if len(l) == 0:
-            self.close()
-        return l
+        try:
+            l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
+            if len(l) == 0:
+                self.close()
+            return l
+        except Exception, e:
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            raise
 
     def fetchone(self):
         """Fetch one row, just like DB-API ``cursor.fetchone()``."""
-        row = self._fetchone_impl()
-        if row is not None:
-            return self._process_row(self, row)
-        else:
-            self.close()
-            return None
+        try:
+            row = self._fetchone_impl()
+            if row is not None:
+                return self._process_row(self, row)
+            else:
+                self.close()
+                return None
+        except Exception, e:
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            raise
 
     def scalar(self):
         """Fetch the first column of the first row, and close the result set."""
-        row = self._fetchone_impl()
+        try:
+            row = self._fetchone_impl()
+        except Exception, e:
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            raise
+            
         try:
             if row is not None:
                 return self._process_row(self, row)[0]
index 1b9962523f78bb33fc650d27d3557fb6d36d9cc0..d0d037a3407526c8cab7832c7be6957c85f41a00 100644 (file)
@@ -1,6 +1,6 @@
 import testenv; testenv.configure_for_tests()
 import sys, weakref
-from sqlalchemy import create_engine, exceptions, select
+from sqlalchemy import create_engine, exceptions, select, MetaData, Table, Column, Integer, String
 from testlib import *
 
 
@@ -212,7 +212,7 @@ class RealReconnectTest(TestBase):
         assert not conn.invalidated
 
         conn.close()
-
+    
     def test_close(self):
         conn = engine.connect()
         self.assertEquals(conn.execute(select([1])).scalar(), 1)
@@ -275,6 +275,40 @@ class RealReconnectTest(TestBase):
         self.assertEquals(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
+class InvalidateDuringResultTest(TestBase):
+    def setUp(self):
+        global meta, table, engine
+        engine = engines.reconnecting_engine()
+        meta = MetaData(engine)
+        table = Table('sometable', meta,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(50)))
+        meta.create_all()
+        table.insert().execute(
+            [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
+        )
+        
+    def tearDown(self):
+        meta.drop_all()
+        engine.dispose()
+    
+    @testing.fails_on('mysql')    
+    def test_invalidate_on_results(self):
+        conn = engine.connect()
+        
+        result = conn.execute("select * from sometable")
+        for x in xrange(20):
+            result.fetchone()
+        
+        engine.test_shutdown()
+        try:
+            result.fetchone()
+            assert False
+        except exceptions.DBAPIError, e:
+            if not e.connection_invalidated:
+                raise
 
+        assert conn.invalidated
+        
 if __name__ == '__main__':
     testenv.main()