From 9f100231798d83f2bf4a53494eb5199864a0094d Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Sat, 28 Jul 2007 19:51:55 +0000 Subject: [PATCH] Added pool hooks for connection creation, check out and check in. --- CHANGES | 2 + lib/sqlalchemy/exceptions.py | 4 + lib/sqlalchemy/interfaces.py | 51 ++++++++++++ lib/sqlalchemy/pool.py | 62 ++++++++++++++- test/engine/pool.py | 146 +++++++++++++++++++++++++++++++++++ 5 files changed, 262 insertions(+), 3 deletions(-) create mode 100644 lib/sqlalchemy/interfaces.py diff --git a/CHANGES b/CHANGES index ec8d8fcce1..a92199a785 100644 --- a/CHANGES +++ b/CHANGES @@ -206,6 +206,8 @@ semantics for "__contains__" [ticket:606] - engines + - You can now hook into the pool lifecycle and run SQL statements or + other logic at new each DBAPI connection, pool check-out and check-in. - Connections gain a .properties collection, with contents scoped to the lifetime of the underlying DBAPI connection - extensions diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index 55c345bd72..7fe5cf518b 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -89,3 +89,7 @@ class DBAPIError(SQLAlchemyError): def __init__(self, message, orig): SQLAlchemyError.__init__(self, "(%s) (%s) %s"% (message, orig.__class__.__name__, str(orig))) self.orig = orig + +class DisconnectionError(SQLAlchemyError): + """Raised within ``Pool`` when a disconnect is detected on a raw DBAPI connection.""" + pass diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py new file mode 100644 index 0000000000..5df19ceef6 --- /dev/null +++ b/lib/sqlalchemy/interfaces.py @@ -0,0 +1,51 @@ +# interfaces.py +# Copyright (C) 2007 Jason Kirtland jek@discorporate.us +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +class PoolListener(object): + """Hooks into the lifecycle of connections in a ``Pool``. + + """ + + def connect(dbapi_con, con_record): + """Called once for each new DBAPI connection or pool's ``creator()``. + + dbapi_con: + A newly connected raw DBAPI connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record: + The ``_ConnectionRecord`` that currently owns the connection + """ + + def checkout(dbapi_con, con_record): + """Called when a connection is retrieved from the pool. + + dbapi_con: + A raw DBAPI connection + + con_record: + The ``_ConnectionRecord`` that currently owns the connection + + If you raise an ``exceptions.DisconnectionError``, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + """ + + def checkin(dbapi_con, con_record): + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + dbapi_con: + A raw DBAPI connection + + con_record: + The _ConnectionRecord that currently owns the connection + """ diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index f86e14ab1e..02f7b1527b 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -111,6 +111,11 @@ class Pool(object): surpassed the connection will be closed and replaced with a newly opened connection. Defaults to -1. + listeners + A list of ``PoolListener``-like objects that receive events when + DBAPI connections are created, checked out and checked in to the + pool. + auto_close_cursors Cursors, returned by ``connection.cursor()``, are tracked and are automatically closed when the connection is returned to the @@ -126,8 +131,9 @@ class Pool(object): False, then no cursor processing occurs upon checkin. """ - def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True, - disallow_open_cursors=False): + def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, + auto_close_cursors=True, disallow_open_cursors=False, + listeners=None): self.logger = logging.instance_logger(self) self._threadconns = weakref.WeakValueDictionary() self._creator = creator @@ -136,6 +142,13 @@ class Pool(object): self.auto_close_cursors = auto_close_cursors self.disallow_open_cursors = disallow_open_cursors self.echo = echo + self.listeners = [] + self._on_connect = [] + self._on_checkout = [] + self._on_checkin = [] + if listeners: + for l in listeners: + self.add_listener(l) echo = logging.echo_property() def unique_connection(self): @@ -183,6 +196,17 @@ class Pool(object): def status(self): raise NotImplementedError() + def add_listener(self, listener): + """Add a ``PoolListener``-like object to this pool.""" + + self.listeners.append(listener) + if hasattr(listener, 'connect'): + self._on_connect.append(listener) + if hasattr(listener, 'checkout'): + self._on_checkout.append(listener) + if hasattr(listener, 'checkin'): + self._on_checkin.append(listener) + def log(self, msg): self.logger.info(msg) @@ -191,6 +215,9 @@ class _ConnectionRecord(object): self.__pool = pool self.connection = self.__connect() self.properties = {} + if pool._on_connect: + for l in pool._on_connect: + l.connect(self.connection, self) def close(self): if self.connection is not None: @@ -209,11 +236,17 @@ class _ConnectionRecord(object): if self.connection is None: self.connection = self.__connect() self.properties.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle): self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) self.__close() self.connection = self.__connect() self.properties.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) return self.connection def __close(self): @@ -305,7 +338,27 @@ class _ConnectionFairy(object): if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") self.__counter +=1 - return self + + if not self._pool._on_checkout or self.__counter != 1: + return self + + # Pool listeners can trigger a reconnection on checkout + attempts = 2 + while attempts > 0: + try: + for l in self._pool._on_checkout: + l.checkout(self.connection, self._connection_record) + return self + except exceptions.DisconnectionError, e: + self._pool.log( + "Disconnection detected on checkout: %s" % (str(e))) + self._connection_record.invalidate(e) + self.connection = self._connection_record.get_connection() + attempts -= 1 + + self._pool.log("Reconnection attempts exhausted on checkout") + self.invalidate() + raise exceptions.InvalidRequestError("This connection is closed") def detach(self): """Separate this Connection from its Pool. @@ -357,6 +410,9 @@ class _ConnectionFairy(object): if self._connection_record is not None: if self._pool.echo: self._pool.log("Connection %s being returned to pool" % repr(self.connection)) + if self._pool._on_checkin: + for l in self._pool._on_checkin: + l.checkin(self.connection, self._connection_record) self._pool.return_conn(self) self.connection = None self._connection_record = None diff --git a/test/engine/pool.py b/test/engine/pool.py index 364afa9d75..7c8dc708ee 100644 --- a/test/engine/pool.py +++ b/test/engine/pool.py @@ -1,6 +1,7 @@ import testbase import threading, thread, time import sqlalchemy.pool as pool +import sqlalchemy.interfaces as interfaces import sqlalchemy.exceptions as exceptions from testlib import * @@ -380,7 +381,152 @@ class PoolTest(PersistTest): self.assert_(c.connection is not c2.connection) self.assert_(not c2.properties) self.assert_('foo2' in c.properties) + + def test_listeners(self): + dbapi = MockDBAPI() + + class InstrumentingListener(object): + def __init__(self): + if hasattr(self, 'connect'): + self.connect = self.inst_connect + if hasattr(self, 'checkout'): + self.checkout = self.inst_checkout + if hasattr(self, 'checkin'): + self.checkin = self.inst_checkin + self.clear() + def clear(self): + self.connected = [] + self.checked_out = [] + self.checked_in = [] + def assert_total(innerself, conn, cout, cin): + self.assert_(len(innerself.connected) == conn) + self.assert_(len(innerself.checked_out) == cout) + self.assert_(len(innerself.checked_in) == cin) + def assert_in(innerself, item, in_conn, in_cout, in_cin): + self.assert_((item in innerself.connected) == in_conn) + self.assert_((item in innerself.checked_out) == in_cout) + self.assert_((item in innerself.checked_in) == in_cin) + def inst_connect(self, con, record): + print "connect(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.connected.append(con) + def inst_checkout(self, con, record): + print "checkout(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.checked_out.append(con) + def inst_checkin(self, con, record): + print "checkin(%s, %s)" % (con, record) + # con can be None if invalidated + assert record is not None + self.checked_in.append(con) + class ListenAll(interfaces.PoolListener, InstrumentingListener): + pass + class ListenConnect(InstrumentingListener): + def connect(self, con, record): + pass + class ListenCheckOut(InstrumentingListener): + def checkout(self, con, record, num): + pass + class ListenCheckIn(InstrumentingListener): + def checkin(self, con, record): + pass + + def _pool(**kw): + return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw) + #, pool_size=1, max_overflow=0, **kw) + + def assert_listeners(p, total, conn, cout, cin): + self.assert_(len(p.listeners) == total) + self.assert_(len(p._on_connect) == conn) + self.assert_(len(p._on_checkout) == cout) + self.assert_(len(p._on_checkin) == cin) + p = _pool() + assert_listeners(p, 0, 0, 0, 0) + + p.add_listener(ListenAll()) + assert_listeners(p, 1, 1, 1, 1) + + p.add_listener(ListenConnect()) + assert_listeners(p, 2, 2, 1, 1) + + p.add_listener(ListenCheckOut()) + assert_listeners(p, 3, 2, 2, 1) + + p.add_listener(ListenCheckIn()) + assert_listeners(p, 4, 2, 2, 2) + del p + + print "----" + snoop = ListenAll() + p = _pool(listeners=[snoop]) + assert_listeners(p, 1, 1, 1, 1) + + c = p.connect() + snoop.assert_total(1, 1, 0) + cc = c.connection + snoop.assert_in(cc, True, True, False) + c.close() + snoop.assert_in(cc, True, True, True) + del c, cc + + snoop.clear() + + # this one depends on immediate gc + c = p.connect() + cc = c.connection + snoop.assert_in(cc, False, True, False) + snoop.assert_total(0, 1, 0) + del c, cc + snoop.assert_total(0, 1, 1) + + p.dispose() + snoop.clear() + + c = p.connect() + c.close() + c = p.connect() + snoop.assert_total(1, 2, 1) + c.close() + snoop.assert_total(1, 2, 2) + + # invalidation + p.dispose() + snoop.clear() + + c = p.connect() + snoop.assert_total(1, 1, 0) + c.invalidate() + snoop.assert_total(1, 1, 1) + c.close() + snoop.assert_total(1, 1, 1) + del c + snoop.assert_total(1, 1, 1) + c = p.connect() + snoop.assert_total(2, 2, 1) + c.close() + del c + snoop.assert_total(2, 2, 2) + + # detached + p.dispose() + snoop.clear() + + c = p.connect() + snoop.assert_total(1, 1, 0) + c.detach() + snoop.assert_total(1, 1, 0) + c.close() + del c + snoop.assert_total(1, 1, 0) + c = p.connect() + snoop.assert_total(2, 2, 0) + c.close() + del c + snoop.assert_total(2, 2, 1) + def tearDown(self): pool.clear_managers() -- 2.47.3