]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- create a new system where we can decorate an event method
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Jul 2013 17:39:56 +0000 (13:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Jul 2013 17:39:56 +0000 (13:39 -0400)
with @_legacy_signature, will inspect incoming listener functions
to see if they match an older signature, will wrap into a newer sig
- add an event listen argument named=True, will send all args as
kw args so that event listeners can be written with **kw, any combination
of names
- add a doc system to events that writes out the various calling styles
for a given event, produces deprecation messages automatically.
a little concerned that it's a bit verbose but will look at it up
on RTD for awhile to get a feel.
- change the calling signature for bulk update/delete events - we have
the BulkUD object right there, and there's at least six or seven things
people might want to see, so just send the whole BulkUD in
[ticket:2775]

12 files changed:
doc/build/core/event.rst
doc/build/static/docs.css
lib/sqlalchemy/event.py
lib/sqlalchemy/events.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/deprecations.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_events.py
test/orm/test_events.py

index f3433876c2147c88b0526078c35091619d3157e3..73d0dab4c153fb82a1cfcbef3ddeba6b65ae0bcf 100644 (file)
@@ -3,7 +3,7 @@
 Events
 ======
 
-SQLAlchemy includes an event API which publishes a wide variety of hooks into 
+SQLAlchemy includes an event API which publishes a wide variety of hooks into
 the internals of both SQLAlchemy Core and ORM.
 
 .. versionadded:: 0.7
@@ -13,13 +13,15 @@ the internals of both SQLAlchemy Core and ORM.
 Event Registration
 ------------------
 
-Subscribing to an event occurs through a single API point, the :func:`.listen` function.   This function
-accepts a user-defined listening function, a string identifier which identifies the event to be
-intercepted, and a target.  Additional positional and keyword arguments may be supported by
+Subscribing to an event occurs through a single API point, the :func:`.listen` function,
+or alternatively the :func:`.listens_for` decorator.   These functions
+accept a user-defined listening function, a string identifier which identifies the event to be
+intercepted, and a target.  Additional positional and keyword arguments to these
+two functions may be supported by
 specific types of events, which may specify alternate interfaces for the given event function, or provide
 instructions regarding secondary event targets based on the given target.
 
-The name of an event and the argument signature of a corresponding listener function is derived from 
+The name of an event and the argument signature of a corresponding listener function is derived from
 a class bound specification method, which exists bound to a marker class that's described in the documentation.
 For example, the documentation for :meth:`.PoolEvents.connect` indicates that the event name is ``"connect"``
 and that a user-defined listener function should receive two positional arguments::
@@ -32,13 +34,62 @@ and that a user-defined listener function should receive two positional argument
 
     listen(Pool, 'connect', my_on_connect)
 
+To listen with the :func:`.listens_for` decorator looks like::
+
+    from sqlalchemy.event import listens_for
+    from sqlalchemy.pool import Pool
+
+    @listens_for(Pool, "connect")
+    def my_on_connect(dbapi_con, connection_record):
+        print "New DBAPI connection:", dbapi_con
+
+Named Argument Styles
+---------------------
+
+There are some varieties of argument styles which can be accepted by listener
+functions.  Taking the example of :meth:`.PoolEvents.connect`, this function
+is documented as receiving ``dbapi_connection`` and ``connection_record`` arguments.
+We can opt to receive these arguments by name, by establishing a listener function
+that accepts ``**keyword`` arguments, by passing ``named=True`` to either
+:func:`.listen` or :func:`.listens_for`::
+
+    from sqlalchemy.event import listens_for
+    from sqlalchemy.pool import Pool
+
+    @listens_for(Pool, "connect", named=True)
+    def my_on_connect(**kw):
+        print("New DBAPI connection:", kw['dbapi_connection'])
+
+When using named argument passing, the names listed in the function argument
+specification will be used as keys in the dictionary.
+
+Named style passes all arguments by name regardless of the function
+signature, so specific arguments may be listed as well, in any order,
+as long as the names match up::
+
+    from sqlalchemy.event import listens_for
+    from sqlalchemy.pool import Pool
+
+    @listens_for(Pool, "connect", named=True)
+    def my_on_connect(dbapi_connection, **kw):
+        print("New DBAPI connection:", dbapi_connection)
+        print("Connection record:", kw['connection_record'])
+
+Above, the presence of ``**kw`` tells :func:`.event.listen_for` that
+arguments should be passed to the function by name, rather than positionally.
+
+.. versionadded:: 0.9.0 Added optional ``named`` argument dispatch to
+   event calling.
+
 Targets
 -------
 
-The :func:`.listen` function is very flexible regarding targets.  It generally accepts classes, instances of those
-classes, and related classes or objects from which the appropriate target can be derived.  For example,
-the above mentioned ``"connect"`` event accepts :class:`.Engine` classes and objects as well as :class:`.Pool`
-classes and objects::
+The :func:`.listen` function is very flexible regarding targets.  It
+generally accepts classes, instances of those classes, and related
+classes or objects from which the appropriate target can be derived.
+For example, the above mentioned ``"connect"`` event accepts
+:class:`.Engine` classes and objects as well as :class:`.Pool` classes
+and objects::
 
     from sqlalchemy.event import listen
     from sqlalchemy.pool import Pool, QueuePool
@@ -68,10 +119,12 @@ classes and objects::
 Modifiers
 ----------
 
-Some listeners allow modifiers to be passed to :func:`.listen`.  These modifiers sometimes provide alternate
-calling signatures for listeners.  Such as with ORM events, some event listeners can have a return value
-which modifies the subsequent handling.   By default, no listener ever requires a return value, but by passing
-``retval=True`` this value can be supported::
+Some listeners allow modifiers to be passed to :func:`.listen`.  These
+modifiers sometimes provide alternate calling signatures for
+listeners.  Such as with ORM events, some event listeners can have a
+return value which modifies the subsequent handling.   By default, no
+listener ever requires a return value, but by passing ``retval=True``
+this value can be supported::
 
     def validate_phone(target, value, oldvalue, initiator):
         """Strip non-numeric characters from a phone number"""
index bb300e829b7483bc52e088a9f0cb5e7f6f140fc2..f08c94b59f87d02a77353150939646c0fc9d772f 100644 (file)
@@ -481,3 +481,9 @@ div .version-warning {
   background:#FFBBBB;
 }
 
+/*div .event-signatures {
+  background-color:#F0F0FD;
+  padding:0 10px;
+  border:1px solid #BFBFBF;
+}*/
+
index bfd027ead0686f8703a97a7d05cf185d22b25799..64ae49976f17b84c85bf3d739be3cefdf448961c 100644 (file)
@@ -6,6 +6,8 @@
 
 """Base event API."""
 
+from __future__ import absolute_import
+
 from . import util, exc
 from itertools import chain
 import weakref
@@ -77,6 +79,15 @@ def remove(target, identifier, fn):
             tgt.dispatch._remove(identifier, tgt, fn)
             return
 
+def _legacy_signature(since, argnames, converter=None):
+    def leg(fn):
+        if not hasattr(fn, '_legacy_signatures'):
+            fn._legacy_signatures = []
+        fn._legacy_signatures.append((since, argnames, converter))
+        return fn
+    return leg
+
+
 _registrars = util.defaultdict(list)
 
 
@@ -189,7 +200,7 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
 
     for k in dict_:
         if _is_event_name(k):
-            setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k]))
+            setattr(dispatch_cls, k, _DispatchDescriptor(cls, dict_[k]))
             _registrars[k].append(cls)
 
 
