]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added pool hooks for connection creation, check out and check in.
authorJason Kirtland <jek@discorporate.us>
Sat, 28 Jul 2007 19:51:55 +0000 (19:51 +0000)
committerJason Kirtland <jek@discorporate.us>
Sat, 28 Jul 2007 19:51:55 +0000 (19:51 +0000)
CHANGES
lib/sqlalchemy/exceptions.py
lib/sqlalchemy/interfaces.py [new file with mode: 0644]
lib/sqlalchemy/pool.py
test/engine/pool.py

diff --git a/CHANGES b/CHANGES
index ec8d8fcce12317deb0406af26343444948c7f783..a92199a785967f6a2c22fe9289eb97d38ad19bb6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     semantics for "__contains__" [ticket:606]
     
 - engines
+  - You can now hook into the pool lifecycle and run SQL statements or
+    other logic at new each DBAPI connection, pool check-out and check-in.
   - Connections gain a .properties collection, with contents scoped to the
     lifetime of the underlying DBAPI connection
 - extensions
index 55c345bd72d264b2b7b2bab895999c5ff8935533..7fe5cf518bd8cec061fad40ed26994d8102a6968 100644 (file)
@@ -89,3 +89,7 @@ class DBAPIError(SQLAlchemyError):
     def __init__(self, message, orig):
         SQLAlchemyError.__init__(self, "(%s) (%s) %s"% (message, orig.__class__.__name__, str(orig)))
         self.orig = orig
+
+class DisconnectionError(SQLAlchemyError):
+    """Raised within ``Pool`` when a disconnect is detected on a raw DBAPI connection."""
+    pass
diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py
new file mode 100644 (file)
index 0000000..5df19ce
--- /dev/null
@@ -0,0 +1,51 @@
+# interfaces.py
+# Copyright (C) 2007 Jason Kirtland jek@discorporate.us
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+class PoolListener(object):
+    """Hooks into the lifecycle of connections in a ``Pool``.
+
+    """
+
+    def connect(dbapi_con, con_record):
+        """Called once for each new DBAPI connection or pool's ``creator()``.
+
+        dbapi_con:
+          A newly connected raw DBAPI connection (not a SQLAlchemy
+          ``Connection`` wrapper).
+
+        con_record:
+          The ``_ConnectionRecord`` that currently owns the connection
+        """
+
+    def checkout(dbapi_con, con_record):
+        """Called when a connection is retrieved from the pool.
+
+        dbapi_con:
+          A raw DBAPI connection
+
+        con_record:
+          The ``_ConnectionRecord`` that currently owns the connection
+
+        If you raise an ``exceptions.DisconnectionError``, the current
+        connection will be disposed and a fresh connection retrieved.
+        Processing of all checkout listeners will abort and restart
+        using the new connection.
+        """
+
+    def checkin(dbapi_con, con_record):
+        """Called when a connection returns to the pool.
+
+        Note that the connection may be closed, and may be None if the
+        connection has been invalidated.  ``checkin`` will not be called
+        for detached connections.  (They do not return to the pool.)
+
+        dbapi_con:
+          A raw DBAPI connection
+
+        con_record:
+          The _ConnectionRecord that currently owns the connection
+        """
index f86e14ab1ef3e1aea4345d50624b3bf37977b2bf..02f7b1527b09250ce641409787c270470bbb51cf 100644 (file)
@@ -111,6 +111,11 @@ class Pool(object):
       surpassed the connection will be closed and replaced with a
       newly opened connection. Defaults to -1.
 
+    listeners
+      A list of ``PoolListener``-like objects that receive events when
+      DBAPI connections are created, checked out and checked in to the
+      pool.
+
     auto_close_cursors
       Cursors, returned by ``connection.cursor()``, are tracked and
       are automatically closed when the connection is returned to the
