]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generalize scoped_session proxying and apply to asyncio elements
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Oct 2020 19:20:48 +0000 (15:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Oct 2020 05:17:25 +0000 (01:17 -0400)
Reworked the proxy creation used by scoped_session() to be
based on fully copied code with augmented docstrings and
moved it into langhelpers.  asyncio session, engine,
connection can now take
advantage of it so that all non-async methods are availble.

Overall implementation of most important accessors / methods
on AsyncConnection, etc. , including awaitable versions
of invalidate, execution_options, etc.

In order to support an event dispatcher on the async
classes while still allowing them to hold __slots__,
make some adjustments to the event system to allow
that to be present, at least rudimentally.

Fixes: #5628
Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2

19 files changed:
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/event/base.py
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/events.py [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_events.py
test/base/test_utils.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/orm/test_scoping.py
test/orm/test_session.py

index 9a6bdd7f36617635c9fe2ab4f9290b7b4b1bfac7..4fbdec1452aa04168a9531fd09c673bd1d761be6 100644 (file)
@@ -184,14 +184,17 @@ class Connection(Connectable):
         r""" Set non-SQL options for the connection which take effect
         during execution.
 
-        The method returns a copy of this :class:`_engine.Connection`
-        which references
-        the same underlying DBAPI connection, but also defines the given
-        execution options which will take effect for a call to
-        :meth:`execute`. As the new :class:`_engine.Connection`
-        references the same
-        underlying resource, it's usually a good idea to ensure that the copies
-        will be discarded immediately, which is implicit if used as in::
+        For a "future" style connection, this method returns this same
+        :class:`_future.Connection` object with the new options added.
+
+        For a legacy connection, this method returns a copy of this
+        :class:`_engine.Connection` which references the same underlying DBAPI
+        connection, but also defines the given execution options which will
+        take effect for a call to
+        :meth:`execute`. As the new :class:`_engine.Connection` references the
+        same underlying resource, it's usually a good idea to ensure that
+        the copies will be discarded immediately, which is implicit if used
+        as in::
 
             result = connection.execution_options(stream_results=True).\
                                 execute(stmt)
@@ -549,9 +552,10 @@ class Connection(Connectable):
         """Invalidate the underlying DBAPI connection associated with
         this :class:`_engine.Connection`.
 
-        The underlying DBAPI connection is literally closed (if
-        possible), and is discarded.  Its source connection pool will
-        typically lazily create a new connection to replace it.
+        An attempt will be made to close the underlying DBAPI connection
+        immediately; however if this operation fails, the error is logged
+        but not raised.  The connection is then discarded whether or not
+        close() succeeded.
 
         Upon the next use (where "use" typically means using the
         :meth:`_engine.Connection.execute` method or similar),
@@ -580,6 +584,10 @@ class Connection(Connectable):
         will at the connection pool level invoke the
         :meth:`_events.PoolEvents.invalidate` event.
 
+        :param exception: an optional ``Exception`` instance that's the
+         reason for the invalidation.  is passed along to event handlers
+         and logging functions.
+
         .. seealso::
 
             :ref:`pool_connection_invalidation`
index daa6f9aeabbbdcb56415033f9ff17050d17d25bf..1ba88f3d21e0d08a3ba559c5e52fec1942299b91 100644 (file)
@@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
                 dispatch_cls._event_names.append(ls.name)
 
     if getattr(cls, "_dispatch_target", None):
-        cls._dispatch_target.dispatch = dispatcher(cls)
+        the_cls = cls._dispatch_target
+        if (
+            hasattr(the_cls, "__slots__")
+            and "_slots_dispatch" in the_cls.__slots__
+        ):
+            cls._dispatch_target.dispatch = slots_dispatcher(cls)
+        else:
+            cls._dispatch_target.dispatch = dispatcher(cls)
 
 
 def _remove_dispatcher(cls):
@@ -304,5 +311,29 @@ class dispatcher(object):
     def __get__(self, obj, cls):
         if obj is None:
             return self.dispatch
-        obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
+
+        disp = self.dispatch._for_instance(obj)
+        try:
+            obj.__dict__["dispatch"] = disp
+        except AttributeError as ae:
+            util.raise_(
+                TypeError(
+                    "target %r doesn't have __dict__, should it be "
+                    "defining _slots_dispatch?" % (obj,)
+                ),
+                replace_context=ae,
+            )
+        return disp
+
+
+class slots_dispatcher(dispatcher):
+    def __get__(self, obj, cls):
+        if obj is None:
+            return self.dispatch
+
+        if hasattr(obj, "_slots_dispatch"):
+            return obj._slots_dispatch
+
+        disp = self.dispatch._for_instance(obj)
+        obj._slots_dispatch = disp
         return disp
index fbbc958d42518195e87af0d8376567c140aa1a51..9c7d6443cba053febb0747d5b6a188fd94a72117 100644 (file)
@@ -2,6 +2,8 @@ from .engine import AsyncConnection  # noqa
 from .engine import AsyncEngine  # noqa
 from .engine import AsyncTransaction  # noqa
 from .engine import create_async_engine  # noqa
+from .events import AsyncConnectionEvents  # noqa
+from .events import AsyncSessionEvents  # noqa
 from .result import AsyncMappingResult  # noqa
 from .result import AsyncResult  # noqa
 from .result import AsyncScalarResult  # noqa
index 4a92fb1f2c1c556a0e05704aa0ec6fd174bdacee..9e4851dfcda38e7b78a546d774dedca87ed35e44 100644 (file)
@@ -8,12 +8,11 @@ from .base import StartableContext
 from .result import AsyncResult
 from ... import exc
 from ... import util
-from ...engine import Connection
 from ...engine import create_engine as _create_engine
-from ...engine import Engine
 from ...engine import Result
 from ...engine import Transaction
-from ...engine.base import OptionEngineMixin
+from ...future import Connection
+from ...future import Engine
 from ...sql import Executable
 from ...util.concurrency import greenlet_spawn
 
@@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw):
     return AsyncEngine(sync_engine)
 
 
-class AsyncConnection(StartableContext):
+class AsyncConnectable:
+    __slots__ = "_slots_dispatch"
+
+
+@util.create_proxy_methods(
+    Connection,
+    ":class:`_future.Connection`",
+    ":class:`_asyncio.AsyncConnection`",
+    classmethods=[],
+    methods=[],
+    attributes=[
+        "closed",
+        "invalidated",
+        "dialect",
+        "default_isolation_level",
+    ],
+)
+class AsyncConnection(StartableContext, AsyncConnectable):
     """An asyncio proxy for a :class:`_engine.Connection`.
 
     :class:`_asyncio.AsyncConnection` is acquired using the
@@ -58,15 +74,23 @@ class AsyncConnection(StartableContext):
 
     """  # noqa
 
+    # AsyncConnection is a thin proxy; no state should be added here
+    # that is not retrievable from the "sync" engine / connection, e.g.
+    # current transaction, info, etc.   It should be possible to
+    # create a new AsyncConnection that matches this one given only the
+    # "sync" elements.
     __slots__ = (
         "sync_engine",
         "sync_connection",
     )
 
     def __init__(
-        self, sync_engine: Engine, sync_connection: Optional[Connection] = None
+        self,
+        async_engine: "AsyncEngine",
+        sync_connection: Optional[Connection] = None,
     ):
-        self.sync_engine = sync_engine
+        self.engine = async_engine
+        self.sync_engine = async_engine.sync_engine
         self.sync_connection = sync_connection
 
     async def start(self):
@@ -79,6 +103,34 @@ class AsyncConnection(StartableContext):
         self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
         return self
 
+    @property
+    def connection(self):
+        """Not implemented for async; call
+        :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+
+        """
+        raise exc.InvalidRequestError(
+            "AsyncConnection.connection accessor is not implemented as the "
+            "attribute may need to reconnect on an invalidated connection.  "
+            "Use the get_raw_connection() method."
+        )
+
+    async def get_raw_connection(self):
+        """Return the pooled DBAPI-level connection in use by this
+        :class:`_asyncio.AsyncConnection`.
+
+        This is typically the SQLAlchemy connection-pool proxied connection
+        which then has an attribute .connection that refers to the actual
+        DBAPI-level connection.
+        """
+        conn = self._sync_connection()
+
+        return await greenlet_spawn(getattr, conn, "connection")
+
+    @property
+    def _proxied(self):
+        return self.sync_connection
+
     def _sync_connection(self):
         if not self.sync_connection:
             self._raise_for_not_started()
@@ -94,6 +146,43 @@ class AsyncConnection(StartableContext):
         self._sync_connection()
         return AsyncTransaction(self, nested=True)
 
+    async def invalidate(self, exception=None):
+        """Invalidate the underlying DBAPI connection associated with
+        this :class:`_engine.Connection`.
+
+        See the method :meth:`_engine.Connection.invalidate` for full
+        detail on this method.
+
+        """
+
+        conn = self._sync_connection()
+        return await greenlet_spawn(conn.invalidate, exception=exception)
+
+    async def get_isolation_level(self):
+        conn = self._sync_connection()
+        return await greenlet_spawn(conn.get_isolation_level)
+
+    async def set_isolation_level(self):
+        conn = self._sync_connection()
+        return await greenlet_spawn(conn.get_isolation_level)
+
+    async def execution_options(self, **opt):
+        r"""Set non-SQL options for the connection which take effect
+        during execution.
+
+        This returns this :class:`_asyncio.AsyncConnection` object with
+        the new options added.
+
+        See :meth:`_future.Connection.execution_options` for full details
+        on this method.
+
+        """
+
+        conn = self._sync_connection()
+        c2 = await greenlet_spawn(conn.execution_options, **opt)
+        assert c2 is conn
+        return self
+
     async def commit(self):
         """Commit the transaction that is currently in progress.
 
@@ -287,7 +376,19 @@ class AsyncConnection(StartableContext):
         await self.close()
 
 
-class AsyncEngine:
+@util.create_proxy_methods(
+    Engine,
+    ":class:`_future.Engine`",
+    ":class:`_asyncio.AsyncEngine`",
+    classmethods=[],
+    methods=[
+        "clear_compiled_cache",
+        "update_execution_options",
+        "get_execution_options",
+    ],
+    attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(AsyncConnectable):
     """An asyncio proxy for a :class:`_engine.Engine`.
 
     :class:`_asyncio.AsyncEngine` is acquired using the
@@ -301,7 +402,12 @@ class AsyncEngine:
 
     """  # noqa
 
-    __slots__ = ("sync_engine",)
+    # AsyncEngine is a thin proxy; no state should be added here
+    # that is not retrievable from the "sync" engine / connection, e.g.
+    # current transaction, info, etc.   It should be possible to
+    # create a new AsyncEngine that matches this one given only the
+    # "sync" elements.
+    __slots__ = ("sync_engine", "_proxied")
 
     _connection_cls = AsyncConnection
 
@@ -327,7 +433,7 @@ class AsyncEngine:
             await self.conn.close()
 
     def __init__(self, sync_engine: Engine):
-        self.sync_engine = sync_engine
+        self.sync_engine = self._proxied = sync_engine
 
     def begin(self):
         """Return a context manager which when entered will deliver an
@@ -363,7 +469,7 @@ class AsyncEngine:
 
         """
 
-        return self._connection_cls(self.sync_engine)
+        return self._connection_cls(self)
 
     async def raw_connection(self) -> Any:
         """Return a "raw" DBAPI connection from the connection pool.
@@ -375,12 +481,33 @@ class AsyncEngine:
         """
         return await greenlet_spawn(self.sync_engine.raw_connection)
 
+    def execution_options(self, **opt):
+        """Return a new :class:`_asyncio.AsyncEngine` that will provide
+        :class:`_asyncio.AsyncConnection` objects with the given execution
+        options.
+
+        Proxied from :meth:`_future.Engine.execution_options`.  See that
+        method for details.
+
+        """
+
+        return AsyncEngine(self.sync_engine.execution_options(**opt))
 
-class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
-    pass
+    async def dispose(self):
+        """Dispose of the connection pool used by this
+        :class:`_asyncio.AsyncEngine`.
 
+        This will close all connection pool connections that are
+        **currently checked in**.  See the documentation for the underlying
+        :meth:`_future.Engine.dispose` method for further notes.
+
+        .. seealso::
+
+            :meth:`_future.Engine.dispose`
+
+        """
 
-AsyncEngine._option_cls = AsyncOptionEngine
+        return await greenlet_spawn(self.sync_engine.dispose)
 
 
 class AsyncTransaction(StartableContext):
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
new file mode 100644 (file)
index 0000000..a8daefc
--- /dev/null
@@ -0,0 +1,29 @@
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+    _target_class_doc = "SomeEngine"
+    _dispatch_target = AsyncConnectable
+
+    @classmethod
+    def _listen(cls, event_key, retval=False):
+        raise NotImplementedError(
+            "asynchronous events are not implemented at this time.  Apply "
+            "synchronous listeners to the AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes."
+        )
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+    _target_class_doc = "SomeSession"
+    _dispatch_target = AsyncSession
+
+    @classmethod
+    def _listen(cls, event_key, retval=False):
+        raise NotImplementedError(
+            "asynchronous events are not implemented at this time.  Apply "
+            "synchronous listeners to the AsyncSession.sync_session."
+        )
index cb06aa26d5f886c17987f0fc105e180b8409f68a..4ae1fb385ec96bcc781af853e1cb61994df0ef4c 100644 (file)
@@ -1,6 +1,5 @@
 from typing import Any
 from typing import Callable
-from typing import List
 from typing import Mapping
 from typing import Optional
 
@@ -15,6 +14,35 @@ from ...sql import Executable
 from ...util.concurrency import greenlet_spawn
 
 
+@util.create_proxy_methods(
+    Session,
+    ":class:`_orm.Session`",
+    ":class:`_asyncio.AsyncSession`",
+    classmethods=["object_session", "identity_key"],
+    methods=[
+        "__contains__",
+        "__iter__",
+        "add",
+        "add_all",
+        "delete",
+        "expire",
+        "expire_all",
+        "expunge",
+        "expunge_all",
+        "get_bind",
+        "is_modified",
+    ],
+    attributes=[
+        "dirty",
+        "deleted",
+        "new",
+        "identity_map",
+        "is_active",
+        "autoflush",
+        "no_autoflush",
+        "info",
+    ],
+)
 class AsyncSession:
     """Asyncio version of :class:`_orm.Session`.
 
@@ -23,6 +51,16 @@ class AsyncSession:
 
     """
 
+    __slots__ = (
+        "binds",
+        "bind",
+        "sync_session",
+        "_proxied",
+        "_slots_dispatch",
+    )
+
+    dispatch = None
+
     def __init__(
         self,
         bind: AsyncEngine = None,
@@ -31,46 +69,18 @@ class AsyncSession:
     ):
         kw["future"] = True
         if bind:
+            self.bind = engine
             bind = engine._get_sync_engine(bind)
 
         if binds:
+            self.binds = binds
             binds = {
                 key: engine._get_sync_engine(b) for key, b in binds.items()
             }
 
-        self.sync_session = Session(bind=bind, binds=binds, **kw)
-
-    def add(self, instance: object) -> None:
-        """Place an object in this :class:`_asyncio.AsyncSession`.
-
-        .. seealso::
-
-            :meth:`_orm.Session.add`
-
-        """
-        self.sync_session.add(instance)
-
-    def add_all(self, instances: List[object]) -> None:
-        """Add the given collection of instances to this
-        :class:`_asyncio.AsyncSession`."""
-
-        self.sync_session.add_all(instances)
-
-    def expire_all(self):
-        """Expires all persistent instances within this Session.
-
-        See :meth:`_orm.Session.expire_all` for usage details.
-
-        """
-        self.sync_session.expire_all()
-
-    def expire(self, instance, attribute_names=None):
-        """Expire the attributes on an instance.
-
-        See :meth:`._orm.Session.expire` for usage details.
-
-        """
-        self.sync_session.expire()
+        self.sync_session = self._proxied = Session(
+            bind=bind, binds=binds, **kw
+        )
 
     async def refresh(
         self, instance, attribute_names=None, with_for_update=None
@@ -178,8 +188,17 @@ class AsyncSession:
         :class:`.Session` object's transactional state.
 
         """
+
+        # POSSIBLY TODO: here, we see that the sync engine / connection
+        # that are generated from AsyncEngine / AsyncConnection don't
+        # provide any backlink from those sync objects back out to the
+        # async ones.   it's not *too* big a deal since AsyncEngine/Connection
+        # are just proxies and all the state is actually in the sync
+        # version of things.  However!  it has to stay that way :)
         sync_connection = await greenlet_spawn(self.sync_session.connection)
-        return engine.AsyncConnection(sync_connection.engine, sync_connection)
+        return engine.AsyncConnection(
+            engine.AsyncEngine(sync_connection.engine), sync_connection
+        )
 
     def begin(self, **kw):
         """Return an :class:`_asyncio.AsyncSessionTransaction` object.
@@ -218,14 +237,22 @@ class AsyncSession:
         return AsyncSessionTransaction(self, nested=True)
 
     async def rollback(self):
+        """Rollback the current transaction in progress."""
         return await greenlet_spawn(self.sync_session.rollback)
 
     async def commit(self):
+        """Commit the current transaction in progress."""
         return await greenlet_spawn(self.sync_session.commit)
 
     async def close(self):
+        """Close this :class:`_asyncio.AsyncSession`."""
         return await greenlet_spawn(self.sync_session.close)
 
+    @classmethod
+    async def close_all(self):
+        """Close all :class:`_asyncio.AsyncSession` sessions."""
+        return await greenlet_spawn(self.sync_session.close_all)
+
     async def __aenter__(self):
         return self
 
index 29a509cb9a7ed3bcfa501778ad7671e2f55a7f1d..4e11ebb8c6438672601d2af3a1e2f0895bc8e85b 100644 (file)
@@ -1371,7 +1371,8 @@ class SessionEvents(event.Events):
         elif isinstance(target, Session):
             return target
         else:
-            return None
+            # allows alternate SessionEvents-like-classes to be consulted
+            return event.Events._accept_with(target)
 
     @classmethod
     def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
index 1090501ca1ca9f0802d223e93962696a02b29ce1..29d845c0a7cbb7ea8be124a52887ae37c54733ff 100644 (file)
@@ -9,14 +9,60 @@ from . import class_mapper
 from . import exc as orm_exc
 from .session import Session
 from .. import exc as sa_exc
+from ..util import create_proxy_methods
 from ..util import ScopedRegistry
 from ..util import ThreadLocalRegistry
 from ..util import warn
 
-
 __all__ = ["scoped_session"]
 
 
+@create_proxy_methods(
+    Session,
+    ":class:`_orm.Session`",
+    ":class:`_orm.scoping.scoped_session`",
+    classmethods=["close_all", "object_session", "identity_key"],
+    methods=[
+        "__contains__",
+        "__iter__",
+        "add",
+        "add_all",
+        "begin",
+        "begin_nested",
+        "close",
+        "commit",
+        "connection",
+        "delete",
+        "execute",
+        "expire",
+        "expire_all",
+        "expunge",
+        "expunge_all",
+        "flush",
+        "get_bind",
+        "is_modified",
+        "bulk_save_objects",
+        "bulk_insert_mappings",
+        "bulk_update_mappings",
+        "merge",
+        "query",
+        "refresh",
+        "rollback",
+        "scalar",
+    ],
+    attributes=[
+        "bind",
+        "dirty",
+        "deleted",
+        "new",
+        "identity_map",
+        "is_active",
+        "autoflush",
+        "no_autoflush",
+        "info",
+        "autocommit",
+    ],
+)
 class scoped_session(object):
     """Provides scoped management of :class:`.Session` objects.
 
@@ -53,6 +99,10 @@ class scoped_session(object):
         else:
             self.registry = ThreadLocalRegistry(session_factory)
 
+    @property
+    def _proxied(self):
+        return self.registry()
+
     def __call__(self, **kw):
         r"""Return the current :class:`.Session`, creating it
         using the :attr:`.scoped_session.session_factory` if not present.
@@ -156,50 +206,3 @@ class scoped_session(object):
 
 ScopedSession = scoped_session
 """Old name for backwards compatibility."""
-
-
-def instrument(name):
-    def do(self, *args, **kwargs):
-        return getattr(self.registry(), name)(*args, **kwargs)
-
-    return do
-
-
-for meth in Session.public_methods:
-    setattr(scoped_session, meth, instrument(meth))
-
-
-def makeprop(name):
-    def set_(self, attr):
-        setattr(self.registry(), name, attr)
-
-    def get(self):
-        return getattr(self.registry(), name)
-
-    return property(get, set_)
-
-
-for prop in (
-    "bind",
-    "dirty",
-    "deleted",
-    "new",
-    "identity_map",
-    "is_active",
-    "autoflush",
-    "no_autoflush",
-    "info",
-    "autocommit",
-):
-    setattr(scoped_session, prop, makeprop(prop))
-
-
-def clslevel(name):
-    def do(cls, *args, **kwargs):
-        return getattr(Session, name)(*args, **kwargs)
-
-    return classmethod(do)
-
-
-for prop in ("close_all", "object_session", "identity_key"):
-    setattr(scoped_session, prop, clslevel(prop))
index e32e055103e989a7cba50a6c5b16889ccef2a6a9..af0ac63e0d26510eb55646309f61a2dfcfec7ce0 100644 (file)
@@ -835,35 +835,6 @@ class Session(_SessionClassMethods):
 
     """
 
-    public_methods = (
-        "__contains__",
-        "__iter__",
-        "add",
-        "add_all",
-        "begin",
-        "begin_nested",
-        "close",
-        "commit",
-        "connection",
-        "delete",
-        "execute",
-        "expire",
-        "expire_all",
-        "expunge",
-        "expunge_all",
-        "flush",
-        "get_bind",
-        "is_modified",
-        "bulk_save_objects",
-        "bulk_insert_mappings",
-        "bulk_update_mappings",
-        "merge",
-        "query",
-        "refresh",
-        "rollback",
-        "scalar",
-    )
-
     @util.deprecated_params(
         autocommit=(
             "2.0",
@@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods):
             will unexpire attributes on access.
 
         """
-        state = attributes.instance_state(obj)
+        try:
+            state = attributes.instance_state(obj)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(obj),
+                replace_context=err,
+            )
+
         to_attach = self._before_attach(state, obj)
         state._load_pending = True
         if to_attach:
index 87383fef717a6a1d76b44e4d625d6eb405a9f431..68fa5fe85a778a210c3ef5f96b27266202aa12f4 100644 (file)
@@ -509,6 +509,7 @@ class _ConnectionRecord(object):
                 "Soft " if soft else "",
                 self.connection,
             )
+
         if soft:
             self._soft_invalidate_time = time.time()
         else:
index dfefd3b95f9d580fe6cb3f9720c41b9c8ab0059b..644ea6dc20d0a3c4fecf33e9a6773914213b254d 100644 (file)
@@ -372,7 +372,7 @@ def _pytest_fn_decorator(target):
         if add_positional_parameters:
             spec.args.extend(add_positional_parameters)
 
-        metadata = dict(target="target", fn="fn", name=fn.__name__)
+        metadata = dict(target="target", fn="__fn", name=fn.__name__)
         metadata.update(format_argspec_plus(spec, grouped=False))
         code = (
             """\
@@ -382,7 +382,7 @@ def %(name)s(%(args)s):
             % metadata
         )
         decorated = _exec_code_in_env(
-            code, {"target": target, "fn": fn}, fn.__name__
+            code, {"target": target, "__fn": fn}, fn.__name__
         )
         if not add_positional_parameters:
             decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
index 8ef2f010321879f0cce341cc04c63a6f2171f096..885f62f97b329ea6857d2d7d6c7157faf380bafd 100644 (file)
@@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type  # noqa
 from .langhelpers import constructor_copy  # noqa
 from .langhelpers import constructor_key  # noqa
 from .langhelpers import counter  # noqa
+from .langhelpers import create_proxy_methods  # noqa
 from .langhelpers import decode_slice  # noqa
 from .langhelpers import decorator  # noqa
 from .langhelpers import dictlike_iteritems  # noqa
index e546f196d511ef2ff10333e705d2931b33e693e1..4289db81295eeed22ac8510957955e683dc46e3b 100644 (file)
@@ -9,6 +9,7 @@
 modules, classes, hierarchies, attributes, functions, and methods.
 
 """
+
 from functools import update_wrapper
 import hashlib
 import inspect
@@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True):
       passed positionally.
     apply_kw
       Like apply_pos, except keyword-ish args are passed as keywords.
+    apply_pos_proxied
+      Like apply_pos but omits the self/cls argument
 
     Example::
 
@@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True):
         spec = fn
 
     args = compat.inspect_formatargspec(*spec)
+
+    apply_pos = compat.inspect_formatargspec(
+        spec[0], spec[1], spec[2], None, spec[4]
+    )
+
     if spec[0]:
         self_arg = spec[0][0]
+
+        apply_pos_proxied = compat.inspect_formatargspec(
+            spec[0][1:], spec[1], spec[2], None, spec[4]
+        )
+
     elif spec[1]:
+        # im not sure what this is
         self_arg = "%s[0]" % spec[1]
+
+        apply_pos_proxied = apply_pos
     else:
         self_arg = None
+        apply_pos_proxied = apply_pos
 
-    apply_pos = compat.inspect_formatargspec(
-        spec[0], spec[1], spec[2], None, spec[4]
-    )
     num_defaults = 0
     if spec[3]:
         num_defaults += len(spec[3])
@@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True):
             self_arg=self_arg,
             apply_pos=apply_pos,
             apply_kw=apply_kw,
+            apply_pos_proxied=apply_pos_proxied,
         )
     else:
         return dict(
@@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True):
             self_arg=self_arg,
             apply_pos=apply_pos[1:-1],
             apply_kw=apply_kw[1:-1],
+            apply_pos_proxied=apply_pos_proxied[1:-1],
         )
 
 
@@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True):
 
     """
     if method is object.__init__:
-        args = grouped and "(self)" or "self"
+        args = "(self)" if grouped else "self"
+        proxied = "()" if grouped else ""
     else:
         try:
             return format_argspec_plus(method, grouped=grouped)
         except TypeError:
             args = (
-                grouped
-                and "(self, *args, **kwargs)"
-                or "self, *args, **kwargs"
+                "(self, *args, **kwargs)"
+                if grouped
+                else "self, *args, **kwargs"
             )
-    return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
+            proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
+    return dict(
+        self_arg="self",
+        args=args,
+        apply_pos=args,
+        apply_kw=args,
+        apply_pos_proxied=proxied,
+    )
+
+
+def create_proxy_methods(
+    target_cls,
+    target_cls_sphinx_name,
+    proxy_cls_sphinx_name,
+    classmethods=(),
+    methods=(),
+    attributes=(),
+):
+    """A class decorator that will copy attributes to a proxy class.
+
+    The class to be instrumented must define a single accessor "_proxied".
+
+    """
+
+    def decorate(cls):
+        def instrument(name, clslevel=False):
+            fn = getattr(target_cls, name)
+            spec = compat.inspect_getfullargspec(fn)
+            env = {}
+
+            spec = _update_argspec_defaults_into_env(spec, env)
+            caller_argspec = format_argspec_plus(spec, grouped=False)
+
+            metadata = {
+                "name": fn.__name__,
+                "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+                "args": caller_argspec["args"],
+                "self_arg": caller_argspec["self_arg"],
+            }
+
+            if clslevel:
+                code = (
+                    "def %(name)s(%(args)s):\n"
+                    "    return target_cls.%(name)s(%(apply_pos_proxied)s)"
+                    % metadata
+                )
+                env["target_cls"] = target_cls
+            else:
+                code = (
+                    "def %(name)s(%(args)s):\n"
+                    "    return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)"  # noqa E501
+                    % metadata
+                )
+
+            proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+            proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+            proxy_fn.__doc__ = inject_docstring_text(
+                fn.__doc__,
+                ".. container:: class_bases\n\n    "
+                "Proxied for the %s class on behalf of the %s class."
+                % (target_cls_sphinx_name, proxy_cls_sphinx_name),
+                1,
+            )
+
+            if clslevel:
+                proxy_fn = classmethod(proxy_fn)
+
+            return proxy_fn
+
+        def makeprop(name):
+            attr = target_cls.__dict__.get(name, None)
+
+            if attr is not None:
+                doc = inject_docstring_text(
+                    attr.__doc__,
+                    ".. container:: class_bases\n\n    "
+                    "Proxied for the %s class on behalf of the %s class."
+                    % (
+                        target_cls_sphinx_name,
+                        proxy_cls_sphinx_name,
+                    ),
+                    1,
+                )
+            else:
+                doc = None
+
+            code = (
+                "def set_(self, attr):\n"
+                "    self._proxied.%(name)s = attr\n"
+                "def get(self):\n"
+                "    return self._proxied.%(name)s\n"
+                "get.__doc__ = doc\n"
+                "getset = property(get, set_)"
+            ) % {"name": name}
+
+            getset = _exec_code_in_env(code, {"doc": doc}, "getset")
+
+            return getset
+
+        for meth in methods:
+            if hasattr(cls, meth):
+                raise TypeError(
+                    "class %s already has a method %s" % (cls, meth)
+                )
+            setattr(cls, meth, instrument(meth))
+
+        for prop in attributes:
+            if hasattr(cls, prop):
+                raise TypeError(
+                    "class %s already has a method %s" % (cls, prop)
+                )
+            setattr(cls, prop, makeprop(prop))
+
+        for prop in classmethods:
+            if hasattr(cls, prop):
+                raise TypeError(
+                    "class %s already has a method %s" % (cls, prop)
+                )
+            setattr(cls, prop, instrument(prop, clslevel=True))
+
+        return cls
+
+    return decorate
 
 
 def getargspec_init(method):
index a4ed1000b8b8861261630d0042fae25efa26de34..19f68e9a3509ba9a6dc0354ed752b8bfa5a514b7 100644 (file)
@@ -15,7 +15,19 @@ from sqlalchemy.testing.mock import Mock
 from sqlalchemy.testing.util import gc_collect
 
 
-class EventsTest(fixtures.TestBase):
+class TearDownLocalEventsFixture(object):
+    def tearDown(self):
+        classes = set()
+        for entry in event.base._registrars.values():
+            for evt_cls in entry:
+                if evt_cls.__module__ == __name__:
+                    classes.add(evt_cls)
+
+        for evt_cls in classes:
+            event.base._remove_dispatcher(evt_cls)
+
+
+class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test class- and instance-level event registration."""
 
     def setUp(self):
@@ -34,9 +46,6 @@ class EventsTest(fixtures.TestBase):
 
         self.Target = Target
 
-    def tearDown(self):
-        event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
     def test_register_class(self):
         def listen(x, y):
             pass
@@ -258,7 +267,60 @@ class EventsTest(fixtures.TestBase):
             )
 
 
-class NamedCallTest(fixtures.TestBase):
+class SlotsEventsTest(fixtures.TestBase):
+    @testing.requires.python3
+    def test_no_slots_dispatch(self):
+        class Target(object):
+            __slots__ = ()
+
+        class TargetEvents(event.Events):
+            _dispatch_target = Target
+
+            def event_one(self, x, y):
+                pass
+
+            def event_two(self, x):
+                pass
+
+            def event_three(self, x):
+                pass
+
+        t1 = Target()
+
+        with testing.expect_raises_message(
+            TypeError,
+            r"target .*Target.* doesn't have __dict__, should it "
+            "be defining _slots_dispatch",
+        ):
+            event.listen(t1, "event_one", Mock())
+
+    def test_slots_dispatch(self):
+        class Target(object):
+            __slots__ = ("_slots_dispatch",)
+
+        class TargetEvents(event.Events):
+            _dispatch_target = Target
+
+            def event_one(self, x, y):
+                pass
+
+            def event_two(self, x):
+                pass
+
+            def event_three(self, x):
+                pass
+
+        t1 = Target()
+
+        m1 = Mock()
+        event.listen(t1, "event_one", m1)
+
+        t1.dispatch.event_one(2, 4)
+
+        eq_(m1.mock_calls, [call(2, 4)])
+
+
+class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def _fixture(self):
         class TargetEventsOne(event.Events):
             def event_one(self, x, y):
@@ -373,7 +435,7 @@ class NamedCallTest(fixtures.TestBase):
         eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})])
 
 
-class LegacySignatureTest(fixtures.TestBase):
+class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """test adaption of legacy args"""
 
     def setUp(self):
@@ -397,11 +459,6 @@ class LegacySignatureTest(fixtures.TestBase):
 
         self.TargetOne = TargetOne
 
-    def tearDown(self):
-        event.base._remove_dispatcher(
-            self.TargetOne.__dict__["dispatch"].events
-        )
-
     def test_legacy_accept(self):
         canary = Mock()
 
@@ -550,12 +607,7 @@ class LegacySignatureTest(fixtures.TestBase):
         )
 
 
-class ClsLevelListenTest(fixtures.TestBase):
-    def tearDown(self):
-        event.base._remove_dispatcher(
-            self.TargetOne.__dict__["dispatch"].events
-        )
-
+class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def setUp(self):
         class TargetEventsOne(event.Events):
             def event_one(self, x, y):
@@ -622,7 +674,7 @@ class ClsLevelListenTest(fixtures.TestBase):
         assert handler2 not in s2.dispatch.event_one
 
 
-class AcceptTargetsTest(fixtures.TestBase):
+class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test default target acceptance."""
 
     def setUp(self):
@@ -643,14 +695,6 @@ class AcceptTargetsTest(fixtures.TestBase):
         self.TargetOne = TargetOne
         self.TargetTwo = TargetTwo
 
-    def tearDown(self):
-        event.base._remove_dispatcher(
-            self.TargetOne.__dict__["dispatch"].events
-        )
-        event.base._remove_dispatcher(
-            self.TargetTwo.__dict__["dispatch"].events
-        )
-
     def test_target_accept(self):
         """Test that events of the same name are routed to the correct
         collection based on the type of target given.
@@ -687,7 +731,7 @@ class AcceptTargetsTest(fixtures.TestBase):
         eq_(list(t2.dispatch.event_one), [listen_two, listen_four])
 
 
-class CustomTargetsTest(fixtures.TestBase):
+class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test custom target acceptance."""
 
     def setUp(self):
@@ -707,9 +751,6 @@ class CustomTargetsTest(fixtures.TestBase):
 
         self.Target = Target
 
-    def tearDown(self):
-        event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
     def test_indirect(self):
         def listen(x, y):
             pass
@@ -727,7 +768,7 @@ class CustomTargetsTest(fixtures.TestBase):
         )
 
 