@@ -217,12 +228,16 @@ class Events(util.with_metaclass(_EventMeta, object)):
             return None
 
     @classmethod
-    def _listen(cls, target, identifier, fn, propagate=False, insert=False):
+    def _listen(cls, target, identifier, fn, propagate=False, insert=False,
+                            named=False):
+        dispatch_descriptor = getattr(target.dispatch, identifier)
+        fn = dispatch_descriptor._adjust_fn_spec(fn, named)
+
         if insert:
-            getattr(target.dispatch, identifier).\
+            dispatch_descriptor.\
                     for_modify(target.dispatch).insert(fn, target, propagate)
         else:
-            getattr(target.dispatch, identifier).\
+            dispatch_descriptor.\
                     for_modify(target.dispatch).append(fn, target, propagate)
 
     @classmethod
@@ -237,12 +252,169 @@ class Events(util.with_metaclass(_EventMeta, object)):
 class _DispatchDescriptor(object):
     """Class-level attributes on :class:`._Dispatch` classes."""
 
-    def __init__(self, fn):
+    def __init__(self, parent_dispatch_cls, fn):
         self.__name__ = fn.__name__
-        self.__doc__ = fn.__doc__
+        argspec = util.inspect_getargspec(fn)
+        self.arg_names = argspec.args[1:]
+        self.has_kw = bool(argspec.keywords)
+        self.legacy_signatures = list(reversed(
+                        sorted(
+                            getattr(fn, '_legacy_signatures', []),
+                            key=lambda s: s[0]
+                        )
+                    ))
+        self.__doc__ = fn.__doc__ = self._augment_fn_docs(parent_dispatch_cls, fn)
+
         self._clslevel = weakref.WeakKeyDictionary()
         self._empty_listeners = weakref.WeakKeyDictionary()
 
