]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
this reorganizes things so the EventDescriptor and all is on a "Dispatch" object.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Aug 2010 03:34:23 +0000 (23:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Aug 2010 03:34:23 +0000 (23:34 -0400)
this leaves the original Event class alone so sphinx documents it.

this is all a mess right now but the pool/engine tests are working fully
at the moment so wanted to mark a working version.

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/event.py
lib/sqlalchemy/pool.py
test/engine/test_execute.py

index 7b6ff5b7a3f5a27965f3a5a95353739323928f7c..8d1138316e418bf72e25a1458ceea4dfcdfd6663 100644 (file)
@@ -1554,7 +1554,7 @@ class EngineEvents(event.Events):
         if issubclass(target.Connection, Connection):
             target.Connection = _proxy_connection_cls(
                                         Connection, 
-                                        target.events)
+                                        target.dispatch)
         event.Events.listen(fn, identifier, target)
 
     def on_execute(self, conn, execute, clauseelement, *multiparams, **params):
@@ -1627,7 +1627,7 @@ class Engine(Connectable, log.Identified):
             self.update_execution_options(**execution_options)
 
             
-    events = event.dispatcher(EngineEvents)
+    dispatch = event.dispatcher(EngineEvents)
     
     def update_execution_options(self, **opt):
         """update the execution_options dictionary of this :class:`Engine`.
@@ -1851,11 +1851,7 @@ def _proxy_connection_cls(cls, dispatch):
             return orig
         def go(*arg, **kw):
             nested = _exec_recursive(conn, fns[1:], orig)
-            try:
-                ret = fns[0](conn, nested, *arg, **kw)
-            except IndexError:
-                import pdb
-                pdb.set_trace()
+            ret = fns[0](conn, nested, *arg, **kw)
             # TODO: need to get consistent way to check 
             # for "they called the fn, they didn't", or otherwise
             # make some decision here how this is to work
index 785c6e96ad58122f93eaecb21bc9c84dc2860677..b6e687b7c97601d75b502f8948598075ce0f0a6c 100644 (file)
@@ -27,7 +27,15 @@ class TLConnection(base.Connection):
         self.__opencount = 0
         base.Connection.close(self)
 
-        
+class TLEvents(base.EngineEvents):
+    @classmethod
+    def listen(cls, fn, identifier, target):
+        if issubclass(target.TLConnection, TLConnection):
+            target.TLConnection = base._proxy_connection_cls(
+                                        TLConnection, 
+                                        target.dispatch)
+        base.EngineEvents.listen(fn, identifier, target)
+
 class TLEngine(base.Engine):
     """An Engine that includes support for thread-local managed transactions."""
 
@@ -37,15 +45,7 @@ class TLEngine(base.Engine):
         super(TLEngine, self).__init__(*args, **kwargs)
         self._connections = util.threading.local()
 
-    class events(base.Engine.events):
-        @classmethod
-        def listen(cls, fn, identifier, target):
-            if issubclass(target.TLConnection, TLConnection):
-                target.TLConnection = base._proxy_connection_cls(
-                                            TLConnection, 
-                                            target.events)
-            base.Engine.events.listen(fn, identifier, target)
-    events = event.dispatcher(events)
+    dispatch = event.dispatcher(TLEvents)
     
     def contextual_connect(self, **kw):
         if not hasattr(self._connections, 'conn'):
index 5a8c193a0259313b96876101943d03e7ab7c9852..bfa617a25b8a9de0fcbb665778172186744133cb 100644 (file)
@@ -16,53 +16,61 @@ def listen(fn, identifier, target, *args, **kw):
     
     for evt_cls in _registrars[identifier]:
         for tgt in evt_cls.accept_with(target):
-            tgt.events.listen(fn, identifier, tgt, *args, **kw)
+            
+            tgt.dispatch.events.listen(fn, identifier, tgt, *args, **kw)
             break
     
 class _DispatchMeta(type):
     def __init__(cls, classname, bases, dict_):
+        
+        dispatch_base = getattr(cls, 'dispatch', Dispatch)
+        cls.dispatch = dispatch_cls = type("%sDispatch" % classname, (dispatch_base, ), {})
+        dispatch_cls.events = cls
         for k in dict_:
             if k.startswith('on_'):
-                setattr(cls, k, EventDescriptor(dict_[k]))
+                setattr(dispatch_cls, k, EventDescriptor(dict_[k]))
                 _registrars[k].append(cls)
         return type.__init__(cls, classname, bases, dict_)
 
 _registrars = util.defaultdict(list)
 
-class Events(object):
-    __metaclass__ = _DispatchMeta
-    
+class Dispatch(object):
+
     def __init__(self, parent_cls):
         self.parent_cls = parent_cls
     
+    @property
+    def descriptors(self):
+        return (getattr(self, k) for k in dir(self) if k.startswith("on_"))
+        
+    def update(self, other):
+        """Populate from the listeners in another :class:`Events` object."""
+
+        for ls in other.descriptors:
+            getattr(self, ls.name).listeners.extend(ls.listeners)
+
+
+class Events(object):
+    __metaclass__ = _DispatchMeta
+    
     @classmethod
     def accept_with(cls, target):
         # Mapper, ClassManager, Session override this to
         # also accept classes, scoped_sessions, sessionmakers, etc.
-        if hasattr(target, 'events') and (
-                    isinstance(target.events, cls) or \
-                    isinstance(target.events, type) and \
-                    issubclass(target.events, cls)
+        if hasattr(target, 'dispatch') and (
+                    isinstance(target.dispatch, cls.dispatch) or \
+                    isinstance(target.dispatch, type) and \
+                    issubclass(target.dispatch, cls.dispatch)
                 ):
             return [target]
         else:
             return []
-        
+
     @classmethod
     def listen(cls, fn, identifier, target):
-        getattr(target.events, identifier).append(fn, target)
-    
-    @property
-    def events(self):
-        """Iterate the Listeners objects."""
-        
-        return (getattr(self, k) for k in dir(self) if k.startswith("on_"))
+        getattr(target.dispatch, identifier).append(fn, target)
         
-    def update(self, other):
-        """Populate from the listeners in another :class:`Events` object."""
 
-        for ls in other.events:
-            getattr(self, ls.name).listeners.extend(ls.listeners)
 
 class _ExecEvent(object):
     _exec_once = False
@@ -149,10 +157,10 @@ class Listeners(_ExecEvent):
 
 class dispatcher(object):
     def __init__(self, events):
-        self.dispatch_cls = events
+        self.dispatch_cls = events.dispatch
         
     def __get__(self, obj, cls):
         if obj is None:
             return self.dispatch_cls
-        obj.__dict__['events'] = disp = self.dispatch_cls(cls)
+        obj.__dict__['dispatch'] = disp = self.dispatch_cls(cls)
         return disp
index 8a2845562340f5f921929552f67699afdabac8f9..df6b0521c49561b4fd4f83496249ed93e17f08d8 100644 (file)
@@ -201,7 +201,7 @@ class Pool(log.Identified):
         self._reset_on_return = reset_on_return
         self.echo = echo
         if _dispatch:
-            self.events.update(_dispatch)
+            self.dispatch.update(_dispatch)
         if listeners:
             util.warn_deprecated(
                         "The 'listeners' argument to Pool (and "
@@ -209,9 +209,9 @@ class Pool(log.Identified):
             for l in listeners:
                 self.add_listener(l)
 
-    events = event.dispatcher(PoolEvents)
+    dispatch = event.dispatcher(PoolEvents)
         
-    @util.deprecated(":meth:`.Pool.add_listener` is deprecated.  Use :func:`.event.listen`")
+    @util.deprecated("Pool.add_listener is deprecated.  Use event.listen()")
     def add_listener(self, listener):
         """Add a :class:`.PoolListener`-like object to this pool.
         
