From 47332cb1477f6a5affb5b78bb8aff0523c499d90 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 26 Aug 2006 21:32:11 +0000 Subject: [PATCH] - changed "invalidate" semantics with pooled connection; will instruct the underlying connection record to reconnect the next time its called. "invalidate" will also automatically be called if any error is thrown in the underlying call to connection.cursor(). this will hopefully allow the connection pool to reconnect to a database that had been stopped and started without restarting the connecting application [ticket:121] --- CHANGES | 7 ++++ lib/sqlalchemy/engine/base.py | 3 +- lib/sqlalchemy/pool.py | 66 +++++++++++++++-------------------- test/engine/pool.py | 35 +++++++++++++++++++ 4 files changed, 73 insertions(+), 38 deletions(-) diff --git a/CHANGES b/CHANGES index 3be73d3d52..085a5979ed 100644 --- a/CHANGES +++ b/CHANGES @@ -7,6 +7,13 @@ function to 'create_engine'. defaults to 3600 seconds; connections after this age will be closed and replaced with a new one, to handle db's that automatically close stale connections [ticket:274] +- changed "invalidate" semantics with pooled connection; will +instruct the underlying connection record to reconnect the next +time its called. "invalidate" will also automatically be called +if any error is thrown in the underlying call to connection.cursor(). +this will hopefully allow the connection pool to reconnect to a +database that had been stopped and started without restarting +the connecting application [ticket:121] - eesh ! the tutorial doctest was broken for quite some time. - add_property() method on mapper does a "compile all mappers" step in case the given property references a non-compiled mapper diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6a96fbcfbe..ce6cc7d82e 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -342,7 +342,8 @@ class Connection(Connectable): try: self.__engine.dialect.do_executemany(c, statement, parameters, context=context) except Exception, e: - self._rollback_impl() + self._autorollback() + #self._rollback_impl() if self.__close_with_result: self.close() raise exceptions.SQLError(statement, parameters, e) diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 211f96070d..577405b0f8 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -104,9 +104,6 @@ class Pool(object): def return_conn(self, agent): self.do_return_conn(agent._connection_record) - def return_invalid(self, agent): - self.do_return_invalid(agent._connection_record) - def get(self): return self.do_get() @@ -116,9 +113,6 @@ class Pool(object): def do_return_conn(self, conn): raise NotImplementedError() - def do_return_invalid(self, conn): - raise NotImplementedError() - def status(self): raise NotImplementedError() @@ -133,26 +127,35 @@ class Pool(object): class _ConnectionRecord(object): def __init__(self, pool): - self.pool = pool + self.__pool = pool self.connection = self.__connect() def close(self): self.connection.close() + def invalidate(self): + self.__pool.log("Invalidate connection %s" % repr(self.connection)) + self.__close() + self.connection = None def get_connection(self): - if self.pool._recycle > -1 and time.time() - self.starttime > self.pool._recycle: - self.pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) - try: - self.connection.close() - except Exception, e: - self.pool.log("Connection %s threw an error: %s" % (repr(self.connection), str(e))) + if self.connection is None: + self.connection = self.__connect() + elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle): + self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) + self.__close() self.connection = self.__connect() return self.connection + def __close(self): + try: + self.__pool.log("Closing connection %s" % (repr(self.connection))) + self.connection.close() + except Exception, e: + self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e))) def __connect(self): try: self.starttime = time.time() - return self.pool._creator() - except: + return self.__pool._creator() + except Exception, e: + self.__pool.log("Error on connect(): %s" % (str(e))) raise - # TODO: reconnect support here ? class _ThreadFairy(object): """marks a thread identifier as owning a connection, for a thread local pool.""" @@ -171,19 +174,19 @@ class _ConnectionFairy(object): except: self.connection = None # helps with endless __getattr__ loops later on self._connection_record = None - self.__pool.return_invalid(self) raise if self.__pool.echo: self.__pool.log("Connection %s checked out from pool" % repr(self.connection)) def invalidate(self): - if self.__pool.echo: - self.__pool.log("Invalidate connection %s" % repr(self.connection)) + self._connection_record.invalidate() self.connection = None - self._connection_record = None - self._threadfairy = None - self.__pool.return_invalid(self) + self._close() def cursor(self, *args, **kwargs): - return _CursorFairy(self, self.connection.cursor(*args, **kwargs)) + try: + return _CursorFairy(self, self.connection.cursor(*args, **kwargs)) + except Exception, e: + self.invalidate() + raise def __getattr__(self, key): return getattr(self.connection, key) def checkout(self): @@ -199,16 +202,15 @@ class _ConnectionFairy(object): self._close() def _close(self): if self.connection is not None: - if self.__pool.echo: - self.__pool.log("Connection %s being returned to pool" % repr(self.connection)) try: self.connection.rollback() except: # damn mysql -- (todo look for NotSupportedError) pass + if self._connection_record is not None: + if self.__pool.echo: + self.__pool.log("Connection %s being returned to pool" % repr(self.connection)) self.__pool.return_conn(self) - self.__pool = None - self.connection = None self._connection_record = None self._threadfairy = None @@ -257,12 +259,6 @@ class SingletonThreadPool(Pool): def do_return_conn(self, conn): pass - def do_return_invalid(self, conn): - try: - del self._conns[thread.get_ident()] - except KeyError: - pass - def do_get(self): try: return self._conns[thread.get_ident()] @@ -288,10 +284,6 @@ class QueuePool(Pool): except Queue.Full: self._overflow -= 1 - def do_return_invalid(self, conn): - if conn is not None: - self._overflow -= 1 - def do_get(self): try: return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout) diff --git a/test/engine/pool.py b/test/engine/pool.py index cfc8f5684a..c6b9681309 100644 --- a/test/engine/pool.py +++ b/test/engine/pool.py @@ -6,7 +6,11 @@ import sqlalchemy.pool as pool import sqlalchemy.exceptions as exceptions class MockDBAPI(object): + def __init__(self): + self.throw_error = False def connect(self, argument): + if self.throw_error: + raise Exception("couldnt connect !") return MockConnection() class MockConnection(object): def close(self): @@ -154,6 +158,37 @@ class PoolTest(PersistTest): time.sleep(3) c3= p.connect() assert id(c3.connection) != c_id + + def test_invalidate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True) + c1 = p.connect() + c_id = id(c1.connection) + c1.close(); c1=None + + c1 = p.connect() + assert id(c1.connection) == c_id + c1.invalidate() + c1 = None + + c1 = p.connect() + assert id(c1.connection) != c_id + + def test_reconnect(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True) + c1 = p.connect() + c_id = id(c1.connection) + c1.close(); c1=None + + c1 = p.connect() + assert id(c1.connection) == c_id + dbapi.raise_error = True + c1.invalidate() + c1 = None + + c1 = p.connect() + assert id(c1.connection) != c_id def testthreadlocal_del(self): self._do_testthreadlocal(useclose=False) -- 2.47.2