@@ -126,8 +131,9 @@ class Pool(object):
     False, then no cursor processing occurs upon checkin.
     """
 
-    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True,
-                 disallow_open_cursors=False):
+    def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False,
+                 auto_close_cursors=True, disallow_open_cursors=False,
+                 listeners=None):
         self.logger = logging.instance_logger(self)
         self._threadconns = weakref.WeakValueDictionary()
         self._creator = creator
@@ -136,6 +142,13 @@ class Pool(object):
         self.auto_close_cursors = auto_close_cursors
         self.disallow_open_cursors = disallow_open_cursors
         self.echo = echo
+        self.listeners = []
+        self._on_connect = []
+        self._on_checkout = []
+        self._on_checkin = []
+        if listeners:
+            for l in listeners:
+                self.add_listener(l)
     echo = logging.echo_property()
 
     def unique_connection(self):
@@ -183,6 +196,17 @@ class Pool(object):
     def status(self):
         raise NotImplementedError()
 
+    def add_listener(self, listener):
+        """Add a ``PoolListener``-like object to this pool."""
+
+        self.listeners.append(listener)
+        if hasattr(listener, 'connect'):
+            self._on_connect.append(listener)
+        if hasattr(listener, 'checkout'):
+            self._on_checkout.append(listener)
+        if hasattr(listener, 'checkin'):
+            self._on_checkin.append(listener)
+
     def log(self, msg):
         self.logger.info(msg)
 
@@ -191,6 +215,9 @@ class _ConnectionRecord(object):
         self.__pool = pool
         self.connection = self.__connect()
         self.properties = {}
+        if pool._on_connect:
+            for l in pool._on_connect:
+                l.connect(self.connection, self)
 
     def close(self):
         if self.connection is not None:
@@ -209,11 +236,17 @@ class _ConnectionRecord(object):
         if self.connection is None:
             self.connection = self.__connect()
             self.properties.clear()
+            if self.__pool._on_connect:
+                for l in self.__pool._on_connect:
+                    l.connect(self.connection, self)
         elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
             self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
             self.__close()
             self.connection = self.__connect()
             self.properties.clear()
+            if self.__pool._on_connect:
+                for l in self.__pool._on_connect:
+                    l.connect(self.connection, self)
         return self.connection
 
     def __close(self):
@@ -305,7 +338,27 @@ class _ConnectionFairy(object):
         if self.connection is None:
             raise exceptions.InvalidRequestError("This connection is closed")
         self.__counter +=1
-        return self
+
+        if not self._pool._on_checkout or self.__counter != 1:
+            return self
+
+        # Pool listeners can trigger a reconnection on checkout
+        attempts = 2
+        while attempts > 0:
+            try:
+                for l in self._pool._on_checkout:
+                    l.checkout(self.connection, self._connection_record)
+                return self
+            except exceptions.DisconnectionError, e:
+                self._pool.log(
+                    "Disconnection detected on checkout: %s" % (str(e)))
+                self._connection_record.invalidate(e)
+                self.connection = self._connection_record.get_connection()
+                attempts -= 1
+
+        self._pool.log("Reconnection attempts exhausted on checkout")
+        self.invalidate()
+        raise exceptions.InvalidRequestError("This connection is closed")
 
     def detach(self):
         """Separate this Connection from its Pool.
@@ -357,6 +410,9 @@ class _ConnectionFairy(object):
         if self._connection_record is not None:
             if self._pool.echo:
                 self._pool.log("Connection %s being returned to pool" % repr(self.connection))
+            if self._pool._on_checkin:
+                for l in self._pool._on_checkin:
+                    l.checkin(self.connection, self._connection_record)
             self._pool.return_conn(self)
         self.connection = None
         self._connection_record = None
index 364afa9d750d461eaf12560745d8537a7842e35b..7c8dc708ee102c14d5c5cedc88e8a826ff4e720e 100644 (file)
@@ -1,6 +1,7 @@
 import testbase
 import threading, thread, time
 import sqlalchemy.pool as pool
+import sqlalchemy.interfaces as interfaces
 import sqlalchemy.exceptions as exceptions
 from testlib import *
 
@@ -380,7 +381,152 @@ class PoolTest(PersistTest):
         self.assert_(c.connection is not c2.connection)
         self.assert_(not c2.properties)
         self.assert_('foo2' in c.properties)