+    def _adjust_fn_spec(self, fn, named):
+        argspec = util.get_callable_argspec(fn, no_self=True)
+        if named:
+            fn = self._wrap_fn_for_kw(fn)
+        fn = self._wrap_fn_for_legacy(fn, argspec)
+        return fn
+
+    def _wrap_fn_for_kw(self, fn):
+        def wrap_kw(*args, **kw):
+            argdict = dict(zip(self.arg_names, args))
+            argdict.update(kw)
+            return fn(**argdict)
+        return wrap_kw
+
+    def _wrap_fn_for_legacy(self, fn, argspec):
+        for since, argnames, conv in self.legacy_signatures:
+            if argnames[-1] == "**kw":
+                has_kw = True
+                argnames = argnames[0:-1]
+            else:
+                has_kw = False
+
+            if len(argnames) == len(argspec.args) \
+                and has_kw is bool(argspec.keywords):
+
+                if conv:
+                    assert not has_kw
+                    def wrap_leg(*args):
+                        return fn(*conv(*args))
+                else:
+                    def wrap_leg(*args, **kw):
+                        argdict = dict(zip(self.arg_names, args))
+                        args = [argdict[name] for name in argnames]
+                        if has_kw:
+                            return fn(*args, **kw)
+                        else:
+                            return fn(*args)
+                return wrap_leg
+        else:
+            return fn
+
+    def _indent(self, text, indent):
+        return "\n".join(
+                    indent + line
+                    for line in text.split("\n")
+                )
+
+    def _standard_listen_example(self, sample_target, fn):
+        example_kw_arg = self._indent(
+                "\n".join(
+                    "%(arg)s = kw['%(arg)s']" % {"arg": arg}
+                    for arg in self.arg_names[0:2]
+                ),
+                "    ")
+        if self.legacy_signatures:
+            current_since = max(since for since, args, conv in self.legacy_signatures)
+        else:
+            current_since = None
+        text = (
+                "from sqlalchemy import event\n\n"
+                "# standard decorator style%(current_since)s\n"
+                "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+                "def receive_%(event_name)s(%(named_event_arguments)s%(has_kw_arguments)s):\n"
+                "    \"listen for the '%(event_name)s' event\"\n"
+                "\n    # ... (event handling logic) ...\n"
+        )
+
+        if len(self.arg_names) > 2:
+            text += (
+
+                "\n# named argument style (new in 0.9)\n"
+                "@event.listens_for(%(sample_target)s, '%(event_name)s', named=True)\n"
+                "def receive_%(event_name)s(**kw):\n"
+                "    \"listen for the '%(event_name)s' event\"\n"
+                "%(example_kw_arg)s\n"
+                "\n    # ... (event handling logic) ...\n"
+            )
+
+        text %= {
+                    "current_since": " (arguments as of %s)" %
+                                    current_since if current_since else "",
+                    "event_name": fn.__name__,
+                    "has_kw_arguments": " **kw" if self.has_kw else "",
+                    "named_event_arguments": ", ".join(self.arg_names),
+                    "example_kw_arg": example_kw_arg,
+                    "sample_target": sample_target
+                }
+        return text
+
+    def _legacy_listen_examples(self, sample_target, fn):
+        text = ""
+        for since, args, conv in self.legacy_signatures:
+            text += (
+                "\n# legacy calling style (pre-%(since)s)\n"
+                "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+                "def receive_%(event_name)s(%(named_event_arguments)s%(has_kw_arguments)s):\n"
+                "    \"listen for the '%(event_name)s' event\"\n"
+                "\n    # ... (event handling logic) ...\n" % {
+                    "since": since,
+                    "event_name": fn.__name__,
+                    "has_kw_arguments": " **kw" if self.has_kw else "",
+                    "named_event_arguments": ", ".join(args),
+                    "sample_target": sample_target
+                }
+            )
+        return text
+
+    def _version_signature_changes(self):
+        since, args, conv = self.legacy_signatures[0]
+        return (
+                "\n.. versionchanged:: %(since)s\n"
+                "    The ``%(event_name)s`` event now accepts the \n"
+                "    arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
+                "    Listener functions which accept the previous argument \n"
+                "    signature(s) listed above will be automatically \n"
+                "    adapted to the new signature." % {
+                    "since": since,
+                    "event_name": self.__name__,
+                    "named_event_arguments": ", ".join(self.arg_names),
+                    "has_kw_arguments": ", **kw" if self.has_kw else ""
+                }
+            )
+
+    def _augment_fn_docs(self, parent_dispatch_cls, fn):
+        header = ".. container:: event_signatures\n\n"\
+                "     Example argument forms::\n"\
+                "\n"
+
+        sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
+        text = (
+                header +
+                self._indent(
+                            self._standard_listen_example(sample_target, fn),
+                            " " * 8)
+            )
+        if self.legacy_signatures:
+            text += self._indent(
+                            self._legacy_listen_examples(sample_target, fn),
+                            " " * 8)
+
+            text += self._version_signature_changes()
+
+        return util.inject_docstring_text(fn.__doc__,
+                text,
+                1
+            )
+
     def _contains(self, cls, evt):
         return cls in self._clslevel and \
             evt in self._clslevel[cls]
@@ -324,8 +496,11 @@ class _DispatchDescriptor(object):
         obj.__dict__[self.__name__] = ret
         return ret
 