@@ -283,8 +283,8 @@ class _ConnectionRecord(object):
         self.connection = self.__connect()
         self.info = {}
 
-        pool.events.on_first_connect.exec_once(self.connection, self)
-        pool.events.on_connect(self.connection, self)
+        pool.dispatch.on_first_connect.exec_once(self.connection, self)
+        pool.dispatch.on_connect(self.connection, self)
 
     def close(self):
         if self.connection is not None:
@@ -312,8 +312,8 @@ class _ConnectionRecord(object):
         if self.connection is None:
             self.connection = self.__connect()
             self.info.clear()
-            if self.__pool.events.on_connect:
-                self.__pool.events.on_connect(self.connection, self)
+            if self.__pool.dispatch.on_connect:
+                self.__pool.dispatch.on_connect(self.connection, self)
         elif self.__pool._recycle > -1 and \
                 time.time() - self.starttime > self.__pool._recycle:
             self.__pool.logger.info(
@@ -322,8 +322,8 @@ class _ConnectionRecord(object):
             self.__close()
             self.connection = self.__connect()
             self.info.clear()
-            if self.__pool.events.on_connect:
-                self.__pool.events.on_connect(self.connection, self)
+            if self.__pool.dispatch.on_connect:
+                self.__pool.dispatch.on_connect(self.connection, self)
         return self.connection
 
     def __close(self):
@@ -372,8 +372,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None):
     if connection_record is not None:
         connection_record.fairy = None
         pool.logger.debug("Connection %r being returned to pool", connection)
