]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- initial "events" idea. will replace all Extension, Proxy, Listener
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jul 2010 17:19:59 +0000 (13:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jul 2010 17:19:59 +0000 (13:19 -0400)
implementations with a single interface.

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/event.py [new file with mode: 0644]
lib/sqlalchemy/interfaces.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/util.py

index cf459f9e65fa51744a6096b033a9e9126605d3e7..51620dd37a3dd97a41f68f969302fbaaa438f3da 100644 (file)
@@ -22,7 +22,7 @@ __all__ = [
 
 import inspect, StringIO, sys, operator
 from itertools import izip
-from sqlalchemy import exc, schema, util, types, log
+from sqlalchemy import exc, schema, util, types, log, interfaces, event
 from sqlalchemy.sql import expression
 
 class Dialect(object):
@@ -1546,6 +1546,19 @@ class TwoPhaseTransaction(Transaction):
     def _do_commit(self):
         self.connection._commit_twophase_impl(self.xid, self._is_prepared)
 
+class _EngineDispatch(event.Dispatch):
+    def append(self, fn, identifier, target):
+        if isinstance(target.Connection, Connection):
+            target.Connection = _proxy_connection_cls(target.Connection, self)
+        event.Dispatch.append(self, fn, identifier)
+
+    def exec_(self, identifier, orig, kw):
+        for fn in getattr(self, identifier):
+            r = fn(**kw)
+            if r:
+                return r
+        else:
+            return orig()
 
 class Engine(Connectable, log.Identified):
     """
@@ -1559,7 +1572,9 @@ class Engine(Connectable, log.Identified):
     """
 
     _execution_options = util.frozendict()
-
+    Connection = Connection
+    _dispatch = event.dispatcher(_EngineDispatch)
+    
     def __init__(self, pool, dialect, url, 
                         logging_name=None, echo=None, proxy=None,
                         execution_options=None
@@ -1573,9 +1588,7 @@ class Engine(Connectable, log.Identified):
         self.engine = self
         self.logger = log.instance_logger(self, echoflag=echo)
         if proxy:
-            self.Connection = _proxy_connection_cls(Connection, proxy)
-        else:
-            self.Connection = Connection
+            interfaces.ConnectionProxy._adapt_listener(self, proxy)
         if execution_options:
             self.update_execution_options(**execution_options)
     
@@ -1795,25 +1808,54 @@ class Engine(Connectable, log.Identified):
 
         return self.pool.unique_connection()
 
-
-def _proxy_connection_cls(cls, proxy):
+def _proxy_connection_cls(cls, dispatch):
     class ProxyConnection(cls):
         def execute(self, object, *multiparams, **params):
-            return proxy.execute(self, super(ProxyConnection, self).execute, 
-                                            object, *multiparams, **params)
-
+            if not dispatch.on_execute:
+                return super(ProxyConnection, self).execute(object, *multiparams, **params)
+            else:
+                orig = super(ProxyConnection, self).execute
+                return dispatch.exec_('on_execute', orig, 
+                                        conn=self, 
+                                        execute=orig, 
+                                        clauseelement=object, 
+                                        multiparams=multiparams, 
+                                        params=params
+                )
+            
         def _execute_clauseelement(self, elem, multiparams=None, params=None):
-            return proxy.execute(self, super(ProxyConnection, self).execute, 
-                                            elem, 
-                                            *(multiparams or []),
-                                            **(params or {}))
+            if not dispatch.on_execute:
+                return super(ProxyConnection, self).\
+                        _execute_clauseelement(elem, 
+                                    multiparams=multiparams, 
+                                    params=params)
+            else:
+                orig = super(ProxyConnection, self).execute
+                return dispatch.exec_('on_execute', orig, 
+                                    conn=self, 
+                                    execute=orig, 
+                                    clauseelement=elem, 
+                                    multiparams=multiparams or [], 
+                                    params=params or {}
+                )
+
 
         def _cursor_execute(self, cursor, statement, 
                                     parameters, context=None):
-            return proxy.cursor_execute(
-                            super(ProxyConnection, self)._cursor_execute, 
-                            cursor, statement, parameters, context, False)
-
+            orig = super(ProxyConnection, self)._cursor_execute
+            if not dispatch.on_cursor_execute:
+                return orig(cursor, statement, parameters, context=context)
+            else:
+                return dispatch.exec_('on_cursor_execute', orig, 
+                                    conn=self, 
+                                    execute=super(ProxyConnection, self).execute, 
+                                    cursor=cursor,
+                                    statement=statement,
+                                    parameters=parameters,
+                                    executemany=False,
+                                    context=context)
+        
+        # these are all TODO
         def _cursor_executemany(self, cursor, statement, 
                                     parameters, context=None):
             return proxy.cursor_execute(
index ec2b4f302e2915c8b6742f5e8b88730e24d9f205..20393a5b3f871c9feba453fc3d97f31b82b2b4b1 100644 (file)
@@ -31,17 +31,18 @@ class TLConnection(base.Connection):
 class TLEngine(base.Engine):
     """An Engine that includes support for thread-local managed transactions."""
 
+    TLConnection = TLConnection
+    # TODO
+    #_dispatch = event.dispatcher(_TLEngineDispatch)
 
     def __init__(self, *args, **kwargs):
         super(TLEngine, self).__init__(*args, **kwargs)
         self._connections = util.threading.local()
-        proxy = kwargs.get('proxy')
-        if proxy:
-            self.TLConnection = base._proxy_connection_cls(
-                                        TLConnection, proxy)
-        else:
-            self.TLConnection = TLConnection
-
+        
+        # dont have to deal with proxy here, the
+        # superclass constructor + class level 
+        # _dispatch handles it
+        
     def contextual_connect(self, **kw):
         if not hasattr(self._connections, 'conn'):
             connection = None
diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py
new file mode 100644 (file)
index 0000000..1b0b62b
--- /dev/null
@@ -0,0 +1,44 @@
+from sqlalchemy import util
+
+def listen(fn, identifier, target, *args):
+    """Listen for events, passing to fn."""
+    
+    target._dispatch.append(fn, identifier, target, *args)
+
+NO_RESULT = util.symbol('no_result')
+
+
+class Dispatch(object):
+        
+    def append(self, identifier, fn, target):
+        getattr(self, identifier).append(fn)
+    
+    def __getattr__(self, key):
+        self.__dict__[key] = coll = []
+        return coll
+    
+    def chain(self, identifier, chain_kw, **kw):
+        ret = NO_RESULT
+        for fn in getattr(self, identifier):
+            ret = fn(**kw)
+            kw['chain_kw'] = ret
+        return ret
+            
+    def __call__(self, identifier, **kw):
+        for fn in getattr(self, identifier):
+            fn(**kw)
+        
+        
+class dispatcher(object):
+    def __init__(self, dispatch_cls=Dispatch):
+        self.dispatch_cls = dispatch_cls
+        self._dispatch = dispatch_cls()
+        
+    def __get__(self, obj, cls):
+        if obj is None:
+            return self._dispatch
+        obj.__dict__['_dispatch'] = disp = self.dispatch_cls()
+        for key in self._dispatch.__dict__:
+            if key.startswith('on_'):
+                disp.__dict__[key] = self._dispatch.__dict__[k].copy()
+        return disp
index c2a267d5f32c7250eebe73904e098bcc60c8bc06..2447b15bff0c481649c3d9c5a2041219f311a453 100644 (file)
@@ -6,10 +6,14 @@
 
 """Interfaces and abstract types."""
 
+from sqlalchemy.util import as_interface, adapt_kw_to_positional
 
 class PoolListener(object):
-    """Hooks into the lifecycle of connections in a ``Pool``.
+    """Hooks into the lifecycle of connections in a :class:`Pool`.
 
+    .. note:: :class:`PoolListener` is deprecated.   Please
+       refer to :func:`event.listen`.
+    
     Usage::
     
         class MyListener(PoolListener):
@@ -58,7 +62,32 @@ class PoolListener(object):
     providing implementations for the hooks you'll be using.
     
     """
-
+    
+    @classmethod
+    def _adapt_listener(cls, self, listener):
+        """Adapt a :class:`PoolListener` to individual 
+        :class:`event.Dispatch` events.
+        
+        """
+        listener = as_interface(listener,
+            methods=('connect', 'first_connect', 'checkout', 'checkin'))
+
+        if hasattr(listener, 'connect'):
+            self._dispatch.append('on_connect', 
+                                adapt_kw_to_positional(listener.connect, 
+                                                    'dbapi_con', 'con_record'), 
+                                self)
+        if hasattr(listener, 'first_connect'):
+            self._dispatch.append('on_first_connect', 
+                                adapt_kw_to_positional(listener.first_connect, 
+                                                    'dbapi_con', 'con_record'),
+                                self)
+        if hasattr(listener, 'checkout'):
+            self._dispatch.append('on_checkout', listener.checkout, self)
+        if hasattr(listener, 'checkin'):
+            self._dispatch.append('on_checkin', listener.checkin, self)
+            
+        
     def connect(self, dbapi_con, con_record):
         """Called once for each new DB-API connection or Pool's ``creator()``.
 
@@ -119,6 +148,9 @@ class PoolListener(object):
 
 class ConnectionProxy(object):
     """Allows interception of statement execution by Connections.
+
+    .. note:: :class:`ConnectionProxy` is deprecated.   Please
+       refer to :func:`event.listen`.
     
     Either or both of the ``execute()`` and ``cursor_execute()``
     may be implemented to intercept compiled statement and
@@ -143,6 +175,11 @@ class ConnectionProxy(object):
         e = create_engine('someurl://', proxy=MyProxy())
     
     """
+    
+    @classmethod
+    def _adapt_listener(cls, self, listener):
+        pass
+        
     def execute(self, conn, execute, clauseelement, *multiparams, **params):
         """Intercept high level execute() events."""
         
index ab31736ed19522be3f83e91ad779f00195f5271e..93e01272a6f1dcda5a17029a56998bf71a66dd64 100644 (file)
@@ -461,6 +461,7 @@ class ScalarAttributeImpl(AttributeImpl):
         dict_[self.key] = value
 
     def fire_replace_event(self, state, dict_, value, previous, initiator):
+#        value = self._dispatch.chain('set', 'value', state, value, previous, initiator or self)
         for ext in self.extensions:
             value = ext.set(state, value, previous, initiator or self)
         return value
index 9d37b183844b73f5df6d5d1906e9fd1c3865a607..bc8d6929cf1833fde400f5ac6a634be0fa85ec58 100644 (file)
@@ -19,9 +19,9 @@ SQLAlchemy connection pool.
 
 import weakref, time, threading
 
-from sqlalchemy import exc, log
+from sqlalchemy import exc, log, event, interfaces, util
 from sqlalchemy import queue as sqla_queue
-from sqlalchemy.util import threading, pickle, as_interface, memoized_property
+from sqlalchemy.util import threading, pickle, memoized_property
 
 proxies = {}
 
@@ -64,7 +64,9 @@ class Pool(log.Identified):
                     creator, recycle=-1, echo=None, 
                     use_threadlocal=False,
                     logging_name=None,
-                    reset_on_return=True, listeners=None):
+                    reset_on_return=True, 
+                    listeners=None,
+                    _dispatch=None):
         """
         Construct a Pool.
 
@@ -102,11 +104,12 @@ class Pool(log.Identified):
           ROLLBACK to release locks and transaction resources.
           Disable at your own peril.  Defaults to True.
 
-        :param listeners: A list of
+        :param listeners: Deprecated.  A list of
           :class:`~sqlalchemy.interfaces.PoolListener`-like objects or
           dictionaries of callables that receive events when DB-API
           connections are created, checked out and checked in to the
-          pool.
+          pool.  This has been superceded by 
+          :func:`~sqlalchemy.event.listen`.
 
         """
         if logging_name:
@@ -121,16 +124,41 @@ class Pool(log.Identified):
         self._use_threadlocal = use_threadlocal
         self._reset_on_return = reset_on_return
         self.echo = echo
-        self.listeners = []
-        self._on_connect = []
-        self._on_first_connect = []
-        self._on_checkout = []
-        self._on_checkin = []
-
+        if _dispatch:
+            self._dispatch = _dispatch
         if listeners:
             for l in listeners:
                 self.add_listener(l)
 
+    if False:
+        # this might be a nice way to define events and have them 
+        # documented at the same time.
+        class events(event.Dispatch):
+            def on_connect(self, dbapi_con, con_record):
+                """Called once for each new DB-API connection or Pool's ``creator()``.
+
+                dbapi_con
+                  A newly connected raw DB-API connection (not a SQLAlchemy
+                  ``Connection`` wrapper).
+
+                con_record
+                  The ``_ConnectionRecord`` that persistently manages the connection
+
+                """
+        
+    _dispatch = event.dispatcher()
+    
+    @util.deprecated("Use event.listen()")
+    def add_listener(self, listener):
+        """Add a ``PoolListener``-like object to this pool.
+        
+        ``listener`` may be an object that implements some or all of
+        PoolListener, or a dictionary of callables containing implementations
+        of some or all of the named methods in PoolListener.
+
+        """
+        interfaces.PoolListener._adapt_listener(self, listener)
+    
     def unique_connection(self):
         return _ConnectionFairy(self).checkout()
 
@@ -185,40 +213,18 @@ class Pool(log.Identified):
     def status(self):
         raise NotImplementedError()
 
-    def add_listener(self, listener):
-        """Add a ``PoolListener``-like object to this pool.
-
-        ``listener`` may be an object that implements some or all of
-        PoolListener, or a dictionary of callables containing implementations
-        of some or all of the named methods in PoolListener.
-
-        """
-
-        listener = as_interface(listener,
-            methods=('connect', 'first_connect', 'checkout', 'checkin'))
-
-        self.listeners.append(listener)
-        if hasattr(listener, 'connect'):
-            self._on_connect.append(listener)
-        if hasattr(listener, 'first_connect'):
-            self._on_first_connect.append(listener)
-        if hasattr(listener, 'checkout'):
-            self._on_checkout.append(listener)
-        if hasattr(listener, 'checkin'):
-            self._on_checkin.append(listener)
 
 class _ConnectionRecord(object):
     def __init__(self, pool):
         self.__pool = pool
         self.connection = self.__connect()
         self.info = {}
-        ls = pool.__dict__.pop('_on_first_connect', None)
-        if ls is not None:
-            for l in ls:
-                l.first_connect(self.connection, self)
-        if pool._on_connect:
-            for l in pool._on_connect:
-                l.connect(self.connection, self)
+
+        if pool._dispatch.on_first_connect:
+            pool._dispatch('on_first_connect', dbapi_con=self.connection, con_record=self)
+            del pool._dispatch.on_first_connect
+        if pool._dispatch.on_connect:
+            pool._dispatch('on_connect', dbapi_con=self.connection, con_record=self)
 
     def close(self):
         if self.connection is not None:
@@ -246,9 +252,8 @@ class _ConnectionRecord(object):
         if self.connection is None:
             self.connection = self.__connect()
             self.info.clear()
-            if self.__pool._on_connect:
-                for l in self.__pool._on_connect:
-                    l.connect(self.connection, self)
+            if self.__pool._dispatch.on_connect:
+                self.__pool._dispatch('on_connect', dbapi_con=self.connection, con_record=self)
         elif self.__pool._recycle > -1 and \
                 time.time() - self.starttime > self.__pool._recycle:
             self.__pool.logger.info(
@@ -257,9 +262,8 @@ class _ConnectionRecord(object):
             self.__close()
             self.connection = self.__connect()
             self.info.clear()
-            if self.__pool._on_connect:
-                for l in self.__pool._on_connect:
-                    l.connect(self.connection, self)
+            if self.__pool._dispatch.on_connect:
+                self.__pool._dispatch('on_connect', dbapi_con=self.connection, con_record=self)
         return self.connection
 
     def __close(self):
@@ -308,9 +312,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None):
     if connection_record is not None:
         connection_record.fairy = None
         pool.logger.debug("Connection %r being returned to pool", connection)
-        if pool._on_checkin:
-            for l in pool._on_checkin:
-                l.checkin(connection, connection_record)
+        if pool._dispatch.on_checkin:
+            pool._dispatch('on_checkin', dbapi_con=connection, con_record=connection_record)
         pool.return_conn(connection_record)
 
 _refs = set()
@@ -394,15 +397,16 @@ class _ConnectionFairy(object):
             raise exc.InvalidRequestError("This connection is closed")
         self.__counter += 1
 
-        if not self._pool._on_checkout or self.__counter != 1:
+        if not self._pool._dispatch.on_checkout or self.__counter != 1:
             return self
 
         # Pool listeners can trigger a reconnection on checkout
         attempts = 2
         while attempts > 0:
             try:
-                for l in self._pool._on_checkout:
-                    l.checkout(self.connection, self._connection_record, self)
+                self._pool._dispatch('on_checkout', dbapi_con=self.connection, 
+                                            con_record=self._connection_record,
+                                            con_proxy=self)
                 return self
             except exc.DisconnectionError, e:
                 self._pool.logger.info(
@@ -515,7 +519,7 @@ class SingletonThreadPool(Pool):
             echo=self.echo, 
             logging_name=self._orig_logging_name,
             use_threadlocal=self._use_threadlocal, 
-            listeners=self.listeners)
+            _dispatch=self._dispatch)
 
     def dispose(self):
         """Dispose of this pool."""
@@ -648,7 +652,7 @@ class QueuePool(Pool):
                           recycle=self._recycle, echo=self.echo, 
                           logging_name=self._orig_logging_name,
                           use_threadlocal=self._use_threadlocal,
-                          listeners=self.listeners)
+                          _dispatch=self._dispatch)
 
     def do_return_conn(self, conn):
         try:
@@ -759,7 +763,7 @@ class NullPool(Pool):
             echo=self.echo, 
             logging_name=self._orig_logging_name,
             use_threadlocal=self._use_threadlocal, 
-            listeners=self.listeners)
+            _dispatch=self._dispatch)
 
     def dispose(self):
         pass
@@ -799,7 +803,7 @@ class StaticPool(Pool):
                               reset_on_return=self._reset_on_return,
                               echo=self.echo,
                               logging_name=self._orig_logging_name,
-                              listeners=self.listeners)
+                              _dispatch=self._dispatch)
 
     def create_connection(self):
         return self._conn
@@ -850,7 +854,7 @@ class AssertionPool(Pool):
         self.logger.info("Pool recreating")
         return AssertionPool(self._creator, echo=self.echo, 
                             logging_name=self._orig_logging_name,
-                            listeners=self.listeners)
+                            _dispatch=self._dispatch)
         
     def do_get(self):
         if self._checked_out:
index ae45e17036ea5db5acb2c0ca4a009cb8bcd8266a..73e2533ce533e7e36516ba252e193362665e52df 100644 (file)
@@ -636,6 +636,11 @@ def assert_arg_type(arg, argtype, name):
                             "Argument '%s' is expected to be of type '%s', got '%s'" % 
                             (name, argtype, type(arg)))
 
+def adapt_kw_to_positional(fn, *args):
+    def call(**kw):
+        return fn(*[kw[a] for a in args])
+    return call
+    
 _creation_order = 1
 def set_creation_order(instance):
     """Assign a '_creation_order' sequence to the given instance.