+class _HasParentDispatchDescriptor(object):
+    def _adjust_fn_spec(self, fn, named):
+        return self.parent._adjust_fn_spec(fn, named)
 
-class _EmptyListener(object):
+class _EmptyListener(_HasParentDispatchDescriptor):
     """Serves as a class-level interface to the events
     served by a _DispatchDescriptor, when there are no
     instance-level events present.
@@ -337,12 +512,13 @@ class _EmptyListener(object):
     def __init__(self, parent, target_cls):
         if target_cls not in parent._clslevel:
             parent.update_subclass(target_cls)
-        self.parent = parent
+        self.parent = parent  # _DispatchDescriptor
         self.parent_listeners = parent._clslevel[target_cls]
         self.name = parent.__name__
         self.propagate = frozenset()
         self.listeners = ()
 
+
     def for_modify(self, obj):
         """Return an event collection which can be modified.
 
@@ -380,7 +556,7 @@ class _EmptyListener(object):
     __nonzero__ = __bool__
 
 
-class _CompoundListener(object):
+class _CompoundListener(_HasParentDispatchDescriptor):
     _exec_once = False
 
     def exec_once(self, *args, **kw):
@@ -432,6 +608,7 @@ class _ListenerCollection(_CompoundListener):
         if target_cls not in parent._clslevel:
             parent.update_subclass(target_cls)
         self.parent_listeners = parent._clslevel[target_cls]
+        self.parent = parent
         self.name = parent.__name__
         self.listeners = []
         self.propagate = set()
@@ -520,6 +697,9 @@ class _JoinedListener(_CompoundListener):
         # each time. less performant.
         self.listeners = list(getattr(self.parent, self.name))
 
+    def _adjust_fn_spec(self, fn, named):
+        return self.local._adjust_fn_spec(fn, named)
+
     def for_modify(self, obj):
         self.local = self.parent_listeners = self.local.for_modify(obj)
         return self
index 7f11232ac40adce7e2bc20c58099c363bb601af2..4fb997b9c9401d699c65e3856f7a7c0d739ac81e 100644 (file)
@@ -70,6 +70,8 @@ class DDLEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeSchemaClassOrObject"
+
     def before_create(self, target, connection, **kw):
         """Called before CREATE statments are emitted.
 
@@ -266,6 +268,8 @@ class PoolEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeEngineOrPool"
+
     @classmethod
     def _accept_with(cls, target):
         if isinstance(target, type):
@@ -443,6 +447,8 @@ class ConnectionEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeEngine"
+
     @classmethod
     def _listen(cls, target, identifier, fn, retval=False):
         target._has_events = True
@@ -753,7 +759,7 @@ class ConnectionEvents(event.Events):
         :param conn: :class:`.Connection` object
         """
 
-    def savepoint(self, conn, name=None):
+    def savepoint(self, conn, name):
         """Intercept savepoint() events.
 
         :param conn: :class:`.Connection` object
index cea07bcf0342804c9835e65fd527bdb5a98202e8..97019bb4e58b147982ecf49483aaff1a86c5d388 100644 (file)
@@ -42,6 +42,8 @@ class InstrumentationEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeBaseClass"
+
     @classmethod
     def _accept_with(cls, target):
         # TODO: there's no coverage for this
@@ -151,6 +153,9 @@ class InstanceEvents(event.Events):
        object, rather than the mapped instance itself.
 
     """
+
+    _target_class_doc = "SomeMappedClass"
+
     @classmethod
     def _accept_with(cls, target):
         if isinstance(target, orm.instrumentation.ClassManager):
@@ -450,6 +455,8 @@ class MapperEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeMappedClass"
+
     @classmethod
     def _accept_with(cls, target):
         if target is orm.mapper:
@@ -1083,6 +1090,9 @@ class SessionEvents(event.Events):
     globally.
 
     """
+
+    _target_class_doc = "SomeSessionOrFactory"
+
     @classmethod
     def _accept_with(cls, target):
         if isinstance(target, orm.scoped_session):
@@ -1382,31 +1392,55 @@ class SessionEvents(event.Events):
 
         """
 
-    def after_bulk_update(self, session, query, query_context, result):
+    @event._legacy_signature("0.9",
+                    ["session", "query", "query_context", "result"],
+                    lambda update_context: (
+                            update_context.session,
+                            update_context.query,
+                            update_context.context,
+                            update_context.result))
+    def after_bulk_update(self, update_context):
         """Execute after a bulk update operation to the session.
 
         This is called as a result of the :meth:`.Query.update` method.
 
-        :param query: the :class:`.Query` object that this update operation was
-         called upon.
-        :param query_context: The :class:`.QueryContext` object, corresponding
-         to the invocation of an ORM query.
-        :param result: the :class:`.ResultProxy` returned as a result of the
-         bulk UPDATE operation.
+        :param update_context: an "update context" object which contains
+         details about the update, including these attributes:
+
+            * ``session`` - the :class:`.Session` involved
+            * ``query`` -the :class:`.Query` object that this update operation was
+              called upon.
+            * ``context`` The :class:`.QueryContext` object, corresponding
+              to the invocation of an ORM query.
+            * ``result`` the :class:`.ResultProxy` returned as a result of the
+              bulk UPDATE operation.
+
 
         """
 