-        if pool.events.on_checkin:
-            pool.events.on_checkin(connection, connection_record)
+        if pool.dispatch.on_checkin:
+            pool.dispatch.on_checkin(connection, connection_record)
         pool.return_conn(connection_record)
 
 _refs = set()
@@ -457,14 +457,14 @@ class _ConnectionFairy(object):
             raise exc.InvalidRequestError("This connection is closed")
         self.__counter += 1
 
-        if not self._pool.events.on_checkout or self.__counter != 1:
+        if not self._pool.dispatch.on_checkout or self.__counter != 1:
             return self
 
         # Pool listeners can trigger a reconnection on checkout
         attempts = 2
         while attempts > 0:
             try:
-                self._pool.events.on_checkout(self.connection, 
+                self._pool.dispatch.on_checkout(self.connection, 
                                             self._connection_record,
                                             self)
                 return self
@@ -579,7 +579,7 @@ class SingletonThreadPool(Pool):
             echo=self.echo, 
             logging_name=self._orig_logging_name,
             use_threadlocal=self._use_threadlocal, 
-            _dispatch=self.events)
+            _dispatch=self.dispatch)
 
     def dispose(self):
         """Dispose of this pool."""
@@ -712,7 +712,7 @@ class QueuePool(Pool):
                           recycle=self._recycle, echo=self.echo, 
                           logging_name=self._orig_logging_name,
                           use_threadlocal=self._use_threadlocal,
-                          _dispatch=self.events)
+                          _dispatch=self.dispatch)
 
     def do_return_conn(self, conn):
         try:
@@ -823,7 +823,7 @@ class NullPool(Pool):
             echo=self.echo, 
             logging_name=self._orig_logging_name,
             use_threadlocal=self._use_threadlocal, 
-            _dispatch=self.events)
+            _dispatch=self.dispatch)
 
     def dispose(self):
         pass
@@ -863,7 +863,7 @@ class StaticPool(Pool):
                               reset_on_return=self._reset_on_return,
                               echo=self.echo,
                               logging_name=self._orig_logging_name,
-                              _dispatch=self.events)
+                              _dispatch=self.dispatch)
 
     def create_connection(self):
         return self._conn
@@ -914,7 +914,7 @@ class AssertionPool(Pool):
         self.logger.info("Pool recreating")
         return AssertionPool(self._creator, echo=self.echo, 
                             logging_name=self._orig_logging_name,
-                            _dispatch=self.events)
+                            _dispatch=self.dispatch)
         
     def do_get(self):
         if self._checked_out:
index 6e6069f04be8879bb031026addc5628eab6fdd95..cacc5385aa9e1cdf28422ca9e5999e6f538aa494 100644 (file)
@@ -320,10 +320,11 @@ class EngineEventsTest(TestBase):
                             == params or testparams == posn):
                         break
 
-        for engine in \
-            engines.testing_engine(options=dict(implicit_returning=False)), \
+        for engine in [
+#            engines.testing_engine(options=dict(implicit_returning=False)), 
             engines.testing_engine(options=dict(implicit_returning=False,
-                                   strategy='threadlocal')):
+                                   strategy='threadlocal'))
+            ]:
             event.listen(execute, 'on_execute', engine)
             event.listen(cursor_execute, 'on_cursor_execute', engine)