From f665ae746428cdb69e97d4576da29268a388569a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 9 Aug 2010 23:34:23 -0400 Subject: [PATCH] this reorganizes things so the EventDescriptor and all is on a "Dispatch" object. this leaves the original Event class alone so sphinx documents it. this is all a mess right now but the pool/engine tests are working fully at the moment so wanted to mark a working version. --- lib/sqlalchemy/engine/base.py | 10 ++---- lib/sqlalchemy/engine/threadlocal.py | 20 +++++------ lib/sqlalchemy/event.py | 54 ++++++++++++++++------------ lib/sqlalchemy/pool.py | 36 +++++++++---------- test/engine/test_execute.py | 7 ++-- 5 files changed, 66 insertions(+), 61 deletions(-) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 7b6ff5b7a3..8d1138316e 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1554,7 +1554,7 @@ class EngineEvents(event.Events): if issubclass(target.Connection, Connection): target.Connection = _proxy_connection_cls( Connection, - target.events) + target.dispatch) event.Events.listen(fn, identifier, target) def on_execute(self, conn, execute, clauseelement, *multiparams, **params): @@ -1627,7 +1627,7 @@ class Engine(Connectable, log.Identified): self.update_execution_options(**execution_options) - events = event.dispatcher(EngineEvents) + dispatch = event.dispatcher(EngineEvents) def update_execution_options(self, **opt): """update the execution_options dictionary of this :class:`Engine`. @@ -1851,11 +1851,7 @@ def _proxy_connection_cls(cls, dispatch): return orig def go(*arg, **kw): nested = _exec_recursive(conn, fns[1:], orig) - try: - ret = fns[0](conn, nested, *arg, **kw) - except IndexError: - import pdb - pdb.set_trace() + ret = fns[0](conn, nested, *arg, **kw) # TODO: need to get consistent way to check # for "they called the fn, they didn't", or otherwise # make some decision here how this is to work diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 785c6e96ad..b6e687b7c9 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -27,7 +27,15 @@ class TLConnection(base.Connection): self.__opencount = 0 base.Connection.close(self) - +class TLEvents(base.EngineEvents): + @classmethod + def listen(cls, fn, identifier, target): + if issubclass(target.TLConnection, TLConnection): + target.TLConnection = base._proxy_connection_cls( + TLConnection, + target.dispatch) + base.EngineEvents.listen(fn, identifier, target) + class TLEngine(base.Engine): """An Engine that includes support for thread-local managed transactions.""" @@ -37,15 +45,7 @@ class TLEngine(base.Engine): super(TLEngine, self).__init__(*args, **kwargs) self._connections = util.threading.local() - class events(base.Engine.events): - @classmethod - def listen(cls, fn, identifier, target): - if issubclass(target.TLConnection, TLConnection): - target.TLConnection = base._proxy_connection_cls( - TLConnection, - target.events) - base.Engine.events.listen(fn, identifier, target) - events = event.dispatcher(events) + dispatch = event.dispatcher(TLEvents) def contextual_connect(self, **kw): if not hasattr(self._connections, 'conn'): diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 5a8c193a02..bfa617a25b 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -16,53 +16,61 @@ def listen(fn, identifier, target, *args, **kw): for evt_cls in _registrars[identifier]: for tgt in evt_cls.accept_with(target): - tgt.events.listen(fn, identifier, tgt, *args, **kw) + + tgt.dispatch.events.listen(fn, identifier, tgt, *args, **kw) break class _DispatchMeta(type): def __init__(cls, classname, bases, dict_): + + dispatch_base = getattr(cls, 'dispatch', Dispatch) + cls.dispatch = dispatch_cls = type("%sDispatch" % classname, (dispatch_base, ), {}) + dispatch_cls.events = cls for k in dict_: if k.startswith('on_'): - setattr(cls, k, EventDescriptor(dict_[k])) + setattr(dispatch_cls, k, EventDescriptor(dict_[k])) _registrars[k].append(cls) return type.__init__(cls, classname, bases, dict_) _registrars = util.defaultdict(list) -class Events(object): - __metaclass__ = _DispatchMeta - +class Dispatch(object): + def __init__(self, parent_cls): self.parent_cls = parent_cls + @property + def descriptors(self): + return (getattr(self, k) for k in dir(self) if k.startswith("on_")) + + def update(self, other): + """Populate from the listeners in another :class:`Events` object.""" + + for ls in other.descriptors: + getattr(self, ls.name).listeners.extend(ls.listeners) + + +class Events(object): + __metaclass__ = _DispatchMeta + @classmethod def accept_with(cls, target): # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. - if hasattr(target, 'events') and ( - isinstance(target.events, cls) or \ - isinstance(target.events, type) and \ - issubclass(target.events, cls) + if hasattr(target, 'dispatch') and ( + isinstance(target.dispatch, cls.dispatch) or \ + isinstance(target.dispatch, type) and \ + issubclass(target.dispatch, cls.dispatch) ): return [target] else: return [] - + @classmethod def listen(cls, fn, identifier, target): - getattr(target.events, identifier).append(fn, target) - - @property - def events(self): - """Iterate the Listeners objects.""" - - return (getattr(self, k) for k in dir(self) if k.startswith("on_")) + getattr(target.dispatch, identifier).append(fn, target) - def update(self, other): - """Populate from the listeners in another :class:`Events` object.""" - for ls in other.events: - getattr(self, ls.name).listeners.extend(ls.listeners) class _ExecEvent(object): _exec_once = False @@ -149,10 +157,10 @@ class Listeners(_ExecEvent): class dispatcher(object): def __init__(self, events): - self.dispatch_cls = events + self.dispatch_cls = events.dispatch def __get__(self, obj, cls): if obj is None: return self.dispatch_cls - obj.__dict__['events'] = disp = self.dispatch_cls(cls) + obj.__dict__['dispatch'] = disp = self.dispatch_cls(cls) return disp diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 8a28455623..df6b0521c4 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -201,7 +201,7 @@ class Pool(log.Identified): self._reset_on_return = reset_on_return self.echo = echo if _dispatch: - self.events.update(_dispatch) + self.dispatch.update(_dispatch) if listeners: util.warn_deprecated( "The 'listeners' argument to Pool (and " @@ -209,9 +209,9 @@ class Pool(log.Identified): for l in listeners: self.add_listener(l) - events = event.dispatcher(PoolEvents) + dispatch = event.dispatcher(PoolEvents) - @util.deprecated(":meth:`.Pool.add_listener` is deprecated. Use :func:`.event.listen`") + @util.deprecated("Pool.add_listener is deprecated. Use event.listen()") def add_listener(self, listener): """Add a :class:`.PoolListener`-like object to this pool. @@ -283,8 +283,8 @@ class _ConnectionRecord(object): self.connection = self.__connect() self.info = {} - pool.events.on_first_connect.exec_once(self.connection, self) - pool.events.on_connect(self.connection, self) + pool.dispatch.on_first_connect.exec_once(self.connection, self) + pool.dispatch.on_connect(self.connection, self) def close(self): if self.connection is not None: @@ -312,8 +312,8 @@ class _ConnectionRecord(object): if self.connection is None: self.connection = self.__connect() self.info.clear() - if self.__pool.events.on_connect: - self.__pool.events.on_connect(self.connection, self) + if self.__pool.dispatch.on_connect: + self.__pool.dispatch.on_connect(self.connection, self) elif self.__pool._recycle > -1 and \ time.time() - self.starttime > self.__pool._recycle: self.__pool.logger.info( @@ -322,8 +322,8 @@ class _ConnectionRecord(object): self.__close() self.connection = self.__connect() self.info.clear() - if self.__pool.events.on_connect: - self.__pool.events.on_connect(self.connection, self) + if self.__pool.dispatch.on_connect: + self.__pool.dispatch.on_connect(self.connection, self) return self.connection def __close(self): @@ -372,8 +372,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None): if connection_record is not None: connection_record.fairy = None pool.logger.debug("Connection %r being returned to pool", connection) - if pool.events.on_checkin: - pool.events.on_checkin(connection, connection_record) + if pool.dispatch.on_checkin: + pool.dispatch.on_checkin(connection, connection_record) pool.return_conn(connection_record) _refs = set() @@ -457,14 +457,14 @@ class _ConnectionFairy(object): raise exc.InvalidRequestError("This connection is closed") self.__counter += 1 - if not self._pool.events.on_checkout or self.__counter != 1: + if not self._pool.dispatch.on_checkout or self.__counter != 1: return self # Pool listeners can trigger a reconnection on checkout attempts = 2 while attempts > 0: try: - self._pool.events.on_checkout(self.connection, + self._pool.dispatch.on_checkout(self.connection, self._connection_record, self) return self @@ -579,7 +579,7 @@ class SingletonThreadPool(Pool): echo=self.echo, logging_name=self._orig_logging_name, use_threadlocal=self._use_threadlocal, - _dispatch=self.events) + _dispatch=self.dispatch) def dispose(self): """Dispose of this pool.""" @@ -712,7 +712,7 @@ class QueuePool(Pool): recycle=self._recycle, echo=self.echo, logging_name=self._orig_logging_name, use_threadlocal=self._use_threadlocal, - _dispatch=self.events) + _dispatch=self.dispatch) def do_return_conn(self, conn): try: @@ -823,7 +823,7 @@ class NullPool(Pool): echo=self.echo, logging_name=self._orig_logging_name, use_threadlocal=self._use_threadlocal, - _dispatch=self.events) + _dispatch=self.dispatch) def dispose(self): pass @@ -863,7 +863,7 @@ class StaticPool(Pool): reset_on_return=self._reset_on_return, echo=self.echo, logging_name=self._orig_logging_name, - _dispatch=self.events) + _dispatch=self.dispatch) def create_connection(self): return self._conn @@ -914,7 +914,7 @@ class AssertionPool(Pool): self.logger.info("Pool recreating") return AssertionPool(self._creator, echo=self.echo, logging_name=self._orig_logging_name, - _dispatch=self.events) + _dispatch=self.dispatch) def do_get(self): if self._checked_out: diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 6e6069f04b..cacc5385aa 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -320,10 +320,11 @@ class EngineEventsTest(TestBase): == params or testparams == posn): break - for engine in \ - engines.testing_engine(options=dict(implicit_returning=False)), \ + for engine in [ +# engines.testing_engine(options=dict(implicit_returning=False)), engines.testing_engine(options=dict(implicit_returning=False, - strategy='threadlocal')): + strategy='threadlocal')) + ]: event.listen(execute, 'on_execute', engine) event.listen(cursor_execute, 'on_cursor_execute', engine) -- 2.47.2