defaults to 3600 seconds; connections after this age will be closed and
replaced with a new one, to handle db's that automatically close
stale connections [ticket:274]
+- changed "invalidate" semantics with pooled connection; will
+instruct the underlying connection record to reconnect the next
+time its called. "invalidate" will also automatically be called
+if any error is thrown in the underlying call to connection.cursor().
+this will hopefully allow the connection pool to reconnect to a
+database that had been stopped and started without restarting
+the connecting application [ticket:121]
- eesh ! the tutorial doctest was broken for quite some time.
- add_property() method on mapper does a "compile all mappers"
step in case the given property references a non-compiled mapper
def return_conn(self, agent):
self.do_return_conn(agent._connection_record)
- def return_invalid(self, agent):
- self.do_return_invalid(agent._connection_record)
-
def get(self):
return self.do_get()
def do_return_conn(self, conn):
raise NotImplementedError()
- def do_return_invalid(self, conn):
- raise NotImplementedError()
-
def status(self):
raise NotImplementedError()
class _ConnectionRecord(object):
def __init__(self, pool):
- self.pool = pool
+ self.__pool = pool
self.connection = self.__connect()
def close(self):
self.connection.close()
+ def invalidate(self):
+ self.__pool.log("Invalidate connection %s" % repr(self.connection))
+ self.__close()
+ self.connection = None
def get_connection(self):
- if self.pool._recycle > -1 and time.time() - self.starttime > self.pool._recycle:
- self.pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
- try:
- self.connection.close()
- except Exception, e:
- self.pool.log("Connection %s threw an error: %s" % (repr(self.connection), str(e)))
+ if self.connection is None:
+ self.connection = self.__connect()
+ elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
+ self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
+ self.__close()
self.connection = self.__connect()
return self.connection
+ def __close(self):
+ try:
+ self.__pool.log("Closing connection %s" % (repr(self.connection)))
+ self.connection.close()
+ except Exception, e:
+ self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
def __connect(self):
try:
self.starttime = time.time()
- return self.pool._creator()
- except:
+ return self.__pool._creator()
+ except Exception, e:
+ self.__pool.log("Error on connect(): %s" % (str(e)))
raise
- # TODO: reconnect support here ?
class _ThreadFairy(object):
"""marks a thread identifier as owning a connection, for a thread local pool."""
except:
self.connection = None # helps with endless __getattr__ loops later on
self._connection_record = None
- self.__pool.return_invalid(self)
raise
if self.__pool.echo:
self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
def invalidate(self):
- if self.__pool.echo:
- self.__pool.log("Invalidate connection %s" % repr(self.connection))
+ self._connection_record.invalidate()
self.connection = None
- self._connection_record = None
- self._threadfairy = None
- self.__pool.return_invalid(self)
+ self._close()
def cursor(self, *args, **kwargs):
- return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+ try:
+ return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+ except Exception, e:
+ self.invalidate()
+ raise
def __getattr__(self, key):
return getattr(self.connection, key)
def checkout(self):
self._close()
def _close(self):
if self.connection is not None:
- if self.__pool.echo:
- self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
try:
self.connection.rollback()
except:
# damn mysql -- (todo look for NotSupportedError)
pass
+ if self._connection_record is not None:
+ if self.__pool.echo:
+ self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
self.__pool.return_conn(self)
- self.__pool = None
- self.connection = None
self._connection_record = None
self._threadfairy = None
def do_return_conn(self, conn):
pass
- def do_return_invalid(self, conn):
- try:
- del self._conns[thread.get_ident()]
- except KeyError:
- pass
-
def do_get(self):
try:
return self._conns[thread.get_ident()]
except Queue.Full:
self._overflow -= 1
- def do_return_invalid(self, conn):
- if conn is not None:
- self._overflow -= 1
-
def do_get(self):
try:
return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout)
import sqlalchemy.exceptions as exceptions
class MockDBAPI(object):
+ def __init__(self):
+ self.throw_error = False
def connect(self, argument):
+ if self.throw_error:
+ raise Exception("couldnt connect !")
return MockConnection()
class MockConnection(object):
def close(self):
time.sleep(3)
c3= p.connect()
assert id(c3.connection) != c_id
+
+ def test_invalidate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ assert id(c1.connection) == c_id
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert id(c1.connection) != c_id
+
+ def test_reconnect(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ assert id(c1.connection) == c_id
+ dbapi.raise_error = True
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert id(c1.connection) != c_id
def testthreadlocal_del(self):
self._do_testthreadlocal(useclose=False)