+
+    def test_listeners(self):
+        dbapi = MockDBAPI()
+
+        class InstrumentingListener(object):
+            def __init__(self):
+                if hasattr(self, 'connect'):
+                    self.connect = self.inst_connect
+                if hasattr(self, 'checkout'):
+                    self.checkout = self.inst_checkout
+                if hasattr(self, 'checkin'):
+                    self.checkin = self.inst_checkin
+                self.clear()
+            def clear(self):
+                self.connected = []
+                self.checked_out = []
+                self.checked_in = []
+            def assert_total(innerself, conn, cout, cin):
+                self.assert_(len(innerself.connected) == conn)
+                self.assert_(len(innerself.checked_out) == cout)
+                self.assert_(len(innerself.checked_in) == cin)
+            def assert_in(innerself, item, in_conn, in_cout, in_cin):
+                self.assert_((item in innerself.connected) == in_conn)
+                self.assert_((item in innerself.checked_out) == in_cout)
+                self.assert_((item in innerself.checked_in) == in_cin)
+            def inst_connect(self, con, record):
+                print "connect(%s, %s)" % (con, record)
+                assert con is not None
+                assert record is not None
+                self.connected.append(con)
+            def inst_checkout(self, con, record):
+                print "checkout(%s, %s)" % (con, record)
+                assert con is not None
+                assert record is not None
+                self.checked_out.append(con)
+            def inst_checkin(self, con, record):
+                print "checkin(%s, %s)" % (con, record)
+                # con can be None if invalidated
+                assert record is not None
+                self.checked_in.append(con)
+        class ListenAll(interfaces.PoolListener, InstrumentingListener):
+            pass
+        class ListenConnect(InstrumentingListener):
+            def connect(self, con, record):
+                pass
+        class ListenCheckOut(InstrumentingListener):
+            def checkout(self, con, record, num):
+                pass
+        class ListenCheckIn(InstrumentingListener):
+            def checkin(self, con, record):
+                pass
+
+        def _pool(**kw):
+            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw)
+            #, pool_size=1, max_overflow=0, **kw)
+
+        def assert_listeners(p, total, conn, cout, cin):
+            self.assert_(len(p.listeners) == total)
+            self.assert_(len(p._on_connect) == conn)
+            self.assert_(len(p._on_checkout) == cout)
+            self.assert_(len(p._on_checkin) == cin)
             
+        p = _pool()
+        assert_listeners(p, 0, 0, 0, 0)
+
+        p.add_listener(ListenAll())
+        assert_listeners(p, 1, 1, 1, 1)
+
+        p.add_listener(ListenConnect())
+        assert_listeners(p, 2, 2, 1, 1)
+
+        p.add_listener(ListenCheckOut())
+        assert_listeners(p, 3, 2, 2, 1)
+
+        p.add_listener(ListenCheckIn())
+        assert_listeners(p, 4, 2, 2, 2)
+        del p
+
+        print "----"
+        snoop = ListenAll()
+        p = _pool(listeners=[snoop])
+        assert_listeners(p, 1, 1, 1, 1)
+
+        c = p.connect()
+        snoop.assert_total(1, 1, 0)
+        cc = c.connection
+        snoop.assert_in(cc, True, True, False)
+        c.close()
+        snoop.assert_in(cc, True, True, True)
+        del c, cc
+
+        snoop.clear()
+
+        # this one depends on immediate gc 
+        c = p.connect()
+        cc = c.connection
+        snoop.assert_in(cc, False, True, False)
+        snoop.assert_total(0, 1, 0)
+        del c, cc
+        snoop.assert_total(0, 1, 1)
+
+        p.dispose()
+        snoop.clear()
+
+        c = p.connect()
+        c.close()
+        c = p.connect()
+        snoop.assert_total(1, 2, 1)
+        c.close()
+        snoop.assert_total(1, 2, 2)
+
+        # invalidation
+        p.dispose()
+        snoop.clear()
+
+        c = p.connect()
+        snoop.assert_total(1, 1, 0)
+        c.invalidate()
+        snoop.assert_total(1, 1, 1)
+        c.close()
+        snoop.assert_total(1, 1, 1)
+        del c
+        snoop.assert_total(1, 1, 1)
+        c = p.connect()
+        snoop.assert_total(2, 2, 1)
+        c.close()
+        del c
+        snoop.assert_total(2, 2, 2)
+
+        # detached
+        p.dispose()
+        snoop.clear()
+
+        c = p.connect()
+        snoop.assert_total(1, 1, 0)
+        c.detach()
+        snoop.assert_total(1, 1, 0)
+        c.close()
+        del c
+        snoop.assert_total(1, 1, 0)
+        c = p.connect()
+        snoop.assert_total(2, 2, 0)
+        c.close()
+        del c
+        snoop.assert_total(2, 2, 1)
+
     def tearDown(self):
        pool.clear_managers()