From 030ef1f0ef37ccaebde06e58f22cd0de5a74c5d0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 3 Apr 2007 16:06:06 +0000 Subject: [PATCH] for #516, moved the "disconnect check" step out of pool and back into base.py. dialects have is_disconnect() method now. simpler design which also puts control of the ultimate "execute" call back into the hands of the dialects. --- lib/sqlalchemy/databases/mssql.py | 18 ++++++------------ lib/sqlalchemy/databases/mysql.py | 7 ++----- lib/sqlalchemy/databases/postgres.py | 22 ++++++++++------------ lib/sqlalchemy/engine/base.py | 11 +++++++---- lib/sqlalchemy/engine/default.py | 3 +++ lib/sqlalchemy/engine/strategies.py | 2 +- lib/sqlalchemy/pool.py | 22 +--------------------- 7 files changed, 30 insertions(+), 55 deletions(-) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 013e78c6af..a2d4ac36e0 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -572,10 +572,8 @@ class MSSQLDialect_pymssql(MSSQLDialect): del keys['port'] return [[], keys] - def get_disconnect_checker(self): - def disconnect_checker(e): - return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) - return disconnect_checker + def is_disconnect(self, e): + return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) ## This code is leftover from the initial implementation, for reference @@ -636,10 +634,8 @@ class MSSQLDialect_pyodbc(MSSQLDialect): connectors.append ("TrustedConnection=Yes") return [[";".join (connectors)], {}] - def get_disconnect_checker(self): - def disconnect_checker(e): - return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1] - return disconnect_checker + def is_disconnect(self, e): + return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1] class MSSQLDialect_adodbapi(MSSQLDialect): @@ -671,10 +667,8 @@ class MSSQLDialect_adodbapi(MSSQLDialect): connectors.append("Integrated Security=SSPI") return [[";".join (connectors)], {}] - def get_disconnect_checker(self): - def disconnect_checker(e): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) - return disconnect_checker + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) dialect_mapping = { 'pymssql': MSSQLDialect_pymssql, diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 7ea98e92f4..03297cd686 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -342,11 +342,8 @@ class MySQLDialect(ansisql.ANSIDialect): except: pass - def get_disconnect_checker(self): - def disconnect_checker(e): - return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2014) - return disconnect_checker - + def is_disconnect(self, e): + return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2014) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a26ef76b6f..6facde9367 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -338,18 +338,16 @@ class PGDialect(ansisql.ANSIDialect): cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name}) return bool(not not cursor.rowcount) - def get_disconnect_checker(self): - def disconnect_checker(e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - # yes, it really says "losed", not "closed" - return "losed the connection unexpectedly" in str(e) - else: - return False - return disconnect_checker + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + elif isinstance(e, self.dbapi.InterfaceError): + return 'connection already closed' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + # yes, it really says "losed", not "closed" + return "losed the connection unexpectedly" in str(e) + else: + return False def reflecttable(self, connection, table): if self.version == 2: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 80d93e61cc..f5b4b377e4 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -246,10 +246,9 @@ class Dialect(sql.AbstractDialect): return clauseelement.compile(dialect=self, parameters=parameters) - def get_disconnect_checker(self): - """Return a callable that determines if an SQLError is caused by a database disconnection.""" - - return lambda x: False + def is_disconnect(self, e): + """Return True if the given DBAPI error indicates an invalid connection""" + raise NotImplementedError() class ExecutionContext(object): @@ -576,6 +575,8 @@ class Connection(Connectable): try: context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context) except Exception, e: + if self.dialect.is_disconnect(e): + self.__connection.invalidate(e=e) self._autorollback() if self.__close_with_result: self.close() @@ -585,6 +586,8 @@ class Connection(Connectable): try: context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) except Exception, e: + if self.dialect.is_disconnect(e): + self.__connection.invalidate(e=e) self._autorollback() if self.__close_with_result: self.close() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ceecee364f..9431e13a0e 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -99,6 +99,9 @@ class DefaultDialect(base.Dialect): def defaultrunner(self, connection): return base.DefaultRunner(connection) + def is_disconnect(self, e): + return False + def _set_paramstyle(self, style): self._paramstyle = style self._figure_paramstyle(style) diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 2f3b451997..1b760fca8b 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -86,7 +86,7 @@ class DefaultEngineStrategy(EngineStrategy): if tk in kwargs: pool_args[k] = kwargs.pop(tk) pool_args['use_threadlocal'] = self.pool_threadlocal() - pool = poolclass(creator, disconnect_checker=dialect.get_disconnect_checker(), **pool_args) + pool = poolclass(creator, **pool_args) else: if isinstance(pool, poollib._DBProxy): pool = pool.get_pool(*cargs, **cparams) diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 0b1ac2630f..a617f8fecd 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -126,7 +126,7 @@ class Pool(object): """ def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True, - disallow_open_cursors=False, disconnect_checker=None): + disallow_open_cursors=False): self.logger = logging.instance_logger(self) self._threadconns = weakref.WeakValueDictionary() self._creator = creator @@ -135,10 +135,6 @@ class Pool(object): self.auto_close_cursors = auto_close_cursors self.disallow_open_cursors = disallow_open_cursors self.echo = echo - if disconnect_checker: - self.disconnect_checker = disconnect_checker - else: - self.disconnect_checker = lambda x: False echo = logging.echo_property() def unique_connection(self): @@ -318,22 +314,6 @@ class _CursorFairy(object): self.__parent._cursors[self] = True self.cursor = cursor - def execute(self, *args, **kwargs): - try: - self.cursor.execute(*args, **kwargs) - except Exception, e: - if self.__parent._pool.disconnect_checker(e): - self.invalidate(e=e) - raise - - def executemany(self, *args, **kwargs): - try: - self.cursor.executemany(*args, **kwargs) - except Exception, e: - if self.__parent._pool.disconnect_checker(e): - self.invalidate(e=e) - raise - def invalidate(self, e=None): self.__parent.invalidate(e=e) -- 2.47.2