-class SubclassGrowthTest(fixtures.TestBase):
+class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """test that ad-hoc subclasses are garbage collected."""
 
     def setUp(self):
@@ -752,7 +793,7 @@ class SubclassGrowthTest(fixtures.TestBase):
         eq_(self.Target.__subclasses__(), [])
 
 
-class ListenOverrideTest(fixtures.TestBase):
+class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test custom listen functions which change the listener function
     signature."""
 
@@ -778,9 +819,6 @@ class ListenOverrideTest(fixtures.TestBase):
 
         self.Target = Target
 
-    def tearDown(self):
-        event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
     def test_listen_override(self):
         listen_one = Mock()
         listen_two = Mock()
@@ -816,7 +854,7 @@ class ListenOverrideTest(fixtures.TestBase):
         eq_(listen_one.mock_calls, [call(12)])
 
 
-class PropagateTest(fixtures.TestBase):
+class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def setUp(self):
         class TargetEvents(event.Events):
             def event_one(self, arg):
@@ -850,7 +888,7 @@ class PropagateTest(fixtures.TestBase):
         eq_(listen_two.mock_calls, [])
 
 
-class JoinTest(fixtures.TestBase):
+class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def setUp(self):
         class TargetEvents(event.Events):
             def event_one(self, target, arg):
@@ -875,11 +913,6 @@ class JoinTest(fixtures.TestBase):
         self.TargetFactory = TargetFactory
         self.TargetElement = TargetElement
 
-    def tearDown(self):
-        for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget):
-            if "dispatch" in cls.__dict__:
-                event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
-
     def test_neither(self):
         element = self.TargetFactory().create()
         element.run_event(1)
@@ -1075,7 +1108,7 @@ class JoinTest(fixtures.TestBase):
         )
 
 
-class DisableClsPropagateTest(fixtures.TestBase):
+class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def setUp(self):
         class TargetEvents(event.Events):
             def event_one(self, target, arg):
@@ -1093,11 +1126,6 @@ class DisableClsPropagateTest(fixtures.TestBase):
         self.BaseTarget = BaseTarget
         self.SubTarget = SubTarget
 
-    def tearDown(self):
-        for cls in (self.SubTarget, self.BaseTarget):
-            if "dispatch" in cls.__dict__:
-                event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
-
     def test_listen_invoke_clslevel(self):
         canary = Mock()
 
@@ -1132,7 +1160,7 @@ class DisableClsPropagateTest(fixtures.TestBase):
         eq_(canary.mock_calls, [])
 
 
-class RemovalTest(fixtures.TestBase):
+class RemovalTest(TearDownLocalEventsFixture, fixtures.TestBase):
     def _fixture(self):
         class TargetEvents(event.Events):
             def event_one(self, x, y):
index fa347243e7555127c47f417fa6158e32b1280a31..9220720b679ed4663b1cfb9f42260cf1aed23671 100644 (file)
@@ -2300,7 +2300,14 @@ class SymbolTest(fixtures.TestBase):
 
 
 class _Py3KFixtures(object):
-    pass
+    def _kw_only_fixture(self):
+        pass
+
+    def _kw_plus_posn_fixture(self):
+        pass
+
+    def _kw_opt_fixture(self):
+        pass
 
 
 if util.py3k:
@@ -2321,185 +2328,208 @@ def _kw_opt_fixture(self, a, *, b, c="c"):
     for k in _locals:
         setattr(_Py3KFixtures, k, _locals[k])
 
+py3k_fixtures = _Py3KFixtures()
 
-class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
-    def _test_format_argspec_plus(self, fn, wanted, grouped=None):
-
-        # test direct function
-        if grouped is None:
-            parsed = util.format_argspec_plus(fn)
-        else:
-            parsed = util.format_argspec_plus(fn, grouped=grouped)
-        eq_(parsed, wanted)
-
-        # test sending fullargspec
-        spec = compat.inspect_getfullargspec(fn)
-        if grouped is None:
-            parsed = util.format_argspec_plus(spec)
-        else:
-            parsed = util.format_argspec_plus(spec, grouped=grouped)
-        eq_(parsed, wanted)
 
-    def test_specs(self):
-        self._test_format_argspec_plus(
+class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
+    @testing.combinations(
+        (
             lambda: None,
             {
                 "args": "()",
                 "self_arg": None,
                 "apply_kw": "()",
                 "apply_pos": "()",
+                "apply_pos_proxied": "()",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda: None,
-            {"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""},
-            grouped=False,
-        )
-
-        self._test_format_argspec_plus(
+            {
+                "args": "",
+                "self_arg": None,
+                "apply_kw": "",
+                "apply_pos": "",
+                "apply_pos_proxied": "",
+            },
+            False,
+        ),
+        (
             lambda self: None,
             {
                 "args": "(self)",
                 "self_arg": "self",
                 "apply_kw": "(self)",
                 "apply_pos": "(self)",
+                "apply_pos_proxied": "()",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda self: None,
             {
                 "args": "self",
                 "self_arg": "self",
                 "apply_kw": "self",
                 "apply_pos": "self",
+                "apply_pos_proxied": "",
             },
-            grouped=False,
-        )
-
-        self._test_format_argspec_plus(
+            False,
+        ),
+        (
             lambda *a: None,
             {
                 "args": "(*a)",
                 "self_arg": "a[0]",
                 "apply_kw": "(*a)",
                 "apply_pos": "(*a)",
+                "apply_pos_proxied": "(*a)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda **kw: None,
             {
                 "args": "(**kw)",
                 "self_arg": None,
                 "apply_kw": "(**kw)",
                 "apply_pos": "(**kw)",
+                "apply_pos_proxied": "(**kw)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda *a, **kw: None,
             {
                 "args": "(*a, **kw)",
                 "self_arg": "a[0]",
                 "apply_kw": "(*a, **kw)",
                 "apply_pos": "(*a, **kw)",
+                "apply_pos_proxied": "(*a, **kw)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a, *b: None,
             {
                 "args": "(a, *b)",
                 "self_arg": "a",
                 "apply_kw": "(a, *b)",
                 "apply_pos": "(a, *b)",
+                "apply_pos_proxied": "(*b)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a, **b: None,
             {
                 "args": "(a, **b)",
                 "self_arg": "a",
                 "apply_kw": "(a, **b)",
                 "apply_pos": "(a, **b)",
+                "apply_pos_proxied": "(**b)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a, *b, **c: None,
             {
                 "args": "(a, *b, **c)",
                 "self_arg": "a",
                 "apply_kw": "(a, *b, **c)",
                 "apply_pos": "(a, *b, **c)",
+                "apply_pos_proxied": "(*b, **c)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a, b=1, **c: None,
             {
                 "args": "(a, b=1, **c)",
                 "self_arg": "a",
                 "apply_kw": "(a, b=b, **c)",
                 "apply_pos": "(a, b, **c)",
+                "apply_pos_proxied": "(b, **c)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a=1, b=2: None,
             {
                 "args": "(a=1, b=2)",
                 "self_arg": "a",
                 "apply_kw": "(a=a, b=b)",
                 "apply_pos": "(a, b)",
+                "apply_pos_proxied": "(b)",
             },
-        )
-
-        self._test_format_argspec_plus(
+            True,
+        ),
+        (
             lambda a=1, b=2: None,
             {
                 "args": "a=1, b=2",
                 "self_arg": "a",
                 "apply_kw": "a=a, b=b",
                 "apply_pos": "a, b",
+                "apply_pos_proxied": "b",
             },
-            grouped=False,
-        )
-
-        if util.py3k:
-            self._test_format_argspec_plus(
-                self._kw_only_fixture,
-                {
-                    "args": "self, a, *, b, c",
-                    "self_arg": "self",
-                    "apply_pos": "self, a, *, b, c",
-                    "apply_kw": "self, a, b=b, c=c",
-                },
-                grouped=False,
-            )
-            self._test_format_argspec_plus(
-                self._kw_plus_posn_fixture,
-                {
-                    "args": "self, a, *args, b, c",
-                    "self_arg": "self",
-                    "apply_pos": "self, a, *args, b, c",
-                    "apply_kw": "self, a, b=b, c=c, *args",
-                },
-                grouped=False,
-            )
-            self._test_format_argspec_plus(
-                self._kw_opt_fixture,
-                {
-                    "args": "self, a, *, b, c='c'",
-                    "self_arg": "self",
-                    "apply_pos": "self, a, *, b, c",
-                    "apply_kw": "self, a, b=b, c=c",
-                },
-                grouped=False,
-            )
+            False,
+        ),
+        (
+            py3k_fixtures._kw_only_fixture,
+            {
+                "args": "self, a, *, b, c",
+                "self_arg": "self",
+                "apply_pos": "self, a, *, b, c",
+                "apply_kw": "self, a, b=b, c=c",
+                "apply_pos_proxied": "a, *, b, c",
+            },
+            False,
+            testing.requires.python3,
+        ),
+        (
+            py3k_fixtures._kw_plus_posn_fixture,
+            {
+                "args": "self, a, *args, b, c",
+                "self_arg": "self",
+                "apply_pos": "self, a, *args, b, c",
+                "apply_kw": "self, a, b=b, c=c, *args",
+                "apply_pos_proxied": "a, *args, b, c",
+            },
+            False,
+            testing.requires.python3,
+        ),
+        (
+            py3k_fixtures._kw_opt_fixture,
+            {
+                "args": "self, a, *, b, c='c'",
+                "self_arg": "self",
+                "apply_pos": "self, a, *, b, c",
+                "apply_kw": "self, a, b=b, c=c",
+                "apply_pos_proxied": "a, *, b, c",
+            },
+            False,
+            testing.requires.python3,
+        ),
+        argnames="fn,wanted,grouped",
+    )
+    def test_specs(self, fn, wanted, grouped):
+
+        # test direct function
+        if grouped is None:
+            parsed = util.format_argspec_plus(fn)
+        else:
+            parsed = util.format_argspec_plus(fn, grouped=grouped)
+        eq_(parsed, wanted)
+
+        # test sending fullargspec
+        spec = compat.inspect_getfullargspec(fn)
+        if grouped is None:
+            parsed = util.format_argspec_plus(spec)
+        else:
+            parsed = util.format_argspec_plus(spec, grouped=grouped)
+        eq_(parsed, wanted)
 
     @testing.requires.cpython
     def test_init_grouped(self):
@@ -2508,17 +2538,20 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
             "self_arg": "self",
             "apply_pos": "(self)",
             "apply_kw": "(self)",
+            "apply_pos_proxied": "()",
         }
         wrapper_spec = {
             "args": "(self, *args, **kwargs)",
             "self_arg": "self",
             "apply_pos": "(self, *args, **kwargs)",
             "apply_kw": "(self, *args, **kwargs)",
+            "apply_pos_proxied": "(*args, **kwargs)",
         }
         custom_spec = {
             "args": "(slef, a=123)",
             "self_arg": "slef",  # yes, slef
             "apply_pos": "(slef, a)",
+            "apply_pos_proxied": "(a)",
             "apply_kw": "(slef, a=a)",
         }
 
@@ -2532,18 +2565,21 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
             "self_arg": "self",
             "apply_pos": "self",
             "apply_kw": "self",
+            "apply_pos_proxied": "",
         }
         wrapper_spec = {
             "args": "self, *args, **kwargs",
             "self_arg": "self",
             "apply_pos": "self, *args, **kwargs",
             "apply_kw": "self, *args, **kwargs",
+            "apply_pos_proxied": "*args, **kwargs",
         }
         custom_spec = {
             "args": "slef, a=123",
             "self_arg": "slef",  # yes, slef
             "apply_pos": "slef, a",
             "apply_kw": "slef, a=a",
+            "apply_pos_proxied": "a",
         }
 
         self._test_init(False, object_spec, wrapper_spec, custom_spec)
index 7c7d90e2175a97f22f842e873a1dff810b94f82b..83987b06f1173e269397d1083b54acd30b5fe29e 100644 (file)
@@ -2,6 +2,7 @@ import asyncio
 
 from sqlalchemy import Column
 from sqlalchemy import delete
+from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import Integer
@@ -9,13 +10,19 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy import text
 from sqlalchemy import union_all
 from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio import engine as _async_engine
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not
+from sqlalchemy.testing import mock
 from sqlalchemy.testing.asyncio import assert_raises_message_async
+from sqlalchemy.util.concurrency import greenlet_spawn
 
 
 class EngineFixture(fixtures.TablesTest):
@@ -50,6 +57,117 @@ class EngineFixture(fixtures.TablesTest):
 class AsyncEngineTest(EngineFixture):
     __backend__ = True
 
+    def test_proxied_attrs_engine(self, async_engine):
+        sync_engine = async_engine.sync_engine
+
+        is_(async_engine.url, sync_engine.url)
+        is_(async_engine.pool, sync_engine.pool)
+        is_(async_engine.dialect, sync_engine.dialect)
+        eq_(async_engine.name, sync_engine.name)
+        eq_(async_engine.driver, sync_engine.driver)
+        eq_(async_engine.echo, sync_engine.echo)
+
+    def test_clear_compiled_cache(self, async_engine):
+        async_engine.sync_engine._compiled_cache["foo"] = "bar"
+        eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
+        async_engine.clear_compiled_cache()
+        assert "foo" not in async_engine.sync_engine._compiled_cache
+
+    def test_execution_options(self, async_engine):
+        a2 = async_engine.execution_options(foo="bar")
+        assert isinstance(a2, _async_engine.AsyncEngine)
+        eq_(a2.sync_engine._execution_options, {"foo": "bar"})
+        eq_(async_engine.sync_engine._execution_options, {})
+
+        """
+
+            attr uri, pool, dialect, engine, name, driver, echo
+            methods clear_compiled_cache, update_execution_options,
+            execution_options, get_execution_options, dispose
+
+        """
+
+    @async_test
+    async def test_proxied_attrs_connection(self, async_engine):
+        conn = await async_engine.connect()
+
+        sync_conn = conn.sync_connection
+
+        is_(conn.engine, async_engine)
+        is_(conn.closed, sync_conn.closed)
+        is_(conn.dialect, async_engine.sync_engine.dialect)
+        eq_(conn.default_isolation_level, sync_conn.default_isolation_level)
+
+    @async_test
+    async def test_invalidate(self, async_engine):
+        conn = await async_engine.connect()
+
+        is_(conn.invalidated, False)
+
+        connection_fairy = await conn.get_raw_connection()
+        is_(connection_fairy.is_valid, True)
+        dbapi_connection = connection_fairy.connection
+
+        await conn.invalidate()
+        assert dbapi_connection._connection.is_closed()
+
+        new_fairy = await conn.get_raw_connection()
+        is_not(new_fairy.connection, dbapi_connection)
+        is_not(new_fairy, connection_fairy)
+        is_(new_fairy.is_valid, True)
+        is_(connection_fairy.is_valid, False)
+
+    @async_test
+    async def test_get_dbapi_connection_raise(self, async_engine):
+
+        conn = await async_engine.connect()
+
+        with testing.expect_raises_message(
+            exc.InvalidRequestError,
+            "AsyncConnection.connection accessor is not "
+            "implemented as the attribute",
+        ):
+            conn.connection
+
+    @async_test
+    async def test_get_raw_connection(self, async_engine):
+
+        conn = await async_engine.connect()
+
+        pooled = await conn.get_raw_connection()
+        is_(pooled, conn.sync_connection.connection)
+
+    @async_test
+    async def test_isolation_level(self, async_engine):
+        conn = await async_engine.connect()
+
+        sync_isolation_level = await greenlet_spawn(
+            conn.sync_connection.get_isolation_level
+        )
+        isolation_level = await conn.get_isolation_level()
+
+        eq_(isolation_level, sync_isolation_level)
+
+        await conn.execution_options(isolation_level="SERIALIZABLE")
+        isolation_level = await conn.get_isolation_level()
+
+        eq_(isolation_level, "SERIALIZABLE")
+
+    @async_test
+    async def test_dispose(self, async_engine):
+        c1 = await async_engine.connect()
+        c2 = await async_engine.connect()
+
+        await c1.close()
+        await c2.close()
+
+        p1 = async_engine.pool
+        eq_(async_engine.pool.checkedin(), 2)
+
+        await async_engine.dispose()
+        eq_(async_engine.pool.checkedin(), 0)
+        is_not(p1, async_engine.pool)
+
     @async_test
     async def test_init_once_concurrency(self, async_engine):
         c1 = async_engine.connect()
@@ -169,6 +287,70 @@ class AsyncEngineTest(EngineFixture):
         )
 
 
+class AsyncEventTest(EngineFixture):
+    """The engine events all run in their normal synchronous context.
+
+    we do not provide an asyncio event interface at this time.
+
+    """
+
+    __backend__ = True
+
+    @async_test
+    async def test_no_async_listeners(self, async_engine):
+        with testing.expect_raises_message(
+            NotImplementedError,
+            "asynchronous events are not implemented "
+            "at this time.  Apply synchronous listeners to the "
+            "AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes.",
+        ):
+            event.listen(async_engine, "before_cursor_execute", mock.Mock())
+
+        conn = await async_engine.connect()
+
+        with testing.expect_raises_message(
+            NotImplementedError,
+            "asynchronous events are not implemented "
+            "at this time.  Apply synchronous listeners to the "
+            "AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes.",
+        ):
+            event.listen(conn, "before_cursor_execute", mock.Mock())
+
+    @async_test
+    async def test_sync_before_cursor_execute_engine(self, async_engine):
+        canary = mock.Mock()
+
+        event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
+
+        async with async_engine.connect() as conn:
+            sync_conn = conn.sync_connection
+            await conn.execute(text("select 1"))
+
+        eq_(
+            canary.mock_calls,
+            [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+        )
+
+    @async_test
+    async def test_sync_before_cursor_execute_connection(self, async_engine):
+        canary = mock.Mock()
+
+        async with async_engine.connect() as conn:
+            sync_conn = conn.sync_connection
+
+            event.listen(
+                async_engine.sync_engine, "before_cursor_execute", canary
+            )
+            await conn.execute(text("select 1"))
+
+        eq_(
+            canary.mock_calls,
+            [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+        )
+
+
 class AsyncResultTest(EngineFixture):
     @testing.combinations(
         (None,), ("scalars",), ("mappings",), argnames="filter_"
index e8caaca3e45ee4579fa9ed2e48c36986171c78e0..a3b8add6774e26a8fb264bb56ffd570fb6867dc9 100644 (file)
@@ -1,3 +1,4 @@
+from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import select
@@ -9,6 +10,7 @@ from sqlalchemy.orm import selectinload
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import mock
 from ...orm import _fixtures
 
 
@@ -139,6 +141,27 @@ class AsyncSessionTransactionTest(AsyncFixture):
 
             eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
 
+    @async_test
+    async def test_delete(self, async_session):
+        User = self.classes.User
+
+        async with async_session.begin():
+            u1 = User(name="u1")
+
+            async_session.add(u1)
+
+            await async_session.flush()
+
+            conn = await async_session.connection()
+
+            eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+            async_session.delete(u1)
+
+            await async_session.flush()
+
+            eq_(await conn.scalar(select(func.count(User.id))), 0)
+
     @async_test
     async def test_flush(self, async_session):
         User = self.classes.User
@@ -198,3 +221,38 @@ class AsyncSessionTransactionTest(AsyncFixture):
 
             is_(new_u_merged, u1)
             eq_(u1.name, "new u1")
+
+
+class AsyncEventTest(AsyncFixture):
+    """The engine events all run in their normal synchronous context.
+
+    we do not provide an asyncio event interface at this time.
+
+    """
+
+    __backend__ = True
+
+    @async_test
+    async def test_no_async_listeners(self, async_session):
+        with testing.expect_raises(
+            NotImplementedError,
+            "NotImplementedError: asynchronous events are not implemented "
+            "at this time.  Apply synchronous listeners to the "
+            "AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes.",
+        ):
+            event.listen(async_session, "before_flush", mock.Mock())
+
+    @async_test
+    async def test_sync_before_commit(self, async_session):
+        canary = mock.Mock()
+
+        event.listen(async_session.sync_session, "before_commit", canary)
+
+        async with async_session.begin():
+            pass
+
+        eq_(
+            canary.mock_calls,
+            [mock.call(async_session.sync_session)],
+        )
index 6b7feaea7f200c224eb4e67f45ff89f97002c20b..d1ed9acc16dd47517c39fbcbcf59368acaec96d6 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.orm import scoped_session
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
 from sqlalchemy.testing.mock import Mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -127,3 +128,26 @@ class ScopedSessionTest(fixtures.MappedTest):
         mock_scope_func.return_value = 1
         s2 = Session(autocommit=True)
         assert s2.autocommit == True
+
+    def test_methods_etc(self):
+        mock_session = Mock()
+        mock_session.bind = "the bind"
+
+        sess = scoped_session(lambda: mock_session)
+
+        sess.add("add")
+        sess.delete("delete")
+
+        eq_(sess.bind, "the bind")
+
+        eq_(
+            mock_session.mock_calls,
+            [mock.call.add("add", True), mock.call.delete("delete")],
+        )
+
+        with mock.patch(
+            "sqlalchemy.orm.session.object_session"
+        ) as mock_object_session:
+            sess.object_session("foo")
+
+        eq_(mock_object_session.mock_calls, [mock.call("foo")])
index 9bc6c6f7ce45a6930fe141295fdd1dd6378327f3..4562df44a60289265e81f3e94f69bb4f02943c14 100644 (file)
@@ -1,3 +1,5 @@
+import inspect as _py_inspect
+
 import sqlalchemy as sa
 from sqlalchemy import event
 from sqlalchemy import ForeignKey
@@ -1820,22 +1822,28 @@ class DisposedStates(fixtures.MappedTest):
 class SessionInterface(fixtures.TestBase):
     """Bogus args to Session methods produce actionable exceptions."""
 
-    # TODO: expand with message body assertions.
-
     _class_methods = set(("connection", "execute", "get_bind", "scalar"))
 
     def _public_session_methods(self):
         Session = sa.orm.session.Session
 
-        blacklist = set(("begin", "query"))
-
+        blacklist = {"begin", "query", "bind_mapper", "get", "bind_table"}
+        specials = {"__iter__", "__contains__"}
         ok = set()
-        for meth in Session.public_methods:
-            if meth in blacklist:
-                continue
-            spec = inspect_getfullargspec(getattr(Session, meth))
-            if len(spec[0]) > 1 or spec[1]:
-                ok.add(meth)
+        for name in dir(Session):
+            if (
+                name in Session.__dict__
+                and (not name.startswith("_") or name in specials)
+                and (
+                    _py_inspect.ismethod(getattr(Session, name))
+                    or _py_inspect.isfunction(getattr(Session, name))
+                )
+            ):
+                if name in blacklist:
+                    continue
+                spec = inspect_getfullargspec(getattr(Session, name))
+                if len(spec[0]) > 1 or spec[1]:
+                    ok.add(name)
         return ok
 
     def _map_it(self, cls):
@@ -1866,18 +1874,21 @@ class SessionInterface(fixtures.TestBase):
         def raises_(method, *args, **kw):
             x_raises_(create_session(), method, *args, **kw)
 
-        raises_("__contains__", user_arg)
-
-        raises_("add", user_arg)
+        for name in [
+            "__contains__",
+            "is_modified",
+            "merge",
+            "refresh",
+            "add",
+            "delete",
+            "expire",
+            "expunge",
+            "enable_relationship_loading",
+        ]:
+            raises_(name, user_arg)
 
         raises_("add_all", (user_arg,))
 
-        raises_("delete", user_arg)
-
-        raises_("expire", user_arg)
-
-        raises_("expunge", user_arg)
-
         # flush will no-op without something in the unit of work
         def _():
             class OK(object):
@@ -1891,12 +1902,6 @@ class SessionInterface(fixtures.TestBase):
 
         _()
 
-        raises_("is_modified", user_arg)
-
-        raises_("merge", user_arg)
-
-        raises_("refresh", user_arg)
-
         instance_methods = (
             self._public_session_methods()
             - self._class_methods