-    def after_bulk_delete(self, session, query, query_context, result):
+    @event._legacy_signature("0.9",
+                    ["session", "query", "query_context", "result"],
+                    lambda delete_context: (
+                            delete_context.session,
+                            delete_context.query,
+                            delete_context.context,
+                            delete_context.result))
+    def after_bulk_delete(self, delete_context):
         """Execute after a bulk delete operation to the session.
 
         This is called as a result of the :meth:`.Query.delete` method.
 
-        :param query: the :class:`.Query` object that this update operation was
-         called upon.
-        :param query_context: The :class:`.QueryContext` object, corresponding
-         to the invocation of an ORM query.
-        :param result: the :class:`.ResultProxy` returned as a result of the
-         bulk DELETE operation.
+        :param delete_context: a "delete context" object which contains
+         details about the update, including these attributes:
+
+            * ``session`` - the :class:`.Session` involved
+            * ``query`` -the :class:`.Query` object that this update operation was
+              called upon.
+            * ``context`` The :class:`.QueryContext` object, corresponding
+              to the invocation of an ORM query.
+            * ``result`` the :class:`.ResultProxy` returned as a result of the
+              bulk DELETE operation.
+
 
         """
 
@@ -1468,6 +1502,8 @@ class AttributeEvents(event.Events):
 
     """
 
+    _target_class_doc = "SomeClass.some_attribute"
+
     @classmethod
     def _accept_with(cls, target):
         # TODO: coverage
index 944623b079d07e6391335f9e71f9815621380721..44da881189549fe11707d719cde58140b650a858 100644 (file)
@@ -798,6 +798,10 @@ class BulkUD(object):
     def __init__(self, query):
         self.query = query.enable_eagerloads(False)
 
+    @property
+    def session(self):
+        return self.query.session
+
     @classmethod
     def _factory(cls, lookup, synchronize_session, *arg):
         try:
@@ -915,8 +919,7 @@ class BulkUpdate(BulkUD):
 
     def _do_post(self):
         session = self.query.session
-        session.dispatch.after_bulk_update(session, self.query,
-                                self.context, self.result)
+        session.dispatch.after_bulk_update(self)
 
 
 class BulkDelete(BulkUD):
@@ -944,8 +947,7 @@ class BulkDelete(BulkUD):
 
     def _do_post(self):
         session = self.query.session
-        session.dispatch.after_bulk_delete(session, self.query,
-                        self.context, self.result)
+        session.dispatch.after_bulk_delete(self)
 
 
 class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
index 739caefe034d799f184059afaed0f7552d7317ec..5d9c4d49ec82480ff89a239f11e25deee4ff85d0 100644 (file)
@@ -10,7 +10,7 @@ from .compat import callable, cmp, reduce,  \
     raise_from_cause, text_type, string_types, int_types, binary_type, \
     quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\
     unquote_plus, b64decode, b64encode, byte_buffer, itertools_filter,\
-    StringIO
+    StringIO, inspect_getargspec
 
 from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \
     Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
@@ -30,10 +30,11 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
     duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
     classproperty, set_creation_order, warn_exception, warn, NoneType,\
     constructor_copy, methods_equivalent, chop_traceback, asint,\
-    generic_repr, counter, PluginLoader, hybridmethod, safe_reraise
+    generic_repr, counter, PluginLoader, hybridmethod, safe_reraise,\
+    get_callable_argspec
 
 from .deprecations import warn_deprecated, warn_pending_deprecation, \
-    deprecated, pending_deprecation
+    deprecated, pending_deprecation, inject_docstring_text
 
 # things that used to be not always available,
 # but are now as of current support Python versions
index d866534ab598ed4b90497062e3da82e51699fa79..a89762b4e78141207c8f3326ee9c0a98ec26f404 100644 (file)
@@ -22,7 +22,7 @@ pypy = hasattr(sys, 'pypy_version_info')
 win32 = sys.platform.startswith('win')
 cpython = not pypy and not jython  # TODO: something better for this ?
 
-
+import collections
 next = next
 
 if py3k:
@@ -33,6 +33,9 @@ else:
     except ImportError:
         import pickle
 
+ArgSpec = collections.namedtuple("ArgSpec",
+                ["args", "varargs", "keywords", "defaults"])
+
 if py3k:
     import builtins
 
@@ -43,6 +46,10 @@ if py3k:
 
     from io import BytesIO as byte_buffer
 
+    def inspect_getargspec(func):
+        return ArgSpec(
+                    *inspect_getfullargspec(func)[0:4]
+                )
 
     string_types = str,
     binary_type = bytes
@@ -87,6 +94,7 @@ if py3k:
 
 else:
     from inspect import getargspec as inspect_getfullargspec
+    inspect_getargspec = inspect_getfullargspec
     from urllib import quote_plus, unquote_plus
     from urlparse import parse_qsl
     import ConfigParser as configparser
index e0dc168db7c57194d566436e8ae429142c7acb11..c315d2da6ab95252b47325767c351133420b1017 100644 (file)
@@ -107,17 +107,37 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
     doc = func.__doc__ is not None and func.__doc__ or ''
     if docstring_header is not None:
         docstring_header %= dict(func=func.__name__)
-        docs = doc and doc.expandtabs().split('\n') or []
-        indent = ''
-        for line in docs[1:]:
-            text = line.lstrip()
-            if text:
-                indent = line[0:len(line) - len(text)]
-                break
-        point = min(len(docs), 1)
-        docs.insert(point, '\n' + indent + docstring_header.rstrip())
-        doc = '\n'.join(docs)
+
+        doc = inject_docstring_text(doc, docstring_header, 1)
 
     decorated = warned(func)
     decorated.__doc__ = doc
     return decorated
+
+import textwrap
+
+def _dedent_docstring(text):
+    split_text = text.split("\n", 1)
+    if len(split_text) == 1:
+        return text
+    else:
+        firstline, remaining = split_text
+    if not firstline.startswith(" "):
+        return firstline + "\n" + textwrap.dedent(remaining)
+    else:
+        return textwrap.dedent(text)
+
+def inject_docstring_text(doctext, injecttext, pos):
+    doctext = _dedent_docstring(doctext or "")
+    lines = doctext.split('\n')
+    injectlines = textwrap.dedent(injecttext).split("\n")
+    if injectlines[0]:
+        injectlines.insert(0, "")
+
+    blanks = [num for num, line in enumerate(lines) if not line.strip()]
+    blanks.insert(0, 0)
+
+    inject_pos = blanks[min(pos, len(blanks) - 1)]
+
+    lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
+    return "\n".join(lines)
index 1ff868e01978fc7a539d767d5ae4308f8d31dec9..c91178a75eccd56bafac119204dcd024ae039897 100644 (file)
@@ -211,8 +211,20 @@ def get_func_kwargs(func):
 
     """
 
