From 2665a0c4cb3e94e6545d0b9bbcbcc39ccffebaba Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 8 Oct 2020 15:20:48 -0400 Subject: [PATCH] generalize scoped_session proxying and apply to asyncio elements Reworked the proxy creation used by scoped_session() to be based on fully copied code with augmented docstrings and moved it into langhelpers. asyncio session, engine, connection can now take advantage of it so that all non-async methods are availble. Overall implementation of most important accessors / methods on AsyncConnection, etc. , including awaitable versions of invalidate, execution_options, etc. In order to support an event dispatcher on the async classes while still allowing them to hold __slots__, make some adjustments to the event system to allow that to be present, at least rudimentally. Fixes: #5628 Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2 --- lib/sqlalchemy/engine/base.py | 30 ++- lib/sqlalchemy/event/base.py | 35 ++- lib/sqlalchemy/ext/asyncio/__init__.py | 2 + lib/sqlalchemy/ext/asyncio/engine.py | 153 +++++++++++- lib/sqlalchemy/ext/asyncio/events.py | 29 +++ lib/sqlalchemy/ext/asyncio/session.py | 97 +++++--- lib/sqlalchemy/orm/events.py | 3 +- lib/sqlalchemy/orm/scoping.py | 99 ++++---- lib/sqlalchemy/orm/session.py | 38 +-- lib/sqlalchemy/pool/base.py | 1 + lib/sqlalchemy/testing/plugin/pytestplugin.py | 4 +- lib/sqlalchemy/util/__init__.py | 1 + lib/sqlalchemy/util/langhelpers.py | 155 +++++++++++- test/base/test_events.py | 126 ++++++---- test/base/test_utils.py | 222 ++++++++++-------- test/ext/asyncio/test_engine_py3k.py | 182 ++++++++++++++ test/ext/asyncio/test_session_py3k.py | 58 +++++ test/orm/test_scoping.py | 24 ++ test/orm/test_session.py | 55 +++-- 19 files changed, 997 insertions(+), 317 deletions(-) create mode 100644 lib/sqlalchemy/ext/asyncio/events.py diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 9a6bdd7f36..4fbdec1452 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -184,14 +184,17 @@ class Connection(Connectable): r""" Set non-SQL options for the connection which take effect during execution. - The method returns a copy of this :class:`_engine.Connection` - which references - the same underlying DBAPI connection, but also defines the given - execution options which will take effect for a call to - :meth:`execute`. As the new :class:`_engine.Connection` - references the same - underlying resource, it's usually a good idea to ensure that the copies - will be discarded immediately, which is implicit if used as in:: + For a "future" style connection, this method returns this same + :class:`_future.Connection` object with the new options added. + + For a legacy connection, this method returns a copy of this + :class:`_engine.Connection` which references the same underlying DBAPI + connection, but also defines the given execution options which will + take effect for a call to + :meth:`execute`. As the new :class:`_engine.Connection` references the + same underlying resource, it's usually a good idea to ensure that + the copies will be discarded immediately, which is implicit if used + as in:: result = connection.execution_options(stream_results=True).\ execute(stmt) @@ -549,9 +552,10 @@ class Connection(Connectable): """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. - The underlying DBAPI connection is literally closed (if - possible), and is discarded. Its source connection pool will - typically lazily create a new connection to replace it. + An attempt will be made to close the underlying DBAPI connection + immediately; however if this operation fails, the error is logged + but not raised. The connection is then discarded whether or not + close() succeeded. Upon the next use (where "use" typically means using the :meth:`_engine.Connection.execute` method or similar), @@ -580,6 +584,10 @@ class Connection(Connectable): will at the connection pool level invoke the :meth:`_events.PoolEvents.invalidate` event. + :param exception: an optional ``Exception`` instance that's the + reason for the invalidation. is passed along to event handlers + and logging functions. + .. seealso:: :ref:`pool_connection_invalidation` diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index daa6f9aeab..1ba88f3d21 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_): dispatch_cls._event_names.append(ls.name) if getattr(cls, "_dispatch_target", None): - cls._dispatch_target.dispatch = dispatcher(cls) + the_cls = cls._dispatch_target + if ( + hasattr(the_cls, "__slots__") + and "_slots_dispatch" in the_cls.__slots__ + ): + cls._dispatch_target.dispatch = slots_dispatcher(cls) + else: + cls._dispatch_target.dispatch = dispatcher(cls) def _remove_dispatcher(cls): @@ -304,5 +311,29 @@ class dispatcher(object): def __get__(self, obj, cls): if obj is None: return self.dispatch - obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj) + + disp = self.dispatch._for_instance(obj) + try: + obj.__dict__["dispatch"] = disp + except AttributeError as ae: + util.raise_( + TypeError( + "target %r doesn't have __dict__, should it be " + "defining _slots_dispatch?" % (obj,) + ), + replace_context=ae, + ) + return disp + + +class slots_dispatcher(dispatcher): + def __get__(self, obj, cls): + if obj is None: + return self.dispatch + + if hasattr(obj, "_slots_dispatch"): + return obj._slots_dispatch + + disp = self.dispatch._for_instance(obj) + obj._slots_dispatch = disp return disp diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index fbbc958d42..9c7d6443cb 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -2,6 +2,8 @@ from .engine import AsyncConnection # noqa from .engine import AsyncEngine # noqa from .engine import AsyncTransaction # noqa from .engine import create_async_engine # noqa +from .events import AsyncConnectionEvents # noqa +from .events import AsyncSessionEvents # noqa from .result import AsyncMappingResult # noqa from .result import AsyncResult # noqa from .result import AsyncScalarResult # noqa diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 4a92fb1f2c..9e4851dfcd 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -8,12 +8,11 @@ from .base import StartableContext from .result import AsyncResult from ... import exc from ... import util -from ...engine import Connection from ...engine import create_engine as _create_engine -from ...engine import Engine from ...engine import Result from ...engine import Transaction -from ...engine.base import OptionEngineMixin +from ...future import Connection +from ...future import Engine from ...sql import Executable from ...util.concurrency import greenlet_spawn @@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw): return AsyncEngine(sync_engine) -class AsyncConnection(StartableContext): +class AsyncConnectable: + __slots__ = "_slots_dispatch" + + +@util.create_proxy_methods( + Connection, + ":class:`_future.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection(StartableContext, AsyncConnectable): """An asyncio proxy for a :class:`_engine.Connection`. :class:`_asyncio.AsyncConnection` is acquired using the @@ -58,15 +74,23 @@ class AsyncConnection(StartableContext): """ # noqa + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. __slots__ = ( "sync_engine", "sync_connection", ) def __init__( - self, sync_engine: Engine, sync_connection: Optional[Connection] = None + self, + async_engine: "AsyncEngine", + sync_connection: Optional[Connection] = None, ): - self.sync_engine = sync_engine + self.engine = async_engine + self.sync_engine = async_engine.sync_engine self.sync_connection = sync_connection async def start(self): @@ -79,6 +103,34 @@ class AsyncConnection(StartableContext): self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) return self + @property + def connection(self): + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self): + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is typically the SQLAlchemy connection-pool proxied connection + which then has an attribute .connection that refers to the actual + DBAPI-level connection. + """ + conn = self._sync_connection() + + return await greenlet_spawn(getattr, conn, "connection") + + @property + def _proxied(self): + return self.sync_connection + def _sync_connection(self): if not self.sync_connection: self._raise_for_not_started() @@ -94,6 +146,43 @@ class AsyncConnection(StartableContext): self._sync_connection() return AsyncTransaction(self, nested=True) + async def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + conn = self._sync_connection() + return await greenlet_spawn(conn.invalidate, exception=exception) + + async def get_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + async def set_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + async def execution_options(self, **opt): + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_future.Connection.execution_options` for full details + on this method. + + """ + + conn = self._sync_connection() + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + async def commit(self): """Commit the transaction that is currently in progress. @@ -287,7 +376,19 @@ class AsyncConnection(StartableContext): await self.close() -class AsyncEngine: +@util.create_proxy_methods( + Engine, + ":class:`_future.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(AsyncConnectable): """An asyncio proxy for a :class:`_engine.Engine`. :class:`_asyncio.AsyncEngine` is acquired using the @@ -301,7 +402,12 @@ class AsyncEngine: """ # noqa - __slots__ = ("sync_engine",) + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = ("sync_engine", "_proxied") _connection_cls = AsyncConnection @@ -327,7 +433,7 @@ class AsyncEngine: await self.conn.close() def __init__(self, sync_engine: Engine): - self.sync_engine = sync_engine + self.sync_engine = self._proxied = sync_engine def begin(self): """Return a context manager which when entered will deliver an @@ -363,7 +469,7 @@ class AsyncEngine: """ - return self._connection_cls(self.sync_engine) + return self._connection_cls(self) async def raw_connection(self) -> Any: """Return a "raw" DBAPI connection from the connection pool. @@ -375,12 +481,33 @@ class AsyncEngine: """ return await greenlet_spawn(self.sync_engine.raw_connection) + def execution_options(self, **opt): + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_future.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) -class AsyncOptionEngine(OptionEngineMixin, AsyncEngine): - pass + async def dispose(self): + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + This will close all connection pool connections that are + **currently checked in**. See the documentation for the underlying + :meth:`_future.Engine.dispose` method for further notes. + + .. seealso:: + + :meth:`_future.Engine.dispose` + + """ -AsyncEngine._option_cls = AsyncOptionEngine + return await greenlet_spawn(self.sync_engine.dispose) class AsyncTransaction(StartableContext): diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py new file mode 100644 index 0000000000..a8daefc4b7 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/events.py @@ -0,0 +1,29 @@ +from .engine import AsyncConnectable +from .session import AsyncSession +from ...engine import events as engine_event +from ...orm import events as orm_event + + +class AsyncConnectionEvents(engine_event.ConnectionEvents): + _target_class_doc = "SomeEngine" + _dispatch_target = AsyncConnectable + + @classmethod + def _listen(cls, event_key, retval=False): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + +class AsyncSessionEvents(orm_event.SessionEvents): + _target_class_doc = "SomeSession" + _dispatch_target = AsyncSession + + @classmethod + def _listen(cls, event_key, retval=False): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index cb06aa26d5..4ae1fb385e 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1,6 +1,5 @@ from typing import Any from typing import Callable -from typing import List from typing import Mapping from typing import Optional @@ -15,6 +14,35 @@ from ...sql import Executable from ...util.concurrency import greenlet_spawn +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "delete", + "expire", + "expire_all", + "expunge", + "expunge_all", + "get_bind", + "is_modified", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) class AsyncSession: """Asyncio version of :class:`_orm.Session`. @@ -23,6 +51,16 @@ class AsyncSession: """ + __slots__ = ( + "binds", + "bind", + "sync_session", + "_proxied", + "_slots_dispatch", + ) + + dispatch = None + def __init__( self, bind: AsyncEngine = None, @@ -31,46 +69,18 @@ class AsyncSession: ): kw["future"] = True if bind: + self.bind = engine bind = engine._get_sync_engine(bind) if binds: + self.binds = binds binds = { key: engine._get_sync_engine(b) for key, b in binds.items() } - self.sync_session = Session(bind=bind, binds=binds, **kw) - - def add(self, instance: object) -> None: - """Place an object in this :class:`_asyncio.AsyncSession`. - - .. seealso:: - - :meth:`_orm.Session.add` - - """ - self.sync_session.add(instance) - - def add_all(self, instances: List[object]) -> None: - """Add the given collection of instances to this - :class:`_asyncio.AsyncSession`.""" - - self.sync_session.add_all(instances) - - def expire_all(self): - """Expires all persistent instances within this Session. - - See :meth:`_orm.Session.expire_all` for usage details. - - """ - self.sync_session.expire_all() - - def expire(self, instance, attribute_names=None): - """Expire the attributes on an instance. - - See :meth:`._orm.Session.expire` for usage details. - - """ - self.sync_session.expire() + self.sync_session = self._proxied = Session( + bind=bind, binds=binds, **kw + ) async def refresh( self, instance, attribute_names=None, with_for_update=None @@ -178,8 +188,17 @@ class AsyncSession: :class:`.Session` object's transactional state. """ + + # POSSIBLY TODO: here, we see that the sync engine / connection + # that are generated from AsyncEngine / AsyncConnection don't + # provide any backlink from those sync objects back out to the + # async ones. it's not *too* big a deal since AsyncEngine/Connection + # are just proxies and all the state is actually in the sync + # version of things. However! it has to stay that way :) sync_connection = await greenlet_spawn(self.sync_session.connection) - return engine.AsyncConnection(sync_connection.engine, sync_connection) + return engine.AsyncConnection( + engine.AsyncEngine(sync_connection.engine), sync_connection + ) def begin(self, **kw): """Return an :class:`_asyncio.AsyncSessionTransaction` object. @@ -218,14 +237,22 @@ class AsyncSession: return AsyncSessionTransaction(self, nested=True) async def rollback(self): + """Rollback the current transaction in progress.""" return await greenlet_spawn(self.sync_session.rollback) async def commit(self): + """Commit the current transaction in progress.""" return await greenlet_spawn(self.sync_session.commit) async def close(self): + """Close this :class:`_asyncio.AsyncSession`.""" return await greenlet_spawn(self.sync_session.close) + @classmethod + async def close_all(self): + """Close all :class:`_asyncio.AsyncSession` sessions.""" + return await greenlet_spawn(self.sync_session.close_all) + async def __aenter__(self): return self diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 29a509cb9a..4e11ebb8c6 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1371,7 +1371,8 @@ class SessionEvents(event.Events): elif isinstance(target, Session): return target else: - return None + # allows alternate SessionEvents-like-classes to be consulted + return event.Events._accept_with(target) @classmethod def _listen(cls, event_key, raw=False, restore_load_context=False, **kw): diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 1090501ca1..29d845c0a7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -9,14 +9,60 @@ from . import class_mapper from . import exc as orm_exc from .session import Session from .. import exc as sa_exc +from ..util import create_proxy_methods from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn - __all__ = ["scoped_session"] +@create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_orm.scoping.scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "bulk_save_objects", + "bulk_insert_mappings", + "bulk_update_mappings", + "merge", + "query", + "refresh", + "rollback", + "scalar", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + "autocommit", + ], +) class scoped_session(object): """Provides scoped management of :class:`.Session` objects. @@ -53,6 +99,10 @@ class scoped_session(object): else: self.registry = ThreadLocalRegistry(session_factory) + @property + def _proxied(self): + return self.registry() + def __call__(self, **kw): r"""Return the current :class:`.Session`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -156,50 +206,3 @@ class scoped_session(object): ScopedSession = scoped_session """Old name for backwards compatibility.""" - - -def instrument(name): - def do(self, *args, **kwargs): - return getattr(self.registry(), name)(*args, **kwargs) - - return do - - -for meth in Session.public_methods: - setattr(scoped_session, meth, instrument(meth)) - - -def makeprop(name): - def set_(self, attr): - setattr(self.registry(), name, attr) - - def get(self): - return getattr(self.registry(), name) - - return property(get, set_) - - -for prop in ( - "bind", - "dirty", - "deleted", - "new", - "identity_map", - "is_active", - "autoflush", - "no_autoflush", - "info", - "autocommit", -): - setattr(scoped_session, prop, makeprop(prop)) - - -def clslevel(name): - def do(cls, *args, **kwargs): - return getattr(Session, name)(*args, **kwargs) - - return classmethod(do) - - -for prop in ("close_all", "object_session", "identity_key"): - setattr(scoped_session, prop, clslevel(prop)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e32e055103..af0ac63e0d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -835,35 +835,6 @@ class Session(_SessionClassMethods): """ - public_methods = ( - "__contains__", - "__iter__", - "add", - "add_all", - "begin", - "begin_nested", - "close", - "commit", - "connection", - "delete", - "execute", - "expire", - "expire_all", - "expunge", - "expunge_all", - "flush", - "get_bind", - "is_modified", - "bulk_save_objects", - "bulk_insert_mappings", - "bulk_update_mappings", - "merge", - "query", - "refresh", - "rollback", - "scalar", - ) - @util.deprecated_params( autocommit=( "2.0", @@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods): will unexpire attributes on access. """ - state = attributes.instance_state(obj) + try: + state = attributes.instance_state(obj) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(obj), + replace_context=err, + ) + to_attach = self._before_attach(state, obj) state._load_pending = True if to_attach: diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 87383fef71..68fa5fe85a 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -509,6 +509,7 @@ class _ConnectionRecord(object): "Soft " if soft else "", self.connection, ) + if soft: self._soft_invalidate_time = time.time() else: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index dfefd3b95f..644ea6dc20 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -372,7 +372,7 @@ def _pytest_fn_decorator(target): if add_positional_parameters: spec.args.extend(add_positional_parameters) - metadata = dict(target="target", fn="fn", name=fn.__name__) + metadata = dict(target="target", fn="__fn", name=fn.__name__) metadata.update(format_argspec_plus(spec, grouped=False)) code = ( """\ @@ -382,7 +382,7 @@ def %(name)s(%(args)s): % metadata ) decorated = _exec_code_in_env( - code, {"target": target, "fn": fn}, fn.__name__ + code, {"target": target, "__fn": fn}, fn.__name__ ) if not add_positional_parameters: decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 8ef2f01032..885f62f97b 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type # noqa from .langhelpers import constructor_copy # noqa from .langhelpers import constructor_key # noqa from .langhelpers import counter # noqa +from .langhelpers import create_proxy_methods # noqa from .langhelpers import decode_slice # noqa from .langhelpers import decorator # noqa from .langhelpers import dictlike_iteritems # noqa diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index e546f196d5..4289db8129 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -9,6 +9,7 @@ modules, classes, hierarchies, attributes, functions, and methods. """ + from functools import update_wrapper import hashlib import inspect @@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True): passed positionally. apply_kw Like apply_pos, except keyword-ish args are passed as keywords. + apply_pos_proxied + Like apply_pos but omits the self/cls argument Example:: @@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True): spec = fn args = compat.inspect_formatargspec(*spec) + + apply_pos = compat.inspect_formatargspec( + spec[0], spec[1], spec[2], None, spec[4] + ) + if spec[0]: self_arg = spec[0][0] + + apply_pos_proxied = compat.inspect_formatargspec( + spec[0][1:], spec[1], spec[2], None, spec[4] + ) + elif spec[1]: + # im not sure what this is self_arg = "%s[0]" % spec[1] + + apply_pos_proxied = apply_pos else: self_arg = None + apply_pos_proxied = apply_pos - apply_pos = compat.inspect_formatargspec( - spec[0], spec[1], spec[2], None, spec[4] - ) num_defaults = 0 if spec[3]: num_defaults += len(spec[3]) @@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True): self_arg=self_arg, apply_pos=apply_pos, apply_kw=apply_kw, + apply_pos_proxied=apply_pos_proxied, ) else: return dict( @@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True): self_arg=self_arg, apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1], + apply_pos_proxied=apply_pos_proxied[1:-1], ) @@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True): """ if method is object.__init__: - args = grouped and "(self)" or "self" + args = "(self)" if grouped else "self" + proxied = "()" if grouped else "" else: try: return format_argspec_plus(method, grouped=grouped) except TypeError: args = ( - grouped - and "(self, *args, **kwargs)" - or "self, *args, **kwargs" + "(self, *args, **kwargs)" + if grouped + else "self, *args, **kwargs" ) - return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args) + proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs" + return dict( + self_arg="self", + args=args, + apply_pos=args, + apply_kw=args, + apply_pos_proxied=proxied, + ) + + +def create_proxy_methods( + target_cls, + target_cls_sphinx_name, + proxy_cls_sphinx_name, + classmethods=(), + methods=(), + attributes=(), +): + """A class decorator that will copy attributes to a proxy class. + + The class to be instrumented must define a single accessor "_proxied". + + """ + + def decorate(cls): + def instrument(name, clslevel=False): + fn = getattr(target_cls, name) + spec = compat.inspect_getfullargspec(fn) + env = {} + + spec = _update_argspec_defaults_into_env(spec, env) + caller_argspec = format_argspec_plus(spec, grouped=False) + + metadata = { + "name": fn.__name__, + "apply_pos_proxied": caller_argspec["apply_pos_proxied"], + "args": caller_argspec["args"], + "self_arg": caller_argspec["self_arg"], + } + + if clslevel: + code = ( + "def %(name)s(%(args)s):\n" + " return target_cls.%(name)s(%(apply_pos_proxied)s)" + % metadata + ) + env["target_cls"] = target_cls + else: + code = ( + "def %(name)s(%(args)s):\n" + " return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501 + % metadata + ) + + proxy_fn = _exec_code_in_env(code, env, fn.__name__) + proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + proxy_fn.__doc__ = inject_docstring_text( + fn.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % (target_cls_sphinx_name, proxy_cls_sphinx_name), + 1, + ) + + if clslevel: + proxy_fn = classmethod(proxy_fn) + + return proxy_fn + + def makeprop(name): + attr = target_cls.__dict__.get(name, None) + + if attr is not None: + doc = inject_docstring_text( + attr.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % ( + target_cls_sphinx_name, + proxy_cls_sphinx_name, + ), + 1, + ) + else: + doc = None + + code = ( + "def set_(self, attr):\n" + " self._proxied.%(name)s = attr\n" + "def get(self):\n" + " return self._proxied.%(name)s\n" + "get.__doc__ = doc\n" + "getset = property(get, set_)" + ) % {"name": name} + + getset = _exec_code_in_env(code, {"doc": doc}, "getset") + + return getset + + for meth in methods: + if hasattr(cls, meth): + raise TypeError( + "class %s already has a method %s" % (cls, meth) + ) + setattr(cls, meth, instrument(meth)) + + for prop in attributes: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, makeprop(prop)) + + for prop in classmethods: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, instrument(prop, clslevel=True)) + + return cls + + return decorate def getargspec_init(method): diff --git a/test/base/test_events.py b/test/base/test_events.py index a4ed1000b8..19f68e9a35 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -15,7 +15,19 @@ from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.util import gc_collect -class EventsTest(fixtures.TestBase): +class TearDownLocalEventsFixture(object): + def tearDown(self): + classes = set() + for entry in event.base._registrars.values(): + for evt_cls in entry: + if evt_cls.__module__ == __name__: + classes.add(evt_cls) + + for evt_cls in classes: + event.base._remove_dispatcher(evt_cls) + + +class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test class- and instance-level event registration.""" def setUp(self): @@ -34,9 +46,6 @@ class EventsTest(fixtures.TestBase): self.Target = Target - def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) - def test_register_class(self): def listen(x, y): pass @@ -258,7 +267,60 @@ class EventsTest(fixtures.TestBase): ) -class NamedCallTest(fixtures.TestBase): +class SlotsEventsTest(fixtures.TestBase): + @testing.requires.python3 + def test_no_slots_dispatch(self): + class Target(object): + __slots__ = () + + class TargetEvents(event.Events): + _dispatch_target = Target + + def event_one(self, x, y): + pass + + def event_two(self, x): + pass + + def event_three(self, x): + pass + + t1 = Target() + + with testing.expect_raises_message( + TypeError, + r"target .*Target.* doesn't have __dict__, should it " + "be defining _slots_dispatch", + ): + event.listen(t1, "event_one", Mock()) + + def test_slots_dispatch(self): + class Target(object): + __slots__ = ("_slots_dispatch",) + + class TargetEvents(event.Events): + _dispatch_target = Target + + def event_one(self, x, y): + pass + + def event_two(self, x): + pass + + def event_three(self, x): + pass + + t1 = Target() + + m1 = Mock() + event.listen(t1, "event_one", m1) + + t1.dispatch.event_one(2, 4) + + eq_(m1.mock_calls, [call(2, 4)]) + + +class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase): def _fixture(self): class TargetEventsOne(event.Events): def event_one(self, x, y): @@ -373,7 +435,7 @@ class NamedCallTest(fixtures.TestBase): eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})]) -class LegacySignatureTest(fixtures.TestBase): +class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase): """test adaption of legacy args""" def setUp(self): @@ -397,11 +459,6 @@ class LegacySignatureTest(fixtures.TestBase): self.TargetOne = TargetOne - def tearDown(self): - event.base._remove_dispatcher( - self.TargetOne.__dict__["dispatch"].events - ) - def test_legacy_accept(self): canary = Mock() @@ -550,12 +607,7 @@ class LegacySignatureTest(fixtures.TestBase): ) -class ClsLevelListenTest(fixtures.TestBase): - def tearDown(self): - event.base._remove_dispatcher( - self.TargetOne.__dict__["dispatch"].events - ) - +class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): def setUp(self): class TargetEventsOne(event.Events): def event_one(self, x, y): @@ -622,7 +674,7 @@ class ClsLevelListenTest(fixtures.TestBase): assert handler2 not in s2.dispatch.event_one -class AcceptTargetsTest(fixtures.TestBase): +class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test default target acceptance.""" def setUp(self): @@ -643,14 +695,6 @@ class AcceptTargetsTest(fixtures.TestBase): self.TargetOne = TargetOne self.TargetTwo = TargetTwo - def tearDown(self): - event.base._remove_dispatcher( - self.TargetOne.__dict__["dispatch"].events - ) - event.base._remove_dispatcher( - self.TargetTwo.__dict__["dispatch"].events - ) - def test_target_accept(self): """Test that events of the same name are routed to the correct collection based on the type of target given. @@ -687,7 +731,7 @@ class AcceptTargetsTest(fixtures.TestBase): eq_(list(t2.dispatch.event_one), [listen_two, listen_four]) -class CustomTargetsTest(fixtures.TestBase): +class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom target acceptance.""" def setUp(self): @@ -707,9 +751,6 @@ class CustomTargetsTest(fixtures.TestBase): self.Target = Target - def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) - def test_indirect(self): def listen(x, y): pass @@ -727,7 +768,7 @@ class CustomTargetsTest(fixtures.TestBase): ) -class SubclassGrowthTest(fixtures.TestBase): +class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase): """test that ad-hoc subclasses are garbage collected.""" def setUp(self): @@ -752,7 +793,7 @@ class SubclassGrowthTest(fixtures.TestBase): eq_(self.Target.__subclasses__(), []) -class ListenOverrideTest(fixtures.TestBase): +class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom listen functions which change the listener function signature.""" @@ -778,9 +819,6 @@ class ListenOverrideTest(fixtures.TestBase): self.Target = Target - def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) - def test_listen_override(self): listen_one = Mock() listen_two = Mock() @@ -816,7 +854,7 @@ class ListenOverrideTest(fixtures.TestBase): eq_(listen_one.mock_calls, [call(12)]) -class PropagateTest(fixtures.TestBase): +class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): def setUp(self): class TargetEvents(event.Events): def event_one(self, arg): @@ -850,7 +888,7 @@ class PropagateTest(fixtures.TestBase): eq_(listen_two.mock_calls, []) -class JoinTest(fixtures.TestBase): +class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase): def setUp(self): class TargetEvents(event.Events): def event_one(self, target, arg): @@ -875,11 +913,6 @@ class JoinTest(fixtures.TestBase): self.TargetFactory = TargetFactory self.TargetElement = TargetElement - def tearDown(self): - for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget): - if "dispatch" in cls.__dict__: - event.base._remove_dispatcher(cls.__dict__["dispatch"].events) - def test_neither(self): element = self.TargetFactory().create() element.run_event(1) @@ -1075,7 +1108,7 @@ class JoinTest(fixtures.TestBase): ) -class DisableClsPropagateTest(fixtures.TestBase): +class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): def setUp(self): class TargetEvents(event.Events): def event_one(self, target, arg): @@ -1093,11 +1126,6 @@ class DisableClsPropagateTest(fixtures.TestBase): self.BaseTarget = BaseTarget self.SubTarget = SubTarget - def tearDown(self): - for cls in (self.SubTarget, self.BaseTarget): - if "dispatch" in cls.__dict__: - event.base._remove_dispatcher(cls.__dict__["dispatch"].events) - def test_listen_invoke_clslevel(self): canary = Mock() @@ -1132,7 +1160,7 @@ class DisableClsPropagateTest(fixtures.TestBase): eq_(canary.mock_calls, []) -class RemovalTest(fixtures.TestBase): +class RemovalTest(TearDownLocalEventsFixture, fixtures.TestBase): def _fixture(self): class TargetEvents(event.Events): def event_one(self, x, y): diff --git a/test/base/test_utils.py b/test/base/test_utils.py index fa347243e7..9220720b67 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2300,7 +2300,14 @@ class SymbolTest(fixtures.TestBase): class _Py3KFixtures(object): - pass + def _kw_only_fixture(self): + pass + + def _kw_plus_posn_fixture(self): + pass + + def _kw_opt_fixture(self): + pass if util.py3k: @@ -2321,185 +2328,208 @@ def _kw_opt_fixture(self, a, *, b, c="c"): for k in _locals: setattr(_Py3KFixtures, k, _locals[k]) +py3k_fixtures = _Py3KFixtures() -class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase): - def _test_format_argspec_plus(self, fn, wanted, grouped=None): - - # test direct function - if grouped is None: - parsed = util.format_argspec_plus(fn) - else: - parsed = util.format_argspec_plus(fn, grouped=grouped) - eq_(parsed, wanted) - - # test sending fullargspec - spec = compat.inspect_getfullargspec(fn) - if grouped is None: - parsed = util.format_argspec_plus(spec) - else: - parsed = util.format_argspec_plus(spec, grouped=grouped) - eq_(parsed, wanted) - def test_specs(self): - self._test_format_argspec_plus( +class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase): + @testing.combinations( + ( lambda: None, { "args": "()", "self_arg": None, "apply_kw": "()", "apply_pos": "()", + "apply_pos_proxied": "()", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda: None, - {"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""}, - grouped=False, - ) - - self._test_format_argspec_plus( + { + "args": "", + "self_arg": None, + "apply_kw": "", + "apply_pos": "", + "apply_pos_proxied": "", + }, + False, + ), + ( lambda self: None, { "args": "(self)", "self_arg": "self", "apply_kw": "(self)", "apply_pos": "(self)", + "apply_pos_proxied": "()", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda self: None, { "args": "self", "self_arg": "self", "apply_kw": "self", "apply_pos": "self", + "apply_pos_proxied": "", }, - grouped=False, - ) - - self._test_format_argspec_plus( + False, + ), + ( lambda *a: None, { "args": "(*a)", "self_arg": "a[0]", "apply_kw": "(*a)", "apply_pos": "(*a)", + "apply_pos_proxied": "(*a)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda **kw: None, { "args": "(**kw)", "self_arg": None, "apply_kw": "(**kw)", "apply_pos": "(**kw)", + "apply_pos_proxied": "(**kw)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda *a, **kw: None, { "args": "(*a, **kw)", "self_arg": "a[0]", "apply_kw": "(*a, **kw)", "apply_pos": "(*a, **kw)", + "apply_pos_proxied": "(*a, **kw)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a, *b: None, { "args": "(a, *b)", "self_arg": "a", "apply_kw": "(a, *b)", "apply_pos": "(a, *b)", + "apply_pos_proxied": "(*b)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a, **b: None, { "args": "(a, **b)", "self_arg": "a", "apply_kw": "(a, **b)", "apply_pos": "(a, **b)", + "apply_pos_proxied": "(**b)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a, *b, **c: None, { "args": "(a, *b, **c)", "self_arg": "a", "apply_kw": "(a, *b, **c)", "apply_pos": "(a, *b, **c)", + "apply_pos_proxied": "(*b, **c)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a, b=1, **c: None, { "args": "(a, b=1, **c)", "self_arg": "a", "apply_kw": "(a, b=b, **c)", "apply_pos": "(a, b, **c)", + "apply_pos_proxied": "(b, **c)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a=1, b=2: None, { "args": "(a=1, b=2)", "self_arg": "a", "apply_kw": "(a=a, b=b)", "apply_pos": "(a, b)", + "apply_pos_proxied": "(b)", }, - ) - - self._test_format_argspec_plus( + True, + ), + ( lambda a=1, b=2: None, { "args": "a=1, b=2", "self_arg": "a", "apply_kw": "a=a, b=b", "apply_pos": "a, b", + "apply_pos_proxied": "b", }, - grouped=False, - ) - - if util.py3k: - self._test_format_argspec_plus( - self._kw_only_fixture, - { - "args": "self, a, *, b, c", - "self_arg": "self", - "apply_pos": "self, a, *, b, c", - "apply_kw": "self, a, b=b, c=c", - }, - grouped=False, - ) - self._test_format_argspec_plus( - self._kw_plus_posn_fixture, - { - "args": "self, a, *args, b, c", - "self_arg": "self", - "apply_pos": "self, a, *args, b, c", - "apply_kw": "self, a, b=b, c=c, *args", - }, - grouped=False, - ) - self._test_format_argspec_plus( - self._kw_opt_fixture, - { - "args": "self, a, *, b, c='c'", - "self_arg": "self", - "apply_pos": "self, a, *, b, c", - "apply_kw": "self, a, b=b, c=c", - }, - grouped=False, - ) + False, + ), + ( + py3k_fixtures._kw_only_fixture, + { + "args": "self, a, *, b, c", + "self_arg": "self", + "apply_pos": "self, a, *, b, c", + "apply_kw": "self, a, b=b, c=c", + "apply_pos_proxied": "a, *, b, c", + }, + False, + testing.requires.python3, + ), + ( + py3k_fixtures._kw_plus_posn_fixture, + { + "args": "self, a, *args, b, c", + "self_arg": "self", + "apply_pos": "self, a, *args, b, c", + "apply_kw": "self, a, b=b, c=c, *args", + "apply_pos_proxied": "a, *args, b, c", + }, + False, + testing.requires.python3, + ), + ( + py3k_fixtures._kw_opt_fixture, + { + "args": "self, a, *, b, c='c'", + "self_arg": "self", + "apply_pos": "self, a, *, b, c", + "apply_kw": "self, a, b=b, c=c", + "apply_pos_proxied": "a, *, b, c", + }, + False, + testing.requires.python3, + ), + argnames="fn,wanted,grouped", + ) + def test_specs(self, fn, wanted, grouped): + + # test direct function + if grouped is None: + parsed = util.format_argspec_plus(fn) + else: + parsed = util.format_argspec_plus(fn, grouped=grouped) + eq_(parsed, wanted) + + # test sending fullargspec + spec = compat.inspect_getfullargspec(fn) + if grouped is None: + parsed = util.format_argspec_plus(spec) + else: + parsed = util.format_argspec_plus(spec, grouped=grouped) + eq_(parsed, wanted) @testing.requires.cpython def test_init_grouped(self): @@ -2508,17 +2538,20 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase): "self_arg": "self", "apply_pos": "(self)", "apply_kw": "(self)", + "apply_pos_proxied": "()", } wrapper_spec = { "args": "(self, *args, **kwargs)", "self_arg": "self", "apply_pos": "(self, *args, **kwargs)", "apply_kw": "(self, *args, **kwargs)", + "apply_pos_proxied": "(*args, **kwargs)", } custom_spec = { "args": "(slef, a=123)", "self_arg": "slef", # yes, slef "apply_pos": "(slef, a)", + "apply_pos_proxied": "(a)", "apply_kw": "(slef, a=a)", } @@ -2532,18 +2565,21 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase): "self_arg": "self", "apply_pos": "self", "apply_kw": "self", + "apply_pos_proxied": "", } wrapper_spec = { "args": "self, *args, **kwargs", "self_arg": "self", "apply_pos": "self, *args, **kwargs", "apply_kw": "self, *args, **kwargs", + "apply_pos_proxied": "*args, **kwargs", } custom_spec = { "args": "slef, a=123", "self_arg": "slef", # yes, slef "apply_pos": "slef, a", "apply_kw": "slef, a=a", + "apply_pos_proxied": "a", } self._test_init(False, object_spec, wrapper_spec, custom_spec) diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 7c7d90e217..83987b06f1 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -2,6 +2,7 @@ import asyncio from sqlalchemy import Column from sqlalchemy import delete +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import Integer @@ -9,13 +10,19 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy import text from sqlalchemy import union_all from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc from sqlalchemy.testing import async_test from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not +from sqlalchemy.testing import mock from sqlalchemy.testing.asyncio import assert_raises_message_async +from sqlalchemy.util.concurrency import greenlet_spawn class EngineFixture(fixtures.TablesTest): @@ -50,6 +57,117 @@ class EngineFixture(fixtures.TablesTest): class AsyncEngineTest(EngineFixture): __backend__ = True + def test_proxied_attrs_engine(self, async_engine): + sync_engine = async_engine.sync_engine + + is_(async_engine.url, sync_engine.url) + is_(async_engine.pool, sync_engine.pool) + is_(async_engine.dialect, sync_engine.dialect) + eq_(async_engine.name, sync_engine.name) + eq_(async_engine.driver, sync_engine.driver) + eq_(async_engine.echo, sync_engine.echo) + + def test_clear_compiled_cache(self, async_engine): + async_engine.sync_engine._compiled_cache["foo"] = "bar" + eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") + async_engine.clear_compiled_cache() + assert "foo" not in async_engine.sync_engine._compiled_cache + + def test_execution_options(self, async_engine): + a2 = async_engine.execution_options(foo="bar") + assert isinstance(a2, _async_engine.AsyncEngine) + eq_(a2.sync_engine._execution_options, {"foo": "bar"}) + eq_(async_engine.sync_engine._execution_options, {}) + + """ + + attr uri, pool, dialect, engine, name, driver, echo + methods clear_compiled_cache, update_execution_options, + execution_options, get_execution_options, dispose + + """ + + @async_test + async def test_proxied_attrs_connection(self, async_engine): + conn = await async_engine.connect() + + sync_conn = conn.sync_connection + + is_(conn.engine, async_engine) + is_(conn.closed, sync_conn.closed) + is_(conn.dialect, async_engine.sync_engine.dialect) + eq_(conn.default_isolation_level, sync_conn.default_isolation_level) + + @async_test + async def test_invalidate(self, async_engine): + conn = await async_engine.connect() + + is_(conn.invalidated, False) + + connection_fairy = await conn.get_raw_connection() + is_(connection_fairy.is_valid, True) + dbapi_connection = connection_fairy.connection + + await conn.invalidate() + assert dbapi_connection._connection.is_closed() + + new_fairy = await conn.get_raw_connection() + is_not(new_fairy.connection, dbapi_connection) + is_not(new_fairy, connection_fairy) + is_(new_fairy.is_valid, True) + is_(connection_fairy.is_valid, False) + + @async_test + async def test_get_dbapi_connection_raise(self, async_engine): + + conn = await async_engine.connect() + + with testing.expect_raises_message( + exc.InvalidRequestError, + "AsyncConnection.connection accessor is not " + "implemented as the attribute", + ): + conn.connection + + @async_test + async def test_get_raw_connection(self, async_engine): + + conn = await async_engine.connect() + + pooled = await conn.get_raw_connection() + is_(pooled, conn.sync_connection.connection) + + @async_test + async def test_isolation_level(self, async_engine): + conn = await async_engine.connect() + + sync_isolation_level = await greenlet_spawn( + conn.sync_connection.get_isolation_level + ) + isolation_level = await conn.get_isolation_level() + + eq_(isolation_level, sync_isolation_level) + + await conn.execution_options(isolation_level="SERIALIZABLE") + isolation_level = await conn.get_isolation_level() + + eq_(isolation_level, "SERIALIZABLE") + + @async_test + async def test_dispose(self, async_engine): + c1 = await async_engine.connect() + c2 = await async_engine.connect() + + await c1.close() + await c2.close() + + p1 = async_engine.pool + eq_(async_engine.pool.checkedin(), 2) + + await async_engine.dispose() + eq_(async_engine.pool.checkedin(), 0) + is_not(p1, async_engine.pool) + @async_test async def test_init_once_concurrency(self, async_engine): c1 = async_engine.connect() @@ -169,6 +287,70 @@ class AsyncEngineTest(EngineFixture): ) +class AsyncEventTest(EngineFixture): + """The engine events all run in their normal synchronous context. + + we do not provide an asyncio event interface at this time. + + """ + + __backend__ = True + + @async_test + async def test_no_async_listeners(self, async_engine): + with testing.expect_raises_message( + NotImplementedError, + "asynchronous events are not implemented " + "at this time. Apply synchronous listeners to the " + "AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes.", + ): + event.listen(async_engine, "before_cursor_execute", mock.Mock()) + + conn = await async_engine.connect() + + with testing.expect_raises_message( + NotImplementedError, + "asynchronous events are not implemented " + "at this time. Apply synchronous listeners to the " + "AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes.", + ): + event.listen(conn, "before_cursor_execute", mock.Mock()) + + @async_test + async def test_sync_before_cursor_execute_engine(self, async_engine): + canary = mock.Mock() + + event.listen(async_engine.sync_engine, "before_cursor_execute", canary) + + async with async_engine.connect() as conn: + sync_conn = conn.sync_connection + await conn.execute(text("select 1")) + + eq_( + canary.mock_calls, + [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)], + ) + + @async_test + async def test_sync_before_cursor_execute_connection(self, async_engine): + canary = mock.Mock() + + async with async_engine.connect() as conn: + sync_conn = conn.sync_connection + + event.listen( + async_engine.sync_engine, "before_cursor_execute", canary + ) + await conn.execute(text("select 1")) + + eq_( + canary.mock_calls, + [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)], + ) + + class AsyncResultTest(EngineFixture): @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index e8caaca3e4..a3b8add677 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -1,3 +1,4 @@ +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import select @@ -9,6 +10,7 @@ from sqlalchemy.orm import selectinload from sqlalchemy.testing import async_test from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import mock from ...orm import _fixtures @@ -139,6 +141,27 @@ class AsyncSessionTransactionTest(AsyncFixture): eq_(await outer_conn.scalar(select(func.count(User.id))), 1) + @async_test + async def test_delete(self, async_session): + User = self.classes.User + + async with async_session.begin(): + u1 = User(name="u1") + + async_session.add(u1) + + await async_session.flush() + + conn = await async_session.connection() + + eq_(await conn.scalar(select(func.count(User.id))), 1) + + async_session.delete(u1) + + await async_session.flush() + + eq_(await conn.scalar(select(func.count(User.id))), 0) + @async_test async def test_flush(self, async_session): User = self.classes.User @@ -198,3 +221,38 @@ class AsyncSessionTransactionTest(AsyncFixture): is_(new_u_merged, u1) eq_(u1.name, "new u1") + + +class AsyncEventTest(AsyncFixture): + """The engine events all run in their normal synchronous context. + + we do not provide an asyncio event interface at this time. + + """ + + __backend__ = True + + @async_test + async def test_no_async_listeners(self, async_session): + with testing.expect_raises( + NotImplementedError, + "NotImplementedError: asynchronous events are not implemented " + "at this time. Apply synchronous listeners to the " + "AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes.", + ): + event.listen(async_session, "before_flush", mock.Mock()) + + @async_test + async def test_sync_before_commit(self, async_session): + canary = mock.Mock() + + event.listen(async_session.sync_session, "before_commit", canary) + + async with async_session.begin(): + pass + + eq_( + canary.mock_calls, + [mock.call(async_session.sync_session)], + ) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 6b7feaea7f..d1ed9acc16 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import scoped_session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import mock from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -127,3 +128,26 @@ class ScopedSessionTest(fixtures.MappedTest): mock_scope_func.return_value = 1 s2 = Session(autocommit=True) assert s2.autocommit == True + + def test_methods_etc(self): + mock_session = Mock() + mock_session.bind = "the bind" + + sess = scoped_session(lambda: mock_session) + + sess.add("add") + sess.delete("delete") + + eq_(sess.bind, "the bind") + + eq_( + mock_session.mock_calls, + [mock.call.add("add", True), mock.call.delete("delete")], + ) + + with mock.patch( + "sqlalchemy.orm.session.object_session" + ) as mock_object_session: + sess.object_session("foo") + + eq_(mock_object_session.mock_calls, [mock.call("foo")]) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 9bc6c6f7ce..4562df44a6 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,3 +1,5 @@ +import inspect as _py_inspect + import sqlalchemy as sa from sqlalchemy import event from sqlalchemy import ForeignKey @@ -1820,22 +1822,28 @@ class DisposedStates(fixtures.MappedTest): class SessionInterface(fixtures.TestBase): """Bogus args to Session methods produce actionable exceptions.""" - # TODO: expand with message body assertions. - _class_methods = set(("connection", "execute", "get_bind", "scalar")) def _public_session_methods(self): Session = sa.orm.session.Session - blacklist = set(("begin", "query")) - + blacklist = {"begin", "query", "bind_mapper", "get", "bind_table"} + specials = {"__iter__", "__contains__"} ok = set() - for meth in Session.public_methods: - if meth in blacklist: - continue - spec = inspect_getfullargspec(getattr(Session, meth)) - if len(spec[0]) > 1 or spec[1]: - ok.add(meth) + for name in dir(Session): + if ( + name in Session.__dict__ + and (not name.startswith("_") or name in specials) + and ( + _py_inspect.ismethod(getattr(Session, name)) + or _py_inspect.isfunction(getattr(Session, name)) + ) + ): + if name in blacklist: + continue + spec = inspect_getfullargspec(getattr(Session, name)) + if len(spec[0]) > 1 or spec[1]: + ok.add(name) return ok def _map_it(self, cls): @@ -1866,18 +1874,21 @@ class SessionInterface(fixtures.TestBase): def raises_(method, *args, **kw): x_raises_(create_session(), method, *args, **kw) - raises_("__contains__", user_arg) - - raises_("add", user_arg) + for name in [ + "__contains__", + "is_modified", + "merge", + "refresh", + "add", + "delete", + "expire", + "expunge", + "enable_relationship_loading", + ]: + raises_(name, user_arg) raises_("add_all", (user_arg,)) - raises_("delete", user_arg) - - raises_("expire", user_arg) - - raises_("expunge", user_arg) - # flush will no-op without something in the unit of work def _(): class OK(object): @@ -1891,12 +1902,6 @@ class SessionInterface(fixtures.TestBase): _() - raises_("is_modified", user_arg) - - raises_("merge", user_arg) - - raises_("refresh", user_arg) - instance_methods = ( self._public_session_methods() - self._class_methods -- 2.39.5