]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
for #516, moved the "disconnect check" step out of pool and back into base.py. diale...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Apr 2007 16:06:06 +0000 (16:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Apr 2007 16:06:06 +0000 (16:06 +0000)
is_disconnect() method now.  simpler design which also puts control of the ultimate "execute" call back into the hands of the dialects.

lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/pool.py

index 013e78c6af5d08529573d9153bd83ca77ad8fa7a..a2d4ac36e082cc8ff0191b1fe74b4e3b4cf56d6d 100644 (file)
@@ -572,10 +572,8 @@ class MSSQLDialect_pymssql(MSSQLDialect):
             del keys['port']
         return [[], keys]
 
-    def get_disconnect_checker(self):
-        def disconnect_checker(e):
-            return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
-        return disconnect_checker
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
 
 
 ##    This code is leftover from the initial implementation, for reference
@@ -636,10 +634,8 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             connectors.append ("TrustedConnection=Yes")
         return [[";".join (connectors)], {}]
 
-    def get_disconnect_checker(self):
-        def disconnect_checker(e):
-            return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1]
-        return disconnect_checker
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1]
 
 
 class MSSQLDialect_adodbapi(MSSQLDialect):
@@ -671,10 +667,8 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
             connectors.append("Integrated Security=SSPI")
         return [[";".join (connectors)], {}]
 
-    def get_disconnect_checker(self):
-        def disconnect_checker(e):
-            return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
-        return disconnect_checker
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
 
 dialect_mapping = {
     'pymssql':  MSSQLDialect_pymssql,
index 7ea98e92f4e6cbce99868912dd2020fb75f7686b..03297cd68675c6ba9484f2ae5b84ebfb3f887e56 100644 (file)
@@ -342,11 +342,8 @@ class MySQLDialect(ansisql.ANSIDialect):
         except:
             pass
 
-    def get_disconnect_checker(self):
-        def disconnect_checker(e):
-            return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2014)
-        return disconnect_checker            
-
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2014)
 
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
index a26ef76b6f5c1825bff367aeb1980b71bf642ceb..6facde93671f55eb9036809ee36e5765140143b2 100644 (file)
@@ -338,18 +338,16 @@ class PGDialect(ansisql.ANSIDialect):
         cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name})
         return bool(not not cursor.rowcount)
 
-    def get_disconnect_checker(self):
-        def disconnect_checker(e):
-            if isinstance(e, self.dbapi.OperationalError):
-                return 'closed the connection' in str(e) or 'connection not open' in str(e)
-            elif isinstance(e, self.dbapi.InterfaceError):
-                return 'connection already closed' in str(e)
-            elif isinstance(e, self.dbapi.ProgrammingError):
-                # yes, it really says "losed", not "closed"
-                return "losed the connection unexpectedly" in str(e)
-            else:
-                return False
-        return disconnect_checker
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'closed the connection' in str(e) or 'connection not open' in str(e)
+        elif isinstance(e, self.dbapi.InterfaceError):
+            return 'connection already closed' in str(e)
+        elif isinstance(e, self.dbapi.ProgrammingError):
+            # yes, it really says "losed", not "closed"
+            return "losed the connection unexpectedly" in str(e)
+        else:
+            return False
 
     def reflecttable(self, connection, table):
         if self.version == 2:
index 80d93e61cc43ec1980eff14b73ad820c3c91710c..f5b4b377e4217c077812c8c723e6b93836dda311 100644 (file)
@@ -246,10 +246,9 @@ class Dialect(sql.AbstractDialect):
 
         return clauseelement.compile(dialect=self, parameters=parameters)
 
-    def get_disconnect_checker(self):
-        """Return a callable that determines if an SQLError is caused by a database disconnection."""
-
-        return lambda x: False
+    def is_disconnect(self, e):
+        """Return True if the given DBAPI error indicates an invalid connection"""
+        raise NotImplementedError()
 
 
 class ExecutionContext(object):
@@ -576,6 +575,8 @@ class Connection(Connectable):
         try:
             context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
         except Exception, e:
+            if self.dialect.is_disconnect(e):
+                self.__connection.invalidate(e=e)
             self._autorollback()
             if self.__close_with_result:
                 self.close()
@@ -585,6 +586,8 @@ class Connection(Connectable):
         try:
             context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
         except Exception, e:
+            if self.dialect.is_disconnect(e):
+                self.__connection.invalidate(e=e)
             self._autorollback()
             if self.__close_with_result:
                 self.close()
index ceecee364fb9a2b570c49c5083424cc37f5da792..9431e13a0ecfbd67462fe5d1cd6d5af40ebf1506 100644 (file)
@@ -99,6 +99,9 @@ class DefaultDialect(base.Dialect):
     def defaultrunner(self, connection):
         return base.DefaultRunner(connection)
 
+    def is_disconnect(self, e):
+        return False
+        
     def _set_paramstyle(self, style):
         self._paramstyle = style
         self._figure_paramstyle(style)
index 2f3b4519976844f57d5ee1f2cf9d5adb90bf0c0b..1b760fca8b2f28d6ccc2174b8cee2bff5caeb654 100644 (file)
@@ -86,7 +86,7 @@ class DefaultEngineStrategy(EngineStrategy):
                 if tk in kwargs:
                     pool_args[k] = kwargs.pop(tk)
             pool_args['use_threadlocal'] = self.pool_threadlocal()
-            pool = poolclass(creator, disconnect_checker=dialect.get_disconnect_checker(), **pool_args)
+            pool = poolclass(creator, **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
                 pool = pool.get_pool(*cargs, **cparams)
index 0b1ac2630f31f9b56160c1596e4278371abde518..a617f8fecd1cfadde87081cef79e809a58834f70 100644 (file)
@@ -126,7 +126,7 @@ class Pool(object):
     """
 
     def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True,
-                 disallow_open_cursors=False, disconnect_checker=None):
+                 disallow_open_cursors=False):
         self.logger = logging.instance_logger(self)
         self._threadconns = weakref.WeakValueDictionary()
         self._creator = creator
@@ -135,10 +135,6 @@ class Pool(object):
         self.auto_close_cursors = auto_close_cursors
         self.disallow_open_cursors = disallow_open_cursors
         self.echo = echo
-        if disconnect_checker:
-            self.disconnect_checker = disconnect_checker
-        else:
-            self.disconnect_checker = lambda x: False
     echo = logging.echo_property()
 
     def unique_connection(self):
@@ -318,22 +314,6 @@ class _CursorFairy(object):
         self.__parent._cursors[self] = True
         self.cursor = cursor
 
-    def execute(self, *args, **kwargs):
-        try:
-            self.cursor.execute(*args, **kwargs)
-        except Exception, e:
-            if self.__parent._pool.disconnect_checker(e):
-                self.invalidate(e=e)
-            raise
-
-    def executemany(self, *args, **kwargs):
-        try:
-            self.cursor.executemany(*args, **kwargs)
-        except Exception, e:
-            if self.__parent._pool.disconnect_checker(e):
-                self.invalidate(e=e)
-            raise
-
     def invalidate(self, e=None):
         self.__parent.invalidate(e=e)