-    return inspect.getargspec(func)[0]
-
+    return compat.inspect_getargspec(func)[0]
+
+def get_callable_argspec(fn, no_self=False):
+    if isinstance(fn, types.FunctionType):
+        return compat.inspect_getargspec(fn)
+    elif isinstance(fn, types.MethodType) and no_self:
+        spec = compat.inspect_getargspec(fn.__func__)
+        return compat.ArgSpec(spec.args[1:], spec.varargs, spec.keywords, spec.defaults)
+    elif hasattr(fn, '__func__'):
+        return compat.inspect_getargspec(fn.__func__)
+    elif hasattr(fn, '__call__'):
+        return get_callable_argspec(fn.__call__)
+    else:
+        raise ValueError("Can't inspect function: %s" % fn)
 
 def format_argspec_plus(fn, grouped=True):
     """Returns a dictionary of formatted, introspected function arguments.
index 20bfa62ff936b13d6801997c1beb52784465a409..1e0568f2774059c72f2caaa14fae207c730aaad0 100644 (file)
@@ -171,6 +171,206 @@ class EventsTest(fixtures.TestBase):
                 meth
             )
 
+class NamedCallTest(fixtures.TestBase):
+
+    def setUp(self):
+        class TargetEventsOne(event.Events):
+            def event_one(self, x, y):
+                pass
+
+            def event_two(self, x, y, **kw):
+                pass
+
+            def event_five(self, x, y, z, q):
+                pass
+
+        class TargetOne(object):
+            dispatch = event.dispatcher(TargetEventsOne)
+        self.TargetOne = TargetOne
+
+    def tearDown(self):
+        event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events)
+
+
+    def test_kw_accept(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_one", named=True)
+        def handler1(**kw):
+            canary(kw)
+
+        self.TargetOne().dispatch.event_one(4, 5)
+
+        eq_(
+            canary.mock_calls,
+            [call({"x": 4, "y": 5})]
+        )
+
+    def test_partial_kw_accept(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_five", named=True)
+        def handler1(z, y, **kw):
+            canary(z, y, kw)
+
+        self.TargetOne().dispatch.event_five(4, 5, 6, 7)
+
+        eq_(
+            canary.mock_calls,
+            [call(6, 5, {"x": 4, "q": 7})]
+        )
+
+    def test_kw_accept_plus_kw(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_two", named=True)
+        def handler1(**kw):
+            canary(kw)
+
+        self.TargetOne().dispatch.event_two(4, 5, z=8, q=5)
+
+        eq_(
+            canary.mock_calls,
+            [call({"x": 4, "y": 5, "z": 8, "q": 5})]
+        )
+
+
+class LegacySignatureTest(fixtures.TestBase):
+    """test adaption of legacy args"""
+
+
+    def setUp(self):
+        class TargetEventsOne(event.Events):
+
+            @event._legacy_signature("0.9", ["x", "y"])
+            def event_three(self, x, y, z, q):
+                pass
+
+            @event._legacy_signature("0.9", ["x", "y", "**kw"])
+            def event_four(self, x, y, z, q, **kw):
+                pass
+
+            @event._legacy_signature("0.9", ["x", "y", "z", "q"],
+                                lambda x, y: (x, y, x + y, x * y))
+            def event_six(self, x, y):
+                pass
+
+
+        class TargetOne(object):
+            dispatch = event.dispatcher(TargetEventsOne)
+        self.TargetOne = TargetOne
+
+    def tearDown(self):
+        event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events)
+
+    def test_legacy_accept(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_three")
+        def handler1(x, y):
+            canary(x, y)
+
+        self.TargetOne().dispatch.event_three(4, 5, 6, 7)
+
+        eq_(
+            canary.mock_calls,
+            [call(4, 5)]
+        )
+
+    def test_legacy_accept_kw_cls(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_four")
+        def handler1(x, y, **kw):
+            canary(x, y, kw)
+        self._test_legacy_accept_kw(self.TargetOne(), canary)
+
+    def test_legacy_accept_kw_instance(self):
+        canary = Mock()
+
+        inst = self.TargetOne()
+        @event.listens_for(inst, "event_four")
+        def handler1(x, y, **kw):
+            canary(x, y, kw)
+        self._test_legacy_accept_kw(inst, canary)
+
+    def _test_legacy_accept_kw(self, target, canary):
+        target.dispatch.event_four(4, 5, 6, 7, foo="bar")
+
+        eq_(
+            canary.mock_calls,
+            [call(4, 5, {"foo": "bar"})]
+        )
+
+    def test_complex_legacy_accept(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_six")
+        def handler1(x, y, z, q):
+            canary(x, y, z, q)
+
+        self.TargetOne().dispatch.event_six(4, 5)
+        eq_(
+            canary.mock_calls,
+            [call(4, 5, 9, 20)]
+        )
+
+    def test_legacy_accept_from_method(self):
+        canary = Mock()
+
+        class MyClass(object):
+            def handler1(self, x, y):
+                canary(x, y)
+
+        event.listen(self.TargetOne, "event_three", MyClass().handler1)
+
+        self.TargetOne().dispatch.event_three(4, 5, 6, 7)
+        eq_(
+            canary.mock_calls,
+            [call(4, 5)]
+        )
+
+    def test_standard_accept_has_legacies(self):
+        canary = Mock()
+
+        event.listen(self.TargetOne, "event_three", canary)
+
+        self.TargetOne().dispatch.event_three(4, 5)
+
+        eq_(
+            canary.mock_calls,
+            [call(4, 5)]
+        )
+
+    def test_kw_accept_has_legacies(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_three", named=True)
+        def handler1(**kw):
+            canary(kw)
+
+        self.TargetOne().dispatch.event_three(4, 5, 6, 7)
+
+        eq_(
+            canary.mock_calls,
+            [call({"x": 4, "y": 5, "z": 6, "q": 7})]
+        )
+
+    def test_kw_accept_plus_kw_has_legacies(self):
+        canary = Mock()
+
+        @event.listens_for(self.TargetOne, "event_four", named=True)
+        def handler1(**kw):
+            canary(kw)
+
+        self.TargetOne().dispatch.event_four(4, 5, 6, 7, foo="bar")
+
+        eq_(
+            canary.mock_calls,
+            [call({"x": 4, "y": 5, "z": 6, "q": 7, "foo": "bar"})]
+        )
+
+
 class ClsLevelListenTest(fixtures.TestBase):
 
 
@@ -508,6 +708,21 @@ class JoinTest(fixtures.TestBase):
         element.run_event(2)
         element.run_event(3)
 
+    def test_kw_ok(self):
+        l1 = Mock()
+        def listen(**kw):
+            l1(kw)
+
+        event.listen(self.TargetFactory, "event_one", listen, named=True)
+        element = self.TargetFactory().create()
+        element.run_event(1)
+        element.run_event(2)
+        eq_(
+            l1.mock_calls,
+            [call({"target": element, "arg": 1}),
+                call({"target": element, "arg": 2}),]
+        )
+
     def test_parent_class_only(self):
         l1 = Mock()
 
index 2f91f5c8324a0c7791d946467f8e27c097a7a612..d2dae8ba3e7dfe7707949a812ec97eaf5c8397f6 100644 (file)
@@ -14,7 +14,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.util import gc_collect
 from test.orm import _fixtures
 from sqlalchemy import event
-
+from sqlalchemy.testing.mock import Mock, call
 
 class _RemoveListeners(object):
     def teardown(self):
@@ -341,12 +341,12 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
                                 self.classes.User)
 
         canary = []
-        def evt(x):
+        def evt(x, y, z):
             canary.append(x)
         event.listen(User, "before_insert", evt, raw=True)
 
         m = mapper(User, users)
-        m.dispatch.before_insert(5)
+        m.dispatch.before_insert(5, 6, 7)
         eq_(canary, [5])
 
     def test_deferred_map_event_subclass_propagate(self):
@@ -363,12 +363,12 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
             pass
 
         canary = []
-        def evt(x):
+        def evt(x, y, z):
             canary.append(x)
         event.listen(User, "before_insert", evt, propagate=True, raw=True)
 
         m = mapper(SubUser, users)
-        m.dispatch.before_insert(5)
+        m.dispatch.before_insert(5, 6, 7)
         eq_(canary, [5])
 
     def test_deferred_map_event_subclass_no_propagate(self):
@@ -385,12 +385,12 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
             pass
 
         canary = []
-        def evt(x):
+        def evt(x, y, z):
             canary.append(x)
         event.listen(User, "before_insert", evt, propagate=False)
 
         m = mapper(SubUser, users)
-        m.dispatch.before_insert(5)
+        m.dispatch.before_insert(5, 6, 7)
         eq_(canary, [])
 
     def test_deferred_map_event_subclass_post_mapping_propagate(self):
@@ -409,11 +409,11 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
         m = mapper(SubUser, users)
 
         canary = []
-        def evt(x):
+        def evt(x, y, z):
             canary.append(x)
         event.listen(User, "before_insert", evt, propagate=True, raw=True)
 
-        m.dispatch.before_insert(5)
+        m.dispatch.before_insert(5, 6, 7)
         eq_(canary, [5])
 
     def test_deferred_instance_event_subclass_post_mapping_propagate(self):
@@ -1068,18 +1068,75 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
     def test_on_bulk_update_hook(self):
         User, users = self.classes.User, self.tables.users
 
-        sess, canary = self._listener_fixture()
+        sess = Session()
+        canary = Mock()
+
+        event.listen(sess, "after_begin", canary.after_begin)
+        event.listen(sess, "after_bulk_update", canary.after_bulk_update)
+
+        def legacy(ses, qry, ctx, res):
+            canary.after_bulk_update_legacy(ses, qry, ctx, res)
+        event.listen(sess, "after_bulk_update", legacy)
+
         mapper(User, users)
+
         sess.query(User).update({'name': 'foo'})
-        eq_(canary, ['after_begin', 'after_bulk_update'])
+
+        eq_(
+            canary.after_begin.call_count,
+            1
+        )
+        eq_(
+            canary.after_bulk_update.call_count,
+            1
+        )
+
+        upd = canary.after_bulk_update.mock_calls[0][1][0]
+        eq_(
+            upd.session,
+            sess
+        )
+        eq_(
+            canary.after_bulk_update_legacy.mock_calls,
+            [call(sess, upd.query, upd.context, upd.result)]
+        )
+
 
     def test_on_bulk_delete_hook(self):
         User, users = self.classes.User, self.tables.users
 
-        sess, canary = self._listener_fixture()
+        sess = Session()
+        canary = Mock()
+
+        event.listen(sess, "after_begin", canary.after_begin)
+        event.listen(sess, "after_bulk_delete", canary.after_bulk_delete)
+
+        def legacy(ses, qry, ctx, res):
+            canary.after_bulk_delete_legacy(ses, qry, ctx, res)
+        event.listen(sess, "after_bulk_delete", legacy)
+
         mapper(User, users)
+
         sess.query(User).delete()
-        eq_(canary, ['after_begin', 'after_bulk_delete'])
+
+        eq_(
+            canary.after_begin.call_count,
+            1
+        )
+        eq_(
+            canary.after_bulk_delete.call_count,
+            1
+        )
+
+        upd = canary.after_bulk_delete.mock_calls[0][1][0]
+        eq_(
+            upd.session,
+            sess
+        )
+        eq_(
+            canary.after_bulk_delete_legacy.mock_calls,
+            [call(sess, upd.query, upd.context, upd.result)]
+        )
 
     def test_connection_emits_after_begin(self):
         sess, canary = self._listener_fixture(bind=testing.db)
@@ -1508,19 +1565,13 @@ class SessionExtensionTest(_fixtures.FixtureTest):
                 log.append('after_attach')
             def after_bulk_update(
                 self,
-                session,
-                query,
-                query_context,
-                result,
+                session, query, query_context, result
                 ):
                 log.append('after_bulk_update')
 
             def after_bulk_delete(
                 self,
-                session,
-                query,
-                query_context,
-                result,
+                session, query, query_context, result
                 ):
                 log.append('after_bulk_delete')