]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Pool listeners may now be specified as a duck-type of PoolListener or a dict of...
authorJason Kirtland <jek@discorporate.us>
Fri, 4 Apr 2008 19:07:30 +0000 (19:07 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 4 Apr 2008 19:07:30 +0000 (19:07 +0000)
CHANGES
lib/sqlalchemy/pool.py
lib/sqlalchemy/util.py
test/base/utils.py
test/engine/pool.py

diff --git a/CHANGES b/CHANGES
index b0fbbc8012f76fa2fe80c5a58794c63c49ebad47..750410db26877f79c5b0094ef7b34622590bbe0a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,12 +1,18 @@
 =======
 CHANGES
 =======
+
 0.4.6
 =====
 - sql
     - Fixed bug with union() when applied to non-Table connected
       select statements
-      
+
+- engines
+    - Pool listeners can now be provided as a dictionary of
+      callables or a (possibly partial) duck-type of
+      PoolListener, your choice.
+
 0.4.5
 =====
 - orm
index 94d9127f0cd061fecd6c26f4e5d18b5a39178ba7..71c0c82df5e43c7ef35580561e74df61aff24111 100644 (file)
@@ -20,7 +20,7 @@ import weakref, time
 
 from sqlalchemy import exceptions, logging
 from sqlalchemy import queue as Queue
-from sqlalchemy.util import thread, threading, pickle
+from sqlalchemy.util import thread, threading, pickle, as_interface
 
 proxies = {}
 
@@ -106,9 +106,10 @@ class Pool(object):
       newly opened connection. Defaults to -1.
 
     listeners
-      A list of ``PoolListener``-like objects that receive events when
-      DB-API connections are created, checked out and checked in to
-      the pool.
+      A list of ``PoolListener``-like objects or dictionaries of callables
+      that receive events when DB-API connections are created, checked out and
+      checked in to the pool.
+
     """
 
     def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True,
@@ -182,7 +183,16 @@ class Pool(object):
         raise NotImplementedError()
 
     def add_listener(self, listener):
-        """Add a ``PoolListener``-like object to this pool."""
+        """Add a ``PoolListener``-like object to this pool.
+
+        ``listener`` may be an object that implements some or all of
+        PoolListener, or a dictionary of callables containing implementations
+        of some or all of the named methods in PoolListener.
+
+        """
+
+        listener = as_interface(
+            listener, methods=('connect', 'checkout', 'checkin'))
 
         self.listeners.append(listener)
         if hasattr(listener, 'connect'):
index 36b40a04de2fdf420b755e595c297514f3dc2c04..740e97e3f0b6d37a57c28987e77b89ab5c12a2bd 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import inspect, itertools, new, sets, sys, warnings, weakref
+import inspect, itertools, new, operator, sets, sys, warnings, weakref
 import __builtin__
 types = __import__('types')
 
@@ -1060,6 +1060,87 @@ class symbol(object):
             symbol._lock.release()
 
 
+def as_interface(obj, cls=None, methods=None, required=None):
+    """Ensure basic interface compliance for an instance or dict of callables.
+
+    Checks that ``obj`` implements public methods of ``cls`` or has members
+    listed in ``methods``.  If ``required`` is not supplied, implementing at
+    least one interface method is sufficient.  Methods present on ``obj`` that
+    are not in the interface are ignored.
+
+    If ``obj`` is a dict and ``dict`` does not meet the interface
+    requirements, the keys of the dictionary are inspected. Keys present in
+    ``obj`` that are not in the interface will raise TypeErrors.
+
+    Raises TypeError if ``obj`` does not meet the interface criteria.
+
+    In all passing cases, an object with callable members is returned.  In the
+    simple case, ``obj`` is returned as-is; if dict processing kicks in then
+    an anonymous class is returned.
+
+    obj
+      A type, instance, or dictionary of callables.
+    cls
+      Optional, a type.  All public methods of cls are considered the
+      interface.  An ``obj`` instance of cls will always pass, ignoring
+      ``required``..
+    methods
+      Optional, a sequence of method names to consider as the interface.
+    required
+      Optional, a sequence of mandatory implementations. If omitted, an
+      ``obj`` that provides at least one interface method is considered
+      sufficient.  As a convenience, required may be a type, in which case
+      all public methods of the type are required.
+
+    """
+    if not cls and not methods:
+        raise TypeError('a class or collection of method names are required')
+
+    if isinstance(cls, type) and isinstance(obj, cls):
+        return obj
+
+    interface = Set(methods or [m for m in dir(cls) if not m.startswith('_')])
+    implemented = Set(dir(obj))
+
+    complies = operator.ge
+    if isinstance(required, type):
+        required = interface
+    elif not required:
+        required = Set()
+        complies = operator.gt
+    else:
+        required = Set(required)
+
+    if complies(implemented.intersection(interface), required):
+        return obj
+
+    # No dict duck typing here.
+    if not type(obj) is dict:
+        qualifier = complies is operator.gt and 'any of' or 'all of'
+        raise TypeError("%r does not implement %s: %s" % (
+            obj, qualifier, ', '.join(interface)))
+
+    class AnonymousInterface(object):
+        """A callable-holding shell."""
+
+    if cls:
+        AnonymousInterface.__name__ = 'Anonymous' + cls.__name__
+    found = Set()
+
+    for method, impl in dictlike_iteritems(obj):
+        if method not in interface:
+            raise TypeError("%r: unknown in this interface" % method)
+        if not callable(impl):
+            raise TypeError("%r=%r is not callable" % (method, impl))
+        setattr(AnonymousInterface, method, staticmethod(impl))
+        found.add(method)
+
+    if complies(found, required):
+        return AnonymousInterface
+
+    raise TypeError("dictionary does not contain required keys %s" %
+                    ', '.join(required - found))
+
 def function_named(fn, name):
     """Return a function with a given __name__.
 
index fc72cf8e12ff7c618453725c195f2d0f7d44a231..a00338f5f52ec50d43d62caf1c6d269e6529d9fa 100644 (file)
@@ -419,6 +419,100 @@ class SymbolTest(TestBase):
             assert rt is sym1
             assert rt is sym2
 
+class AsInterfaceTest(TestBase):
+    class Something(object):
+        def _ignoreme(self): pass
+        def foo(self): pass
+        def bar(self): pass
+
+    class Partial(object):
+        def bar(self): pass
+
+    class Object(object): pass
+
+    def test_instance(self):
+        obj = object()
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          cls=self.Something)
+
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          methods=('foo'))
+
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          cls=self.Something, required=('foo'))
+
+        obj = self.Something()
+        self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
+        self.assertEqual(obj, util.as_interface(obj, methods=('foo',)))
+        self.assertEqual(
+            obj, util.as_interface(obj, cls=self.Something,
+                                   required=('outofband',)))
+        partial = self.Partial()
+
+        slotted = self.Object()
+        slotted.bar = lambda self: 123
+
+        for obj in partial, slotted:
+            self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
+            self.assertRaises(TypeError, util.as_interface, obj,
+                              methods=('foo'))
+            self.assertEqual(obj, util.as_interface(obj, methods=('bar',)))
+            self.assertEqual(
+                obj, util.as_interface(obj, cls=self.Something,
+                                       required=('bar',)))
+            self.assertRaises(TypeError, util.as_interface, obj,
+                              cls=self.Something, required=('foo',))
+
+            self.assertRaises(TypeError, util.as_interface, obj,
+                              cls=self.Something, required=self.Something)
+
+    def test_dict(self):
+        obj = {}
+
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          cls=self.Something)
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          methods=('foo'))
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          cls=self.Something, required=('foo'))
+
+        def assertAdapted(obj, *methods):
+            assert isinstance(obj, type)
+            found = set([m for m in dir(obj) if not m.startswith('_')])
+            for method in methods:
+                assert method in found
+                found.remove(method)
+            assert not found
+
+        fn = lambda self: 123
+
+        obj = {'foo': fn, 'bar': fn}
+
+        res = util.as_interface(obj, cls=self.Something)
+        assertAdapted(res, 'foo', 'bar')
+
+        res = util.as_interface(obj, cls=self.Something, required=self.Something)
+        assertAdapted(res, 'foo', 'bar')
+
+        res = util.as_interface(obj, cls=self.Something, required=('foo',))
+        assertAdapted(res, 'foo', 'bar')
+
+        res = util.as_interface(obj, methods=('foo', 'bar'))
+        assertAdapted(res, 'foo', 'bar')
+
+        res = util.as_interface(obj, methods=('foo', 'bar', 'baz'))
+        assertAdapted(res, 'foo', 'bar')
+
+        res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',))
+        assertAdapted(res, 'foo', 'bar')
+
+        self.assertRaises(TypeError, util.as_interface, obj, methods=('foo',))
+
+        self.assertRaises(TypeError, util.as_interface, obj,
+                          methods=('foo', 'bar', 'baz'), required=('baz',))
+
+        obj = {'foo': 123}
+        self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something)
 
 if __name__ == "__main__":
     testenv.main()
index 8e0abb69115ed66a54ed590629379ea4efb76dd0..75cb08e3c880e4d29ce187a0f1b5e258a7096d45 100644 (file)
@@ -453,12 +453,12 @@ class PoolTest(TestBase):
             def checkout(self, con, record, proxy, num):
                 pass
         class ListenCheckIn(InstrumentingListener):
-            def checkin(self, con, proxy, record):
+            def checkin(self, con, record):
                 pass
 
         def _pool(**kw):
-            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), use_threadlocal=False, **kw)
-            #, pool_size=1, max_overflow=0, **kw)
+            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+                                  use_threadlocal=False, **kw)
 
         def assert_listeners(p, total, conn, cout, cin):
             for instance in (p, p.recreate()):
@@ -551,6 +551,65 @@ class PoolTest(TestBase):
         del c
         snoop.assert_total(2, 2, 1)
 
+    def test_listeners_callables(self):
+        dbapi = MockDBAPI()
+
+        counts = [0, 0, 0]
+        def connect(dbapi_con, con_record):
+            counts[0] += 1
+        def checkout(dbapi_con, con_record, con_proxy):
+            counts[1] += 1
+        def checkin(dbapi_con, con_record):
+            counts[2] += 1
+
+        i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
+        i_connect = dict(connect=connect)
+        i_checkout = dict(checkout=checkout)
+        i_checkin = dict(checkin=checkin)
+
+        def _pool(**kw):
+            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+                                  use_threadlocal=False, **kw)
+
+        def assert_listeners(p, total, conn, cout, cin):
+            for instance in (p, p.recreate()):
+                self.assert_(len(instance.listeners) == total)
+                self.assert_(len(instance._on_connect) == conn)
+                self.assert_(len(instance._on_checkout) == cout)
+                self.assert_(len(instance._on_checkin) == cin)
+
+        p = _pool()
+        assert_listeners(p, 0, 0, 0, 0)
+
+        p.add_listener(i_all)
+        assert_listeners(p, 1, 1, 1, 1)
+
+        p.add_listener(i_connect)
+        assert_listeners(p, 2, 2, 1, 1)
+
+        p.add_listener(i_checkout)
+        assert_listeners(p, 3, 2, 2, 1)
+
+        p.add_listener(i_checkin)
+        assert_listeners(p, 4, 2, 2, 2)
+        del p
+
+        p = _pool(listeners=[i_all])
+        assert_listeners(p, 1, 1, 1, 1)
+
+        c = p.connect()
+        assert counts == [1, 1, 0]
+        c.close()
+        assert counts == [1, 1, 1]
+
+        c = p.connect()
+        assert counts == [1, 2, 1]
+        p.add_listener(i_checkin)
+        c.close()
+        assert counts == [1, 2, 3]
+
+
+
     def tearDown(self):
        pool.